diff --git a/packages/db/prisma/migrations/20260324182442_support_mcp_clients/migration.sql b/packages/db/prisma/migrations/20260324182442_support_mcp_clients/migration.sql new file mode 100644 index 000000000..3d3d9966f --- /dev/null +++ b/packages/db/prisma/migrations/20260324182442_support_mcp_clients/migration.sql @@ -0,0 +1,41 @@ +-- CreateTable +CREATE TABLE "McpServer" ( + "id" TEXT NOT NULL, + "name" TEXT NOT NULL, + "serverUrl" TEXT NOT NULL, + "clientInfo" TEXT, + "orgId" INTEGER NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "McpServer_pkey" PRIMARY KEY ("id") +); + +-- CreateTable +CREATE TABLE "McpServerCredential" ( + "id" TEXT NOT NULL, + "userId" TEXT NOT NULL, + "serverId" TEXT NOT NULL, + "tokens" TEXT, + "codeVerifier" TEXT, + "state" TEXT, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" TIMESTAMP(3) NOT NULL, + + CONSTRAINT "McpServerCredential_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "McpServer_serverUrl_orgId_key" ON "McpServer"("serverUrl", "orgId"); + +-- CreateIndex +CREATE UNIQUE INDEX "McpServerCredential_userId_serverId_key" ON "McpServerCredential"("userId", "serverId"); + +-- AddForeignKey +ALTER TABLE "McpServer" ADD CONSTRAINT "McpServer_orgId_fkey" FOREIGN KEY ("orgId") REFERENCES "Org"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "McpServerCredential" ADD CONSTRAINT "McpServerCredential_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "McpServerCredential" ADD CONSTRAINT "McpServerCredential_serverId_fkey" FOREIGN KEY ("serverId") REFERENCES "McpServer"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/db/prisma/migrations/20260325184501_add_mcp_server_credential_state_index/migration.sql b/packages/db/prisma/migrations/20260325184501_add_mcp_server_credential_state_index/migration.sql new file mode 100644 index 000000000..d14625836 --- /dev/null +++ b/packages/db/prisma/migrations/20260325184501_add_mcp_server_credential_state_index/migration.sql @@ -0,0 +1,2 @@ +-- CreateIndex +CREATE INDEX "McpServerCredential_state_idx" ON "McpServerCredential"("state"); diff --git a/packages/db/prisma/migrations/20260326230727_/migration.sql b/packages/db/prisma/migrations/20260326230727_/migration.sql new file mode 100644 index 000000000..b17ca3d7e --- /dev/null +++ b/packages/db/prisma/migrations/20260326230727_/migration.sql @@ -0,0 +1,24 @@ +/* + Warnings: + + - You are about to drop the column `name` on the `McpServer` table. All the data in the column will be lost. + +*/ +-- AlterTable +ALTER TABLE "McpServer" DROP COLUMN "name"; + +-- CreateTable +CREATE TABLE "UserMcpServer" ( + "userId" TEXT NOT NULL, + "serverId" TEXT NOT NULL, + "name" TEXT NOT NULL, + "createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "UserMcpServer_pkey" PRIMARY KEY ("userId","serverId") +); + +-- AddForeignKey +ALTER TABLE "UserMcpServer" ADD CONSTRAINT "UserMcpServer_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE; + +-- AddForeignKey +ALTER TABLE "UserMcpServer" ADD CONSTRAINT "UserMcpServer_serverId_fkey" FOREIGN KEY ("serverId") REFERENCES "McpServer"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/packages/db/prisma/migrations/20260327233318_add_tokens_expires_at/migration.sql b/packages/db/prisma/migrations/20260327233318_add_tokens_expires_at/migration.sql new file mode 100644 index 000000000..26f316ab1 --- /dev/null +++ b/packages/db/prisma/migrations/20260327233318_add_tokens_expires_at/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "McpServerCredential" ADD COLUMN "tokensExpiresAt" TIMESTAMP(3); diff --git a/packages/db/prisma/schema.prisma b/packages/db/prisma/schema.prisma index 7e1af6be7..26bbdad3e 100644 --- a/packages/db/prisma/schema.prisma +++ b/packages/db/prisma/schema.prisma @@ -292,6 +292,8 @@ model Org { chats Chat[] + mcpServers McpServer[] + license License? /// Set the first time this instance is seen to be on a trial subscription. @@ -409,6 +411,9 @@ model User { /// claim baked into the JWT cookie at mint time. sessionVersion Int @default(0) + mcpServerCredentials McpServerCredential[] + userMcpServers UserMcpServer[] + createdAt DateTime @default(now()) updatedAt DateTime @updatedAt @@ -603,3 +608,72 @@ model OAuthToken { createdAt DateTime @default(now()) lastUsedAt DateTime? } + +/// An external MCP server endpoint, unique per org. +/// Stores the dynamic client registration (client_id/client_secret) once per org. +model McpServer { + id String @id @default(cuid()) + serverUrl String /// MCP server endpoint (e.g., "https://mcp.linear.app/mcp") + + /// Dynamic client registration result (RFC 7591). + /// Encrypted JSON of OAuthClientInformation: { client_id, client_secret, client_id_issued_at, client_secret_expires_at } + /// Null until first user in the org triggers registration. + clientInfo String? + + org Org @relation(fields: [orgId], references: [id], onDelete: Cascade) + orgId Int + + credentials McpServerCredential[] + userMcpServers UserMcpServer[] + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@unique([serverUrl, orgId]) +} + +/// Junction table: a user's personal reference to an MCP server with their chosen display name. +model UserMcpServer { + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + + server McpServer @relation(fields: [serverId], references: [id], onDelete: Cascade) + serverId String + + name String /// User-chosen display name (e.g., "Linear") + + createdAt DateTime @default(now()) + + @@id([userId, serverId]) +} + +/// Per-user OAuth credentials for an external MCP server. +/// Stores tokens (long-lived) and ephemeral auth-flow state separately. +model McpServerCredential { + id String @id @default(cuid()) + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + + server McpServer @relation(fields: [serverId], references: [id], onDelete: Cascade) + serverId String + + /// OAuth tokens (access_token, refresh_token, etc.) — encrypted JSON of OAuthTokens. + tokens String? + + /// Absolute expiry time of the access token, computed at issuance from expires_in. + /// Null when no tokens are stored or the provider did not include expires_in. + tokensExpiresAt DateTime? + + /// PKCE code verifier — ephemeral, only used between redirect and callback. + codeVerifier String? + + /// OAuth state parameter — ephemeral, for CSRF protection during auth flow. + state String? + + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + @@unique([userId, serverId]) + @@index([state]) +} diff --git a/packages/shared/src/env.server.ts b/packages/shared/src/env.server.ts index 036655018..21d5e2b37 100644 --- a/packages/shared/src/env.server.ts +++ b/packages/shared/src/env.server.ts @@ -278,6 +278,7 @@ const options = { */ SOURCEBOT_CHAT_MODEL_TEMPERATURE: numberSchema.optional(), SOURCEBOT_CHAT_MAX_STEP_COUNT: numberSchema.default(100), + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: numberSchema.default(60000), DEBUG_WRITE_CHAT_MESSAGES_TO_FILE: booleanSchema.default('false'), DEBUG_ENABLE_REACT_SCAN: booleanSchema.default('false'), diff --git a/packages/web/package.json b/packages/web/package.json index 8825a21c5..9ca986d6a 100644 --- a/packages/web/package.json +++ b/packages/web/package.json @@ -20,6 +20,7 @@ "@ai-sdk/deepseek": "^2.0.29", "@ai-sdk/google": "^3.0.64", "@ai-sdk/google-vertex": "^4.0.111", + "@ai-sdk/mcp": "^2.0.0-beta.11", "@ai-sdk/mistral": "^3.0.30", "@ai-sdk/openai": "^3.0.53", "@ai-sdk/openai-compatible": "^2.0.41", @@ -196,7 +197,7 @@ "use-stick-to-bottom": "^1.1.3", "usehooks-ts": "^3.1.0", "vscode-icons-js": "^11.6.1", - "zod": "^3.25.74", + "zod": "^3.25.76", "zod-to-json-schema": "^3.24.5" }, "devDependencies": { diff --git a/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx b/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx index 43ccb1a87..6bc248ce8 100644 --- a/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx +++ b/packages/web/src/app/(app)/askgh/[owner]/[repo]/components/landingPage.tsx @@ -69,7 +69,7 @@ export const LandingPage = ({
{ - createNewChatThread(children, selectedSearchScopes); + createNewChatThread(children, selectedSearchScopes, []); }} className="min-h-[50px]" isRedirecting={isLoading} diff --git a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx b/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx index 574001e5f..3cf15df48 100644 --- a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx +++ b/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx @@ -40,11 +40,13 @@ export const ChatThreadPanel = ({ localStorage.removeItem(SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY); }, []); - // Use the last user's last message to determine what repos and contexts we should select by default. + // Use the last user message to determine what repos, contexts, and MCP state we should select by default. const lastUserMessage = messages.findLast((message) => message.role === "user"); const defaultSelectedSearchScopes = lastUserMessage?.metadata?.selectedSearchScopes ?? []; + const defaultDisabledMcpServerIds = lastUserMessage?.metadata?.disabledMcpServerIds ?? []; const [selectedSearchScopes, setSelectedSearchScopes] = useState(defaultSelectedSearchScopes); - + const [disabledMcpServerIds, setDisabledMcpServerIds] = useState(defaultDisabledMcpServerIds); + useEffect(() => { if (!chatState) { return; @@ -53,6 +55,7 @@ export const ChatThreadPanel = ({ try { setInputMessage(chatState.inputMessage); setSelectedSearchScopes(chatState.selectedSearchScopes); + setDisabledMcpServerIds(chatState.disabledMcpServerIds); } catch { console.error('Invalid chat state in session storage'); } finally { @@ -72,6 +75,8 @@ export const ChatThreadPanel = ({ searchContexts={searchContexts} selectedSearchScopes={selectedSearchScopes} onSelectedSearchScopesChange={setSelectedSearchScopes} + disabledMcpServerIds={disabledMcpServerIds} + onDisabledMcpServerIdsChange={setDisabledMcpServerIds} isOwner={isOwner} isAuthenticated={isAuthenticated} chatName={chatName} diff --git a/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx b/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx index 9d6b92381..99c2a5fb7 100644 --- a/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx +++ b/packages/web/src/app/(app)/chat/components/landingPageChatBox.tsx @@ -8,7 +8,7 @@ import { useCreateNewChatThread } from "@/features/chat/useCreateNewChatThread"; import { RepositoryQuery, SearchContextQuery } from "@/lib/types"; import { useState } from "react"; import { useLocalStorage } from "usehooks-ts"; -import { SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY } from "@/features/chat/constants"; +import { DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY, SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY } from "@/features/chat/constants"; import { SearchModeSelector } from "../../components/searchModeSelector"; import { NotConfiguredErrorBanner } from "@/features/chat/components/notConfiguredErrorBanner"; import { LoginModal } from "@/app/components/loginModal"; @@ -28,6 +28,7 @@ export const LandingPageChatBox = ({ }: LandingPageChatBox) => { const { createNewChatThread, isLoading, loginWall } = useCreateNewChatThread({ isAuthenticated }); const [selectedSearchScopes, setSelectedSearchScopes] = useLocalStorage(SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY, [], { initializeWithValue: false }); + const [disabledMcpServerIds, setDisabledMcpServerIds] = useLocalStorage(DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY, [], { initializeWithValue: false }); const [isContextSelectorOpen, setIsContextSelectorOpen] = useState(false); const isChatBoxDisabled = languageModels.length === 0; @@ -36,7 +37,7 @@ export const LandingPageChatBox = ({
{ - createNewChatThread(children, selectedSearchScopes); + createNewChatThread(children, selectedSearchScopes, disabledMcpServerIds); }} className="min-h-[50px]" isRedirecting={isLoading} @@ -56,6 +57,8 @@ export const LandingPageChatBox = ({ onSelectedSearchScopesChange={setSelectedSearchScopes} isContextSelectorOpen={isContextSelectorOpen} onContextSelectorOpenChanged={setIsContextSelectorOpen} + disabledMcpServerIds={disabledMcpServerIds} + onDisabledMcpServerIdsChange={setDisabledMcpServerIds} /> icon: "link" as const, } ] : []), + ...(await hasEntitlement("oauth") ? [ + { + title: "MCP Servers", + href: `/settings/mcpServers`, + } + ] : []), ], }, ]; diff --git a/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.tsx b/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.tsx new file mode 100644 index 000000000..1263a7c02 --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpServers/mcpServersPage.tsx @@ -0,0 +1,285 @@ +'use client'; + +import { useEffect, useRef, useState } from "react"; +import { useToast } from "@/components/hooks/use-toast"; +import { isServiceError } from "@/lib/utils"; +import { createMcpServer, deleteMcpServer } from "@/ee/features/mcp/actions"; +import { getMcpServersWithStatus } from "@/app/api/(client)/client"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { ConnectMcpButton } from "@/ee/features/mcp/components/connectMcpButton"; +import { Button } from "@/components/ui/button"; +import { + Dialog, DialogContent, DialogFooter, DialogHeader, DialogTitle, DialogTrigger, +} from "@/components/ui/dialog"; +import { + AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, + AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger, +} from "@/components/ui/alert-dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Card, CardContent, CardDescription, CardFooter, CardHeader, CardTitle } from "@/components/ui/card"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { Loader2, Plus, Server, Trash2 } from "lucide-react"; +import { Skeleton } from "@/components/ui/skeleton"; + +function clearCallbackParams() { + const url = new URL(window.location.href); + url.searchParams.delete('status'); + url.searchParams.delete('server'); + url.searchParams.delete('message'); + window.history.replaceState({}, '', url.toString()); +} + +interface McpServersPageProps { + callbackStatus?: string; + callbackServer?: string; + callbackMessage?: string; +} + +export function McpServersPage({ callbackStatus, callbackServer, callbackMessage }: McpServersPageProps) { + const { toast } = useToast(); + const didHandleCallbackRef = useRef(false); + + useEffect(() => { + if (didHandleCallbackRef.current) { + return; + } + if (callbackStatus === 'connected') { + didHandleCallbackRef.current = true; + toast({ description: `Successfully connected${callbackServer ? ` to ${callbackServer}` : ''}.` }); + clearCallbackParams(); + } else if (callbackStatus === 'error') { + didHandleCallbackRef.current = true; + toast({ title: "Connection failed", description: callbackMessage ?? 'Failed to connect MCP server.', variant: "destructive" }); + clearCallbackParams(); + } + }, [callbackStatus, callbackServer, callbackMessage, toast]); + + const queryClient = useQueryClient(); + + const { data: servers = [], isLoading, isError } = useQuery({ + queryKey: ['mcpServersWithStatus'], + queryFn: async () => { + const result = await getMcpServersWithStatus(); + if (isServiceError(result)) { + throw new Error("Failed to load MCP servers"); + } + return result; + }, + }); + + // Create dialog state + const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); + const [newServerName, setNewServerName] = useState(""); + const [newServerUrl, setNewServerUrl] = useState(""); + const [isCreating, setIsCreating] = useState(false); + + // Delete state + const [deletingServerId, setDeletingServerId] = useState(null); + + const handleCreate = async () => { + if (!newServerUrl.trim()) { + toast({ title: "Error", description: "Server URL is required", variant: "destructive" }); + return; + } + + setIsCreating(true); + try { + const result = await createMcpServer(newServerName.trim(), newServerUrl.trim()); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to add MCP server: ${result.message}`, variant: "destructive" }); + return; + } + await queryClient.invalidateQueries({ queryKey: ['mcpServersWithStatus'] }); + handleCloseCreateDialog(); + } catch (e) { + toast({ title: "Error", description: `Failed to add MCP server: ${e}`, variant: "destructive" }); + } finally { + setIsCreating(false); + } + }; + + const handleCloseCreateDialog = () => { + setIsCreateDialogOpen(false); + setNewServerName(""); + setNewServerUrl(""); + }; + + const handleDelete = async (serverId: string) => { + setDeletingServerId(serverId); + try { + const result = await deleteMcpServer(serverId); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to delete: ${result.message}`, variant: "destructive" }); + return; + } + await queryClient.invalidateQueries({ queryKey: ['mcpServersWithStatus'] }); + } catch (e) { + toast({ title: "Error", description: `Failed to delete MCP server: ${e}`, variant: "destructive" }); + } finally { + setDeletingServerId(null); + } + }; + + if (isError) { + return
Error loading MCP servers
; + } + + return ( +
+ {/* Header + Add button */} +
+
+

MCP Servers

+

+ Connect external MCP servers to use with Ask Sourcebot. +

+
+ + + + + + + + Add MCP Server + +
+
+ + setNewServerName(e.target.value)} + placeholder="e.g. Linear" + /> +
+
+ + setNewServerUrl(e.target.value)} + placeholder="https://mcp.linear.app/mcp" + /> +
+
+ + + + +
+
+
+ + {/* Server list */} + {isLoading ? ( +
+ {Array.from({ length: 2 }).map((_, i) => ( + + + + + + + + + + ))} +
+ ) : servers.length === 0 ? ( + + +
+ +
+

No MCP servers yet

+

+ Click "Add MCP Server" above to connect an external MCP server. +

+
+
+ ) : ( +
+ {servers.map((server) => ( + + +
+
+ +
+ {server.name || server.serverUrl} + {server.serverUrl} +
+
+ + + + + + + Delete MCP Server + + Are you sure you want to remove {server.name || server.serverUrl}? This will remove the server and your credentials from your list. + + + + Cancel + handleDelete(server.id)} + disabled={deletingServerId === server.id} + className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + {deletingServerId === server.id ? "Deleting..." : "Delete"} + + + + +
+
+ + {server.isConnected && ( +
+ + Connected +
+ )} + {server.isAuthExpired && ( +
+ + Authorization expired +
+ )} + {!server.isConnected && !server.isAuthExpired && ( +
+ + Not connected +
+ )} +
+ + + +
+ ))} +
+ )} +
+ ); +} \ No newline at end of file diff --git a/packages/web/src/app/(app)/settings/mcpServers/page.tsx b/packages/web/src/app/(app)/settings/mcpServers/page.tsx new file mode 100644 index 000000000..edfd780d6 --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcpServers/page.tsx @@ -0,0 +1,14 @@ +import { McpServersPage } from "./mcpServersPage"; + +interface PageProps { + searchParams: Promise<{ + status?: string; + server?: string; + message?: string; + }>; +} + +export default async function Page({ searchParams }: PageProps) { + const { status, server, message } = await searchParams; + return ; +} diff --git a/packages/web/src/app/api/(client)/client.ts b/packages/web/src/app/api/(client)/client.ts index 22c689278..5d11cfef4 100644 --- a/packages/web/src/app/api/(client)/client.ts +++ b/packages/web/src/app/api/(client)/client.ts @@ -29,6 +29,8 @@ import type { SearchChatShareableMembersQueryParams, SearchChatShareableMembersResponse, } from "../(server)/ee/chat/[chatId]/searchMembers/route"; +import { ConnectMcpResponse } from "../(server)/ee/askmcp/connect/types"; +import type { GetMcpServersResponse } from "../(server)/ee/askmcp/servers/route"; export const search = async (body: SearchRequest): Promise => { const result = await fetch("/api/search", { @@ -214,4 +216,33 @@ export const listChats = async (queryParams: ListChatsQueryParams): Promise response.json()); return result as ListChatsResponse | ServiceError; -} \ No newline at end of file +} + +export const connectMcpToAsk = async (body: { serverId: string }): Promise => { + const result = await fetch('/api/ee/askmcp/connect', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + body: JSON.stringify(body), + }).then(response => response.json()); + + if (isServiceError(result)) { + return result; + } + + return result as ConnectMcpResponse; +} + +export const getMcpServersWithStatus = async (): Promise => { + const result = await fetch('/api/ee/askmcp/servers', { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + }).then(response => response.json()); + + return result as GetMcpServersResponse | ServiceError; +} diff --git a/packages/web/src/app/api/(server)/chat/route.ts b/packages/web/src/app/api/(server)/chat/route.ts index 4c0b12819..84b3de016 100644 --- a/packages/web/src/app/api/(server)/chat/route.ts +++ b/packages/web/src/app/api/(server)/chat/route.ts @@ -33,7 +33,7 @@ export const POST = apiHandler(async (req: NextRequest) => { return serviceErrorResponse(requestBodySchemaValidationError(parsed.error)); } - const { messages, id, selectedSearchScopes, languageModel: _languageModel } = parsed.data; + const { messages, id, selectedSearchScopes, disabledMcpServerIds, languageModel: _languageModel } = parsed.data; // @note: a bit of type massaging is required here since the // zod schema does not enum on `model` or `provider`. // @see: chat/types.ts @@ -108,10 +108,13 @@ export const POST = apiHandler(async (req: NextRequest) => { selectedSearchScopes, }, selectedRepos: expandedRepos, + disabledMcpServerIds, model, modelName: languageModelConfig.displayName ?? languageModelConfig.model, modelProviderOptions: providerOptions, modelTemperature: temperature, + userId: user?.id, + orgId: org.id, onFinish: async ({ messages }) => { await updateChatMessages({ chatId: id, messages, prisma }); }, diff --git a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts new file mode 100644 index 000000000..bd340b9a0 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts @@ -0,0 +1,122 @@ +import { auth as mcpAuth } from '@ai-sdk/mcp'; +import { apiHandler } from '@/lib/apiHandler'; +import { env, createLogger } from '@sourcebot/shared'; +import { hasEntitlement } from '@/lib/entitlements'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +// Note: We use the raw (unscoped) prisma client here because this route handles OAuth +// redirect callbacks from external providers, so it can't go through withAuth. Session +// identity is verified via NextAuth's auth() instead, and all queries filter by userId. +import { __unsafePrisma as prisma } from '@/prisma'; +import { auth } from '@/auth'; +import { NextRequest, NextResponse } from 'next/server'; + +const logger = createLogger('mcp-oauth-callback'); + +export const GET = apiHandler(async (request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 } + ); + } + + const session = await auth(); + if (!session?.user?.id) { + return Response.json( + { error: 'unauthorized', error_description: 'You must be logged in.' }, + { status: 401 } + ); + } + + const { searchParams } = request.nextUrl; + const oauthError = searchParams.get('error'); + const code = searchParams.get('code'); + const state = searchParams.get('state'); + + // Handle OAuth errors (e.g., user cancelled the authorization flow). + if (oauthError) { + const settingsUrl = new URL(`/settings/mcpServers`, env.AUTH_URL); + settingsUrl.searchParams.set('status', 'error'); + const errorDescription = searchParams.get('error_description') ?? 'Authorization was cancelled or denied.'; + settingsUrl.searchParams.set('message', errorDescription); + return NextResponse.redirect(settingsUrl); + } + + if (!code || !state) { + return Response.json( + { error: 'invalid_request', error_description: 'Missing required parameters: code, state.' }, + { status: 400 } + ); + } + + const credential = await prisma.mcpServerCredential.findFirst({ + where: { + state, + userId: session.user.id, + }, + include: { + server: { + include: { + userMcpServers: { + where: { userId: session.user.id }, + take: 1, + }, + }, + }, + }, + }); + + if (!credential) { + return Response.json( + { error: 'invalid_state', error_description: 'No pending authorization found for this state.' }, + { status: 400 } + ); + } + + const orgMembership = await prisma.userToOrg.findUnique({ + where: { + orgId_userId: { + orgId: credential.server.orgId, + userId: session.user.id, + }, + }, + }); + + if (!orgMembership) { + return Response.json( + { error: 'forbidden', error_description: 'You do not have access to this MCP server.' }, + { status: 403 } + ); + } + + const provider = new PrismaOAuthClientProvider( + credential.serverId, + session.user.id, + `${env.AUTH_URL}/api/ee/askmcp/callback`, + ); + + const result = await mcpAuth(provider, { + serverUrl: new URL(credential.server.serverUrl), + authorizationCode: code, + callbackState: state, + }); + + // Always clear ephemeral PKCE/state regardless of outcome to prevent replay. + await provider.invalidateCredentials('verifier'); + + const settingsUrl = new URL(`/settings/mcpServers`, env.AUTH_URL); + + if (result === 'AUTHORIZED') { + const displayName = credential.server.userMcpServers[0]?.name ?? credential.server.serverUrl; + logger.info(`Successfully authorized MCP server ${displayName} for user ${session.user.id}.`); + settingsUrl.searchParams.set('status', 'connected'); + settingsUrl.searchParams.set('server', displayName); + return NextResponse.redirect(settingsUrl); + } + + // If auth() didn't return AUTHORIZED, something went wrong + settingsUrl.searchParams.set('status', 'error'); + settingsUrl.searchParams.set('message', 'Token exchange failed'); + return NextResponse.redirect(settingsUrl); +}); \ No newline at end of file diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts new file mode 100644 index 000000000..8d0ff1b0e --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts @@ -0,0 +1,77 @@ +import { auth as mcpAuth } from '@ai-sdk/mcp'; +import { apiHandler } from '@/lib/apiHandler'; +import { withAuth } from '@/middleware/withAuth'; +import { sew } from '@/middleware/sew'; +import { isServiceError } from '@/lib/utils'; +import { serviceErrorResponse, notFound, requestBodySchemaValidationError } from '@/lib/serviceError'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { NextRequest } from 'next/server'; +import { z } from 'zod'; +import { hasEntitlement } from '@/lib/entitlements'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import { ConnectMcpResponse } from "@/app/api/(server)/ee/askmcp/connect/types"; +import { env } from "@sourcebot/shared"; + +const bodySchema = z.object({ serverId: z.string() }); + +export const POST = apiHandler(async (request: NextRequest) => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 } + ); + } + + const body = await request.json(); + const parsed = bodySchema.safeParse(body); + if (!parsed.success) { + return serviceErrorResponse(requestBodySchemaValidationError(parsed.error)); + } + + const result = await sew(() => + withAuth(async ({ user, org, prisma }) => { + const mcpServer = await prisma.mcpServer.findUnique({ + where: { id: parsed.data.serverId, orgId: org.id }, + }); + if (!mcpServer) { + return notFound('MCP server not found'); + } + + // Verify the user has added this server to their list. + const userServer = await prisma.userMcpServer.findUnique({ + where: { + userId_serverId: { + userId: user.id, + serverId: mcpServer.id, + }, + }, + }); + if (!userServer) { + return notFound('MCP server not found'); + } + + const provider = new PrismaOAuthClientProvider( + mcpServer.id, + user.id, + `${env.AUTH_URL}/api/ee/askmcp/callback`, + ); + + const result = await mcpAuth(provider, { + serverUrl: new URL(mcpServer.serverUrl), + }); + + if (result === 'AUTHORIZED') { + // Already has valid tokens (e.g., refreshed) + return { authorizationUrl: null } satisfies ConnectMcpResponse; + } + + return { authorizationUrl: provider.authorizationUrl! } satisfies ConnectMcpResponse; + }) + ); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); \ No newline at end of file diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts new file mode 100644 index 000000000..80281ae17 --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/types.ts @@ -0,0 +1,4 @@ +export interface ConnectMcpResponse { + /** The external OAuth authorization URL the browser should navigate to. Null if already authorized. */ + authorizationUrl: string | null; +} \ No newline at end of file diff --git a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts new file mode 100644 index 000000000..8c922faba --- /dev/null +++ b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts @@ -0,0 +1,91 @@ +import { apiHandler } from '@/lib/apiHandler'; +import { serviceErrorResponse } from '@/lib/serviceError'; +import { isServiceError } from '@/lib/utils'; +import { withAuth } from '@/middleware/withAuth'; +import { hasEntitlement } from '@/lib/entitlements'; +import { decryptOAuthToken } from '@sourcebot/shared'; +import { sanitizeMcpServerName } from '@/ee/features/mcp/utils'; +import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; +import type { OAuthTokens } from '@ai-sdk/mcp'; + +export interface McpServerWithStatus { + id: string; + name: string; + serverUrl: string; + sanitizedName: string; + faviconUrl: string; + isConnected: boolean; + isAuthExpired: boolean; +} + +export type GetMcpServersResponse = McpServerWithStatus[]; + +export const GET = apiHandler(async () => { + if (!(await hasEntitlement('oauth'))) { + return Response.json( + { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, + { status: 403 } + ); + } + + const result = await withAuth(async ({ user, prisma }) => { + const userServers = await prisma.userMcpServer.findMany({ + where: { userId: user.id }, + orderBy: { createdAt: 'desc' }, + include: { + server: { + include: { + credentials: { + where: { userId: user.id }, + take: 1, + }, + }, + }, + }, + }); + + return userServers.map((us): McpServerWithStatus => { + const credential = us.server.credentials[0] ?? null; + const sanitizedName = sanitizeMcpServerName(us.name); + const origin = new URL(us.server.serverUrl).origin; + const faviconUrl = `https://www.google.com/s2/favicons?domain=${origin}&sz=32`; + + let isConnected = false; + let isAuthExpired = false; + + if (credential?.tokens) { + try { + const decrypted = decryptOAuthToken(credential.tokens); + if (decrypted) { + const tokens: OAuthTokens = JSON.parse(decrypted); + if (tokens.refresh_token || !credential.tokensExpiresAt) { + isConnected = true; + } else if (new Date() > credential.tokensExpiresAt) { + isAuthExpired = true; + } else { + isConnected = true; + } + } + } catch { + // treat as not connected if decryption fails + } + } + + return { + id: us.server.id, + name: us.name, + serverUrl: us.server.serverUrl, + sanitizedName, + faviconUrl, + isConnected, + isAuthExpired, + }; + }); + }); + + if (isServiceError(result)) { + return serviceErrorResponse(result); + } + + return Response.json(result); +}); \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/actions.ts b/packages/web/src/ee/features/mcp/actions.ts new file mode 100644 index 000000000..f1b827bb8 --- /dev/null +++ b/packages/web/src/ee/features/mcp/actions.ts @@ -0,0 +1,136 @@ +'use server'; + +import { sew } from '@/middleware/sew'; +import { ErrorCode } from '@/lib/errorCodes'; +import { ServiceError } from '@/lib/serviceError'; +import { withAuth } from '@/middleware/withAuth'; +import { StatusCodes } from 'http-status-codes'; +import { z } from 'zod'; +import { sanitizeMcpServerName } from './utils'; + +export const createMcpServer = async (name: string, serverUrl: string) => sew(() => + withAuth(async ({ org, user, prisma }) => { + const urlResult = z.string().url().safeParse(serverUrl); + if (!urlResult.success || !serverUrl.startsWith('https://')) { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Invalid server URL. Must be a valid HTTPS URL.', + } satisfies ServiceError; + } + + const sanitizedName = sanitizeMcpServerName(name); + const alphanumericCount = (sanitizedName.match(/[a-z0-9]/g) ?? []).length; + if (alphanumericCount < 3) { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: 'Server name must contain at least 3 alphanumeric characters.', + } satisfies ServiceError; + } + + // Upsert the McpServer record — reuse if the endpoint already exists for this org. + const mcpServer = await prisma.mcpServer.upsert({ + where: { + serverUrl_orgId: { + serverUrl, + orgId: org.id, + }, + }, + update: {}, + create: { + serverUrl, + orgId: org.id, + }, + }); + + // Check if this user already has this server in their list. + const existingUserServer = await prisma.userMcpServer.findUnique({ + where: { + userId_serverId: { + userId: user.id, + serverId: mcpServer.id, + }, + }, + }); + + if (existingUserServer) { + return { + statusCode: StatusCodes.CONFLICT, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: `You have already added an MCP server with URL "${serverUrl}".`, + } satisfies ServiceError; + } + + // Ensure the sanitized name is unique within the user's own servers to prevent + // tool-name collisions (e.g. "My Server" and "My-Server" both become "my_server"). + const userServers = await prisma.userMcpServer.findMany({ + where: { userId: user.id }, + select: { name: true }, + }); + const nameCollision = userServers.some( + (s) => sanitizeMcpServerName(s.name) === sanitizedName + ); + if (nameCollision) { + return { + statusCode: StatusCodes.CONFLICT, + errorCode: ErrorCode.MCP_SERVER_ALREADY_EXISTS, + message: `You already have an MCP server with a similar name. Please choose a more distinct name.`, + } satisfies ServiceError; + } + + await prisma.userMcpServer.create({ + data: { + userId: user.id, + serverId: mcpServer.id, + name, + }, + }); + + return { + id: mcpServer.id, + name, + serverUrl: mcpServer.serverUrl, + }; + })); + +export const deleteMcpServer = async (serverId: string) => sew(() => + withAuth(async ({ user, prisma }) => { + const userServer = await prisma.userMcpServer.findUnique({ + where: { + userId_serverId: { + userId: user.id, + serverId, + }, + }, + }); + + if (!userServer) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.MCP_SERVER_NOT_FOUND, + message: 'MCP server not found', + } satisfies ServiceError; + } + + // Delete the user's reference and their credentials. The McpServer row stays + // because other users may reference the same endpoint. + await prisma.$transaction([ + prisma.mcpServerCredential.deleteMany({ + where: { + userId: user.id, + serverId, + }, + }), + prisma.userMcpServer.delete({ + where: { + userId_serverId: { + userId: user.id, + serverId, + }, + }, + }), + ]); + + return { success: true }; + })); \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/components/connectMcpButton.tsx b/packages/web/src/ee/features/mcp/components/connectMcpButton.tsx new file mode 100644 index 000000000..d2b00c516 --- /dev/null +++ b/packages/web/src/ee/features/mcp/components/connectMcpButton.tsx @@ -0,0 +1,58 @@ +'use client'; + +import { useState } from 'react'; +import { LoadingButton } from '@/components/ui/loading-button'; +import { useToast } from '@/components/hooks/use-toast'; +import { isServiceError } from '@/lib/utils'; +import { connectMcpToAsk } from '@/app/api/(client)/client'; +import { ExternalLink } from 'lucide-react'; + +interface ConnectMcpButtonProps { + serverId: string; + isConnected?: boolean; + isAuthExpired?: boolean; +} + +export function ConnectMcpButton({ serverId, isConnected, isAuthExpired }: ConnectMcpButtonProps) { + const [loading, setLoading] = useState(false); + const { toast } = useToast(); + + const buttonLabel = isConnected || isAuthExpired ? "Reconnect" : "Connect MCP Server"; + const buttonVariant = isConnected ? "outline" as const : undefined; + + const handleConnect = async () => { + setLoading(true); + const result = await connectMcpToAsk({ serverId }); + + if (isServiceError(result)) { + toast({ + description: `Failed to connect MCP server. ${result.message}`, + }); + setLoading(false); + return; + } + + if (result.authorizationUrl) { + // OAuth flow — redirect to the authorization URL + window.location.href = result.authorizationUrl; + // Keep loading=true while redirecting (same pattern as ManageSubscriptionButton) + } else { + // Already authorized + toast({ + description: 'MCP server is already connected.', + }); + setLoading(false); + } + }; + + return ( + + {buttonLabel} + + + ); +} \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/components/mcpFavicon.tsx b/packages/web/src/ee/features/mcp/components/mcpFavicon.tsx new file mode 100644 index 000000000..2220fc516 --- /dev/null +++ b/packages/web/src/ee/features/mcp/components/mcpFavicon.tsx @@ -0,0 +1,24 @@ +'use client'; + +import { Plug } from "lucide-react"; +import { useState } from "react"; + +interface McpFaviconProps { + faviconUrl: string | undefined; + className?: string; +} + +export const McpFavicon = ({ faviconUrl, className = "w-4 h-4" }: McpFaviconProps) => { + const [failed, setFailed] = useState(false); + if (faviconUrl && !failed) { + return ( + setFailed(true)} + className={`${className} flex-shrink-0`} + alt="" + /> + ); + } + return ; +}; \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/mcpClientFactory.test.ts b/packages/web/src/ee/features/mcp/mcpClientFactory.test.ts new file mode 100644 index 000000000..69eefd6d1 --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpClientFactory.test.ts @@ -0,0 +1,132 @@ +import { expect, test, describe, vi } from 'vitest'; +import { prisma } from '@/__mocks__/prisma'; +import type { OAuthTokens } from '@ai-sdk/mcp'; + +// --- Mocks --- + +vi.mock('@/prisma', async () => { + const actual = await vi.importActual('@/__mocks__/prisma'); + return { ...actual }; +}); + +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), + env: { AUTH_URL: 'http://localhost:3000' }, + decryptOAuthToken: vi.fn((s: string) => s), +})); + +vi.mock('server-only', () => ({ default: vi.fn() })); + +vi.mock('@/features/mcp/prismaOAuthClientProvider', () => ({ + PrismaOAuthClientProvider: vi.fn(), +})); + +vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => ({ + StreamableHTTPClientTransport: vi.fn(), +})); + +// Import after mocks are set up +const { isTokenExpiredWithNoRefresh, getConnectedMcpClients } = await import('./mcpClientFactory'); + +// --- Helpers --- + +const PAST = new Date('2020-01-01'); +const FUTURE = new Date('2099-01-01'); + +const TOKEN_NO_REFRESH: OAuthTokens = { access_token: 'tok', token_type: 'Bearer' }; +const TOKEN_WITH_REFRESH: OAuthTokens = { access_token: 'tok', token_type: 'Bearer', refresh_token: 'ref' }; + +function makeCredential(overrides: { + tokens?: OAuthTokens; + tokensExpiresAt?: Date | null; + orgId?: number; + hasUserServer?: boolean; +}) { + return { + serverId: 'srv-1', + userId: 'user-1', + tokens: JSON.stringify(overrides.tokens ?? TOKEN_NO_REFRESH), + tokensExpiresAt: overrides.tokensExpiresAt ?? null, + codeVerifier: null, + state: null, + server: { + orgId: overrides.orgId ?? 1, + serverUrl: 'https://example.com/mcp', + userMcpServers: overrides.hasUserServer === false ? [] : [{ name: 'MyServer' }], + }, + }; +} + +// --- isTokenExpiredWithNoRefresh --- + +describe('isTokenExpiredWithNoRefresh', () => { + test('returns true when access token is expired and no refresh token', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, PAST)).toBe(true); + }); + + test('returns false when refresh_token is present even if access token is expired', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_WITH_REFRESH, PAST)).toBe(false); + }); + + test('returns false when tokensExpiresAt is null', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, null)).toBe(false); + }); + + test('returns false when access token has not yet expired', () => { + expect(isTokenExpiredWithNoRefresh(TOKEN_NO_REFRESH, FUTURE)).toBe(false); + }); +}); + +// --- getConnectedMcpClients --- + +describe('getConnectedMcpClients', () => { + test('skips server when access token expired and no refresh token', async () => { + prisma.mcpServerCredential.findMany.mockResolvedValue([ + makeCredential({ tokens: TOKEN_NO_REFRESH, tokensExpiresAt: PAST }), + ] as never); + + const result = await getConnectedMcpClients('user-1', 1); + expect(result).toHaveLength(0); + }); + + test('includes server when refresh_token present even if access token expired', async () => { + prisma.mcpServerCredential.findMany.mockResolvedValue([ + makeCredential({ tokens: TOKEN_WITH_REFRESH, tokensExpiresAt: PAST }), + ] as never); + + const result = await getConnectedMcpClients('user-1', 1); + expect(result).toHaveLength(1); + }); + + test('includes server when tokensExpiresAt is null', async () => { + prisma.mcpServerCredential.findMany.mockResolvedValue([ + makeCredential({ tokensExpiresAt: null }), + ] as never); + + const result = await getConnectedMcpClients('user-1', 1); + expect(result).toHaveLength(1); + }); + + test('skips server belonging to a different org', async () => { + prisma.mcpServerCredential.findMany.mockResolvedValue([ + makeCredential({ orgId: 999 }), + ] as never); + + const result = await getConnectedMcpClients('user-1', 1); + expect(result).toHaveLength(0); + }); + + test('skips server the user has removed from their list', async () => { + prisma.mcpServerCredential.findMany.mockResolvedValue([ + makeCredential({ hasUserServer: false }), + ] as never); + + const result = await getConnectedMcpClients('user-1', 1); + expect(result).toHaveLength(0); + }); +}); \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/mcpClientFactory.ts b/packages/web/src/ee/features/mcp/mcpClientFactory.ts new file mode 100644 index 000000000..47f7ee809 --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpClientFactory.ts @@ -0,0 +1,105 @@ +import { __unsafePrisma } from '@/prisma'; +import { createLogger, env, decryptOAuthToken } from '@sourcebot/shared'; +import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import type { OAuthTokens } from '@ai-sdk/mcp'; + +const logger = createLogger('mcp-client-factory'); + +export interface McpToolSet { + serverId: string; + serverName: string; + serverUrl: string; + transport: StreamableHTTPClientTransport; +} + +/** + * Returns true if the access token is definitely expired and there is no refresh token to fall back on. + */ +export function isTokenExpiredWithNoRefresh(tokens: OAuthTokens, tokensExpiresAt: Date | null): boolean { + if (tokens.refresh_token) { + return false; + } + if (!tokensExpiresAt) { + return false; + } + return new Date() > tokensExpiresAt; +} + +/** + * Creates authenticated transports for all external MCP servers the user has valid credentials for. + * Skips servers with clearly expired tokens and no refresh token. + * Does NOT connect — connection is deferred to createMCPClient. + */ +export async function getConnectedMcpClients(userId: string, orgId: number): Promise { + const credentials = await __unsafePrisma.mcpServerCredential.findMany({ + where: { + userId, + tokens: { not: null }, + }, + include: { + server: { + include: { + userMcpServers: { + where: { userId }, + take: 1, + }, + }, + }, + }, + }); + + const clients: McpToolSet[] = []; + + for (const credential of credentials) { + // Skip servers that don't belong to the current org. + if (credential.server.orgId !== orgId) { + continue; + } + + const userServer = credential.server.userMcpServers[0]; + // Skip if the user has removed this server from their list. + if (!userServer) { + continue; + } + + const serverName = userServer.name; + + try { + const decrypted = decryptOAuthToken(credential.tokens); + if (!decrypted) { + logger.warn(`Could not decrypt tokens for MCP server ${serverName}, skipping.`); + continue; + } + + const tokens: OAuthTokens = JSON.parse(decrypted); + + if (isTokenExpiredWithNoRefresh(tokens, credential.tokensExpiresAt)) { + logger.warn(`Access token for MCP server ${serverName} is expired and has no refresh token. User ${userId} needs to re-authorize.`); + continue; + } + + const provider = new PrismaOAuthClientProvider( + credential.serverId, + userId, + `${env.AUTH_URL}/api/ee/askmcp/callback`, + ); + + const transport = new StreamableHTTPClientTransport( + new URL(credential.server.serverUrl), + { authProvider: provider }, + ); + + clients.push({ + serverId: credential.serverId, + serverName, + serverUrl: credential.server.serverUrl, + transport, + }); + } catch (error) { + logger.error(`Failed to connect to MCP server ${serverName}:`, error); + } + } + + return clients; +} \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/mcpToolRegistry.test.ts b/packages/web/src/ee/features/mcp/mcpToolRegistry.test.ts new file mode 100644 index 000000000..20918f066 --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpToolRegistry.test.ts @@ -0,0 +1,185 @@ +import { expect, test, describe } from 'vitest'; +import { buildMcpToolRegistry, searchMcpTools, McpToolRegistryEntry } from './mcpToolRegistry'; + +// Helper to create a mock tool record matching the MCPClient['tools'] return type. +function createToolRecord(tools: Record) { + const record: Record = {}; + for (const [name, tool] of Object.entries(tools)) { + record[name] = { + description: tool.description, + execute: tool.execute ?? (() => {}), + inputSchema: {}, + }; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return record as any; +} + +describe('buildMcpToolRegistry', () => { + test('extracts serverName from namespaced tool name', () => { + const tools = createToolRecord({ + 'mcp_linear__list_issues': { description: 'List issues' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry).toEqual([ + { name: 'mcp_linear__list_issues', description: 'List issues', serverName: 'linear' }, + ]); + }); + + test('handles underscores in server name', () => { + const tools = createToolRecord({ + 'mcp_my_server__get_data': { description: 'Get data' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].serverName).toBe('my_server'); + }); + + test('defaults missing description to empty string', () => { + const tools = createToolRecord({ + 'mcp_linear__list_issues': { description: undefined }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].description).toBe(''); + }); + + test('non-matching tool name yields empty serverName', () => { + const tools = createToolRecord({ + 'some_random_tool': { description: 'A tool' }, + }); + + const registry = buildMcpToolRegistry(tools); + + expect(registry[0].serverName).toBe(''); + }); + + test('empty tools record returns empty array', () => { + const registry = buildMcpToolRegistry(createToolRecord({})); + + expect(registry).toEqual([]); + }); +}); + +describe('searchMcpTools', () => { + // Shared registry for most tests. + const registry: McpToolRegistryEntry[] = [ + { name: 'mcp_linear__list_issues', description: 'List all issues in a project', serverName: 'linear' }, + { name: 'mcp_linear__create_issue', description: 'Create a new issue', serverName: 'linear' }, + { name: 'mcp_linear__update_issue', description: 'Update an existing issue', serverName: 'linear' }, + { name: 'mcp_github__search_repos', description: 'Search repositories on GitHub', serverName: 'github' }, + { name: 'mcp_pg__run_query', description: 'Run a database query', serverName: 'pg' }, + { name: 'mcp_slack__send_message', description: 'Send a message to a Slack channel', serverName: 'slack' }, + { name: 'mcp_jira__create_ticket', description: 'Create a new Jira ticket', serverName: 'jira' }, + ]; + + test('exact name match returns single result', () => { + const results = searchMcpTools('mcp_linear__list_issues', registry); + + expect(results).toEqual([ + { name: 'mcp_linear__list_issues', description: 'List all issues in a project', serverName: 'linear' }, + ]); + }); + + test('token matching on tool name', () => { + const results = searchMcpTools('list issues', registry); + + expect(results.length).toBeGreaterThan(0); + expect(results[0].name).toBe('mcp_linear__list_issues'); + }); + + test('synonym expansion: "find" matches tools with "list"', () => { + const results = searchMcpTools('find issues', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + expect(names).toContain('mcp_linear__list_issues'); + }); + + test('synonym expansion: "add" matches tools with "create"', () => { + const results = searchMcpTools('add ticket', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + expect(names).toContain('mcp_jira__create_ticket'); + }); + + test('reverse expansion: canonical "list" expands to synonyms', () => { + // "list" is canonical and expands to "find", "get", "fetch", "search", etc. + const results = searchMcpTools('list repos', registry); + + expect(results.length).toBeGreaterThan(0); + const names = results.map(r => r.name); + // "search_repos" should match because "list" expands to "search" + expect(names).toContain('mcp_github__search_repos'); + }); + + test('higher-scoring entries come first', () => { + // "create issue" should score higher for create_issue than for list_issues + const results = searchMcpTools('create issue', registry); + + expect(results.length).toBeGreaterThan(1); + // The first result should be the one that matches both tokens + expect(results[0].name).toBe('mcp_linear__create_issue'); + }); + + test('topK limits results', () => { + const results = searchMcpTools('issue', registry, 2); + + expect(results.length).toBeLessThanOrEqual(2); + }); + + test('default topK is 5', () => { + // All 7 entries match "mcp" as a substring, but we need tokens > 2 chars + // Use a query that matches many entries + const largeRegistry: McpToolRegistryEntry[] = Array.from({ length: 10 }, (_, i) => ({ + name: `mcp_server__tool_${i}`, + description: `Tool number ${i} for testing`, + serverName: 'server', + })); + + const results = searchMcpTools('tool testing', largeRegistry); + + expect(results.length).toBeLessThanOrEqual(5); + }); + + test('short/empty query fallback returns first topK entries', () => { + // "do it" — all tokens are <= 2 chars after filtering + const results = searchMcpTools('do it', registry); + + expect(results).toEqual(registry.slice(0, 5)); + }); + + test('empty string query fallback returns first topK entries', () => { + const results = searchMcpTools('', registry); + + expect(results).toEqual(registry.slice(0, 5)); + }); + + test('returns empty array when no tokens match', () => { + const results = searchMcpTools('xyznonexistent', registry); + + expect(results).toEqual([]); + }); + + test('search matches in description, not just name', () => { + const results = searchMcpTools('database', registry); + + expect(results.length).toBeGreaterThan(0); + expect(results[0].name).toBe('mcp_pg__run_query'); + }); + + test('tokens shorter than 3 chars are filtered out', () => { + // "do a list" → only "list" survives (length > 2) + const results = searchMcpTools('do a list', registry); + + expect(results.length).toBeGreaterThan(0); + // Should still find results via the "list" token + const names = results.map(r => r.name); + expect(names).toContain('mcp_linear__list_issues'); + }); +}); diff --git a/packages/web/src/ee/features/mcp/mcpToolRegistry.ts b/packages/web/src/ee/features/mcp/mcpToolRegistry.ts new file mode 100644 index 000000000..431710e9e --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpToolRegistry.ts @@ -0,0 +1,99 @@ +import type { MCPClient } from '@ai-sdk/mcp'; + +export interface McpToolRegistryEntry { + name: string; + description: string; + serverName: string; +} + +type McpToolRecord = Awaited>; + +// Synonym map for common action words. Expands query tokens so that e.g. +// "find tickets" matches a tool named "list_issues". +// Module-level constant — built once at server startup, never re-created. +const SYNONYM_MAP: Record = { + list: ['find', 'get', 'fetch', 'retrieve', 'search', 'show', 'query', 'read'], + create: ['make', 'add', 'post', 'open', 'new', 'submit', 'write'], + update: ['edit', 'modify', 'change', 'patch', 'set'], + delete: ['remove', 'destroy', 'archive', 'close'], + send: ['post', 'publish', 'notify', 'message'], + issue: ['ticket', 'bug', 'task', 'item', 'work'], + comment: ['note', 'reply', 'respond'], + user: ['member', 'person', 'assignee'], + project: ['repo', 'repository', 'workspace'], +}; + +// Reverse lookup: synonym → canonical token. Built once from SYNONYM_MAP. +const REVERSE_SYNONYMS: Record = {}; +for (const [canonical, synonyms] of Object.entries(SYNONYM_MAP)) { + for (const synonym of synonyms) { + REVERSE_SYNONYMS[synonym] = canonical; + } +} + +function expandTokens(tokens: string[]): string[] { + const expanded = new Set(tokens); + for (const token of tokens) { + const canonical = REVERSE_SYNONYMS[token]; + if (canonical) { + expanded.add(canonical); + } + const synonyms = SYNONYM_MAP[token]; + if (synonyms) { + for (const s of synonyms) { + expanded.add(s); + } + } + } + return Array.from(expanded); +} + +export function buildMcpToolRegistry(tools: McpToolRecord): McpToolRegistryEntry[] { + return Object.entries(tools).map(([name, tool]) => { + const match = name.match(/^mcp_(.+?)__/); + const serverName = match ? match[1] : ''; + return { + name, + description: tool.description ?? '', + serverName, + }; + }); +} + +export function searchMcpTools( + query: string, + registry: McpToolRegistryEntry[], + topK = 5, +): McpToolRegistryEntry[] { + // Fast path: if the query is an exact tool name, return it directly. + const exactMatch = registry.find(e => e.name === query); + if (exactMatch) { + return [exactMatch]; + } + + const rawTokens = query + .toLowerCase() + .split(/\W+/) + .filter(t => t.length > 2); + + // If no meaningful tokens remain (e.g. query is "do it" — all tokens <= 2 chars), + // fall back to returning the first topK tools rather than returning nothing. + // We could potentially return nothing or return another tool that will help search better + // in the future. + if (rawTokens.length === 0) { + return registry.slice(0, topK); + } + + const tokens = expandTokens(rawTokens); + + return registry + .map(entry => { + const haystack = `${entry.name} ${entry.description}`.toLowerCase(); + const score = tokens.filter(t => haystack.includes(t)).length; + return { entry, score }; + }) + .filter(({ score }) => score > 0) + .sort((a, b) => b.score - a.score) + .slice(0, topK) + .map(({ entry }) => entry); +} \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/mcpToolSets.test.ts b/packages/web/src/ee/features/mcp/mcpToolSets.test.ts new file mode 100644 index 000000000..d49f56986 --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpToolSets.test.ts @@ -0,0 +1,284 @@ +import { expect, test, describe, vi, beforeEach } from 'vitest'; +import type { McpToolSet } from './mcpClientFactory'; + +// --- Mocks --- + +const mockCreateMCPClient = vi.fn(); + +vi.mock('@ai-sdk/mcp', () => ({ + createMCPClient: (...args: unknown[]) => mockCreateMCPClient(...args), +})); + +vi.mock('@sourcebot/shared', () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), + env: { + SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS: 5000, + }, +})); + +vi.mock('ai', () => ({ + jsonSchema: vi.fn((schema: unknown, opts: unknown) => ({ schema, ...(opts as object) })), +})); + +// --- Helpers --- + +interface MockToolDef { + name: string; + description?: string; + inputSchema?: Record; + annotations?: Record; +} + +function createMockMcpClient(toolDefs: MockToolDef[]) { + const toolRecord: Record; description: string | undefined; inputSchema: unknown }> = {}; + for (const def of toolDefs) { + toolRecord[def.name] = { + execute: vi.fn().mockResolvedValue({ content: [{ type: 'text', text: 'result' }] }), + description: def.description, + inputSchema: def.inputSchema ?? {}, + }; + } + + return { + listTools: vi.fn().mockResolvedValue({ tools: toolDefs }), + toolsFromDefinitions: vi.fn().mockReturnValue(toolRecord), + close: vi.fn().mockResolvedValue(undefined), + tools: vi.fn().mockResolvedValue(toolRecord), + }; +} + +function createMockClient(overrides: Partial & { serverName: string }): McpToolSet { + return { + serverId: 'server-id', + serverUrl: `https://${overrides.serverName.toLowerCase()}.example.com/mcp`, + transport: {} as McpToolSet['transport'], + ...overrides, + }; +} + +// --- Tests --- + +// Import after mocks are set up +const { getMcpTools } = await import('./mcpToolSets'); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe('getMcpTools', () => { + test('single server with single tool produces correctly namespaced key', async () => { + const mockClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + expect(Object.keys(result.tools)).toEqual(['mcp_linear__list_issues']); + expect(result.failedServers).toEqual([]); + }); + + test('multiple servers produce tools with distinct prefixes', async () => { + const linearClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + const githubClient = createMockMcpClient([ + { name: 'search_repos', description: 'Search repos' }, + ]); + + mockCreateMCPClient + .mockResolvedValueOnce(linearClient) + .mockResolvedValueOnce(githubClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + createMockClient({ serverName: 'GitHub' }), + ]); + + const toolNames = Object.keys(result.tools); + expect(toolNames).toContain('mcp_linear__list_issues'); + expect(toolNames).toContain('mcp_github__search_repos'); + }); + + test('read-only tool does NOT get needsApproval', async () => { + const mockClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues', annotations: { readOnlyHint: true } }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__list_issues']; + expect(tool).toBeDefined(); + expect('needsApproval' in tool).toBe(false); + }); + + test('non-read-only tool gets needsApproval: true', async () => { + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + expect(tool).toBeDefined(); + expect(tool).toHaveProperty('needsApproval', true); + }); + + test('failed server connection adds to failedServers array', async () => { + mockCreateMCPClient.mockRejectedValue(new Error('Connection refused')); + + const result = await getMcpTools([ + createMockClient({ serverName: 'BrokenServer' }), + ]); + + expect(result.failedServers).toEqual(['BrokenServer']); + expect(Object.keys(result.tools)).toEqual([]); + }); + + test('failed server does not prevent other servers from working', async () => { + const goodClient = createMockMcpClient([ + { name: 'list_issues', description: 'List issues' }, + ]); + + mockCreateMCPClient + .mockRejectedValueOnce(new Error('Connection refused')) + .mockResolvedValueOnce(goodClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'BrokenServer' }), + createMockClient({ serverName: 'Linear' }), + ]); + + expect(result.failedServers).toEqual(['BrokenServer']); + expect(Object.keys(result.tools)).toEqual(['mcp_linear__list_issues']); + }); + + test('generates favicon URL from server URL origin', async () => { + const mockClient = createMockMcpClient([ + { name: 'tool', description: 'A tool' }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear', serverUrl: 'https://api.linear.app/mcp' }), + ]); + + expect(result.serverFaviconUrls['linear']).toBe( + 'https://www.google.com/s2/favicons?domain=https://api.linear.app&sz=32' + ); + }); + + test('cleanup function calls close on all clients', async () => { + const client1 = createMockMcpClient([{ name: 'tool1', description: 'Tool 1' }]); + const client2 = createMockMcpClient([{ name: 'tool2', description: 'Tool 2' }]); + + mockCreateMCPClient + .mockResolvedValueOnce(client1) + .mockResolvedValueOnce(client2); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Server1' }), + createMockClient({ serverName: 'Server2' }), + ]); + + await result.cleanup(); + + expect(client1.close).toHaveBeenCalledOnce(); + expect(client2.close).toHaveBeenCalledOnce(); + }); + + test('cleanup handles errors in close gracefully', async () => { + const client1 = createMockMcpClient([{ name: 'tool1', description: 'Tool 1' }]); + const client2 = createMockMcpClient([{ name: 'tool2', description: 'Tool 2' }]); + client1.close.mockRejectedValue(new Error('Close failed')); + + mockCreateMCPClient + .mockResolvedValueOnce(client1) + .mockResolvedValueOnce(client2); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Server1' }), + createMockClient({ serverName: 'Server2' }), + ]); + + // Should not throw + await expect(result.cleanup()).resolves.toBeUndefined(); + expect(client2.close).toHaveBeenCalledOnce(); + }); + + test('empty clients array returns empty result', async () => { + const result = await getMcpTools([]); + + expect(result.tools).toEqual({}); + expect(result.failedServers).toEqual([]); + expect(result.serverFaviconUrls).toEqual({}); + expect(typeof result.cleanup).toBe('function'); + }); + + test('tool schema validation rejects invalid input', async () => { + const mockClient = createMockMcpClient([ + { + name: 'create_issue', + description: 'Create issue', + inputSchema: { + type: 'object', + properties: { title: { type: 'string' } }, + }, + }, + ]); + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + // The inputSchema should have a validate function from our jsonSchema mock + const schema = tool.inputSchema as { validate?: (value: unknown) => Promise<{ success: boolean; error?: Error }> }; + expect(schema.validate).toBeDefined(); + + if (schema.validate) { + // Valid input + const validResult = await schema.validate({ title: 'My Issue' }); + expect(validResult.success).toBe(true); + + // Invalid input (extra property not allowed because additionalProperties: false) + const invalidResult = await schema.validate({ title: 'My Issue', bogus: 'field' }); + expect(invalidResult.success).toBe(false); + } + }); + + test('tool execute wrapper propagates non-timeout errors', async () => { + const originalError = new Error('External API failed'); + const mockClient = createMockMcpClient([ + { name: 'create_issue', description: 'Create issue' }, + ]); + // Override the execute to reject + const toolRecord = mockClient.toolsFromDefinitions(); + toolRecord['create_issue'].execute.mockRejectedValue(originalError); + + mockCreateMCPClient.mockResolvedValue(mockClient); + + const result = await getMcpTools([ + createMockClient({ serverName: 'Linear' }), + ]); + + const tool = result.tools['mcp_linear__create_issue']; + await expect( + tool.execute({}, { messages: [], toolCallId: 'test' }) + ).rejects.toThrow('External API failed'); + }); +}); diff --git a/packages/web/src/ee/features/mcp/mcpToolSets.ts b/packages/web/src/ee/features/mcp/mcpToolSets.ts new file mode 100644 index 000000000..91a235b8b --- /dev/null +++ b/packages/web/src/ee/features/mcp/mcpToolSets.ts @@ -0,0 +1,149 @@ +import { createMCPClient, type MCPClient } from '@ai-sdk/mcp'; +import { McpToolSet } from './mcpClientFactory'; +import { createLogger, env } from '@sourcebot/shared'; +import { sanitizeMcpServerName } from './utils'; +import Ajv from 'ajv'; +import { jsonSchema, ToolExecutionOptions } from 'ai'; +import type { JSONSchema7, JSONSchema7Definition } from 'json-schema'; + +const logger = createLogger('mcp-tool-sets'); +const ajv = new Ajv({ allErrors: true, strict: false }); + +class McpToolTimeoutError extends Error { + constructor(toolName: string, timeoutMs: number) { + super(`MCP tool "${toolName}" timed out after ${timeoutMs}ms`); + this.name = 'McpToolTimeoutError'; + } +} + +export interface McpToolsResult { + tools: Record>[string]>; + failedServers: string[]; + serverFaviconUrls: Record; + cleanup: () => Promise; +} + +/** + * Creates MCPClients from authenticated transports, retrieves their tools, + * and returns a namespaced tool record + cleanup function. + */ +export async function getMcpTools(clients: McpToolSet[]): Promise { + const allTools: McpToolsResult['tools'] = {}; + const failedServers: string[] = []; + const serverFaviconUrls: Record = {}; + const mcpClients: MCPClient[] = []; + + const connectionTimeoutMs = env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS; + + for (const { serverName, serverUrl, transport } of clients) { + try { + const mcpClient = await Promise.race([ + createMCPClient({ transport }), + new Promise((_, reject) => + setTimeout(() => reject(new Error(`Connection to MCP server "${serverName}" timed out after ${connectionTimeoutMs}ms`)), connectionTimeoutMs) + ), + ]); + mcpClients.push(mcpClient); + + const toolDefinitions = await Promise.race([ + mcpClient.listTools(), + new Promise((_, reject) => + setTimeout(() => reject(new Error(`Listing tools from MCP server "${serverName}" timed out after ${connectionTimeoutMs}ms`)), connectionTimeoutMs) + ), + ]); + const tools = mcpClient.toolsFromDefinitions(toolDefinitions); + const sanitizedName = sanitizeMcpServerName(serverName); + const prefix = `mcp_${sanitizedName}`; + + for (const [toolName, tool] of Object.entries(tools)) { + const def = toolDefinitions.tools.find(t => t.name === toolName); + const isReadOnly = (def?.annotations as Record | undefined)?.readOnlyHint === true; + + // The @ai-sdk/mcp library sets additionalProperties: false in the JSON schema + // sent to the model, but does NOT provide a validate function — so the AI SDK + // skips server-side validation entirely. We compile the schema with ajv to + // enforce parameter names at runtime, which allows experimental_repairToolCall + // to fire on InvalidToolInputError. + const rawSchema = def?.inputSchema ?? { type: 'object', properties: {} }; + const schema = { + ...rawSchema, + type: 'object' as const, + properties: (rawSchema.properties ?? {}) as Record, + additionalProperties: false, + } satisfies JSONSchema7; + const validate = ajv.compile(schema); + const validProperties = Object.keys(schema.properties); + const validatedInputSchema = jsonSchema(schema, { + validate: async (value: unknown) => { + if (validate(value)) { + return { success: true as const, value }; + } + return { + success: false as const, + error: new Error( + `${ajv.errorsText(validate.errors)}. The valid parameter names for this tool are: [${validProperties.join(', ')}]` + ), + }; + }, + }); + + const originalExecute = tool.execute; + const qualifiedName = `${prefix}__${toolName}`; + const timeoutMs = env.SOURCEBOT_MCP_TOOL_CALL_TIMEOUT_MS; + + const executeWithTimeout = (async (input: unknown, options: ToolExecutionOptions) => { + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const combinedSignal = options.abortSignal + ? AbortSignal.any([options.abortSignal, timeoutSignal]) + : timeoutSignal; + + try { + return await originalExecute(input, { + ...options, + abortSignal: combinedSignal, + }); + } catch (error) { + if (timeoutSignal.aborted) { + logger.warn(`MCP tool "${qualifiedName}" timed out after ${timeoutMs}ms`); + throw new McpToolTimeoutError(qualifiedName, timeoutMs); + } + throw error; + } + }) as typeof originalExecute; + + allTools[qualifiedName] = { + ...tool, + execute: executeWithTimeout, + // The @ai-sdk/mcp package bundles its own copy of @ai-sdk/provider-utils, + // so its Schema isn't structurally identical to the workspace copy. + // The runtime shape is the same; cast through `any` to bridge the duplicate + // type identity (the two FlexibleSchema types differ only by their internal + // schemaSymbol brand). + // eslint-disable-next-line @typescript-eslint/no-explicit-any + inputSchema: validatedInputSchema as any, + ...(isReadOnly ? {} : { needsApproval: true }), + }; + } + + const origin = new URL(serverUrl).origin; + serverFaviconUrls[sanitizedName] = `https://www.google.com/s2/favicons?domain=${origin}&sz=32`; + } catch (error) { + logger.error(`Failed to get tools from MCP server ${serverName}:`, error); + failedServers.push(serverName); + } + } + + const cleanup = async () => { + await Promise.allSettled( + mcpClients.map(async (client) => { + try { + await client.close(); + } catch (error) { + logger.error('Error closing MCP client:', error); + } + }) + ); + }; + + return { tools: allTools, failedServers, serverFaviconUrls, cleanup }; +} \ No newline at end of file diff --git a/packages/web/src/ee/features/mcp/utils.test.ts b/packages/web/src/ee/features/mcp/utils.test.ts new file mode 100644 index 000000000..c4a63ffc3 --- /dev/null +++ b/packages/web/src/ee/features/mcp/utils.test.ts @@ -0,0 +1,36 @@ +import { expect, test, describe } from 'vitest'; +import { sanitizeMcpServerName } from './utils'; + +describe('sanitizeMcpServerName', () => { + test('lowercases ASCII letters', () => { + expect(sanitizeMcpServerName('MyServer')).toBe('myserver'); + }); + + test('replaces special characters with underscores', () => { + expect(sanitizeMcpServerName('My Server!')).toBe('my_server_'); + }); + + test('preserves digits', () => { + expect(sanitizeMcpServerName('server123')).toBe('server123'); + }); + + test('replaces spaces and hyphens', () => { + expect(sanitizeMcpServerName('my-cool server')).toBe('my_cool_server'); + }); + + test('handles empty string', () => { + expect(sanitizeMcpServerName('')).toBe(''); + }); + + test('replaces unicode characters with underscores', () => { + expect(sanitizeMcpServerName('Ñoño')).toBe('_o_o'); + }); + + test('replaces all special characters', () => { + expect(sanitizeMcpServerName('@#$%')).toBe('____'); + }); + + test('returns already sanitized name unchanged', () => { + expect(sanitizeMcpServerName('linear')).toBe('linear'); + }); +}); diff --git a/packages/web/src/ee/features/mcp/utils.ts b/packages/web/src/ee/features/mcp/utils.ts new file mode 100644 index 000000000..3a0176dba --- /dev/null +++ b/packages/web/src/ee/features/mcp/utils.ts @@ -0,0 +1,11 @@ +/** + * Sanitizes an MCP server name into a lowercase alphanumeric string suitable + * for use as a tool-name prefix (e.g. "My Server!" → "my_server_"). + * + * This is used to namespace MCP tools (mcp_{sanitizedName}__{toolName}) and + * to key favicon maps. Must be kept consistent everywhere — collisions on + * this value are prevented at server-creation time. + */ +export function sanitizeMcpServerName(name: string): string { + return name.toLowerCase().replace(/[^a-z0-9]/g, '_'); +} \ No newline at end of file diff --git a/packages/web/src/features/chat/agent.ts b/packages/web/src/features/chat/agent.ts index 0efb706fc..412c98c28 100644 --- a/packages/web/src/features/chat/agent.ts +++ b/packages/web/src/features/chat/agent.ts @@ -6,17 +6,26 @@ import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; import { ProviderOptions } from "@ai-sdk/provider-utils"; import { createLogger, env } from "@sourcebot/shared"; import { + convertToModelMessages, createUIMessageStream, JSONValue, LanguageModel, ModelMessage, StopCondition, streamText, StreamTextResult, UIMessageStreamOnFinishCallback, UIMessageStreamOptions, - UIMessageStreamWriter + UIMessageStreamWriter, + tool, + Tool, + NoSuchToolError, } from "ai"; +import { z } from "zod"; import { randomUUID } from "crypto"; import _dedent from "dedent"; import { ANSWER_TAG, FILE_REFERENCE_PREFIX } from "./constants"; import { Source } from "./types"; import { addLineNumbers, fileReferenceToString } from "./utils"; import { createTools } from "./tools"; +import { getConnectedMcpClients } from "@/ee/features/mcp/mcpClientFactory"; +import { getMcpTools, McpToolsResult } from "@/ee/features/mcp/mcpToolSets"; +import { buildMcpToolRegistry, McpToolRegistryEntry, searchMcpTools } from "@/ee/features/mcp/mcpToolRegistry"; +import { hasEntitlement } from '@/lib/entitlements'; const dedent = _dedent.withOptions({ alignValues: true }); @@ -36,6 +45,9 @@ interface CreateMessageStreamResponseProps { chatId: string; messages: SBChatMessage[]; selectedRepos: string[]; + // When undefined, MCP tools are disabled entirely (e.g. programmatic callers like askCodebase). + // When an array, MCP tools are enabled for all servers not in the list. + disabledMcpServerIds?: string[]; model: AISDKLanguageModelV3; modelName: string; onFinish: UIMessageStreamOnFinishCallback; @@ -43,6 +55,8 @@ interface CreateMessageStreamResponseProps { modelProviderOptions?: Record>; modelTemperature?: number; metadata?: Partial; + userId?: string; + orgId?: number; } export const createMessageStream = async ({ @@ -50,12 +64,15 @@ export const createMessageStream = async ({ messages, metadata, selectedRepos, + disabledMcpServerIds, model, modelName, modelProviderOptions, modelTemperature, onFinish, onError, + userId, + orgId, }: CreateMessageStreamResponseProps) => { const latestMessage = messages[messages.length - 1]; const sources = latestMessage.parts @@ -66,7 +83,7 @@ export const createMessageStream = async ({ // Extract user messages and assistant answers. // We will use this as the context we carry between messages. - const messageHistory = + let messageHistory: ModelMessage[] = messages.map((message): ModelMessage | undefined => { if (message.role === 'user') { return { @@ -86,6 +103,28 @@ export const createMessageStream = async ({ } }).filter(message => message !== undefined); + // When the last assistant turn has approval responses (from the tool approval flow), + // the turn is incomplete — it has no answer text, only a pending tool call that was + // approved. We need to preserve the full tool call + approval so streamText can + // execute the approved tool and continue. + const lastMsg = messages[messages.length - 1]; + const hasApprovalResponses = lastMsg?.role === 'assistant' && + lastMsg.parts.some(p => p.type === 'dynamic-tool' && p.state === 'approval-responded'); + + // When continuing after tool approval, capture the prior turn's metadata + // so we can aggregate token counts and response times across phases. + const priorMetadata = hasApprovalResponses + ? (lastMsg.metadata as SBChatMessageMetadata | undefined) + : undefined; + + if (hasApprovalResponses) { + const fullLastTurn = await convertToModelMessages( + [lastMsg], + { ignoreIncompleteToolCalls: true } + ); + messageHistory = [...messageHistory, ...fullLastTurn]; + } + const stream = createUIMessageStream({ execute: async ({ writer }) => { writer.write({ @@ -101,17 +140,33 @@ export const createMessageStream = async ({ inputMessages: messageHistory, inputSources: sources, selectedRepos, + disabledMcpServerIds, onWriteSource: (source) => { writer.write({ type: 'data-source', data: source, }); }, + onMcpServerDiscovered: (sanitizedName, faviconUrl) => { + writer.write({ + type: 'data-mcp-server', + data: { sanitizedName, faviconUrl }, + }); + }, + onMcpServerFailed: (serverName) => { + writer.write({ + type: 'data-mcp-failed-server', + data: { serverName }, + }); + }, traceId, chatId, + userId, + orgId, }); await mergeStreamAsync(researchStream, writer, { + originalMessages: messages, sendReasoning: true, sendStart: false, sendFinish: false, @@ -122,10 +177,10 @@ export const createMessageStream = async ({ writer.write({ type: 'message-metadata', messageMetadata: { - totalTokens: totalUsage.totalTokens, - totalInputTokens: totalUsage.inputTokens, - totalOutputTokens: totalUsage.outputTokens, - totalResponseTimeMs: new Date().getTime() - startTime.getTime(), + totalTokens: (priorMetadata?.totalTokens ?? 0) + (totalUsage.totalTokens ?? 0), + totalInputTokens: (priorMetadata?.totalInputTokens ?? 0) + (totalUsage.inputTokens ?? 0), + totalOutputTokens: (priorMetadata?.totalOutputTokens ?? 0) + (totalUsage.outputTokens ?? 0), + totalResponseTimeMs: (priorMetadata?.totalResponseTimeMs ?? 0) + (new Date().getTime() - startTime.getTime()), modelName, traceId, ...metadata, @@ -149,11 +204,16 @@ interface AgentOptions { providerOptions?: ProviderOptions; temperature?: number; selectedRepos: string[]; + disabledMcpServerIds?: string[]; inputMessages: ModelMessage[]; inputSources: Source[]; onWriteSource: (source: Source) => void; + onMcpServerDiscovered: (sanitizedName: string, faviconUrl: string) => void; + onMcpServerFailed: (serverName: string) => void; traceId: string; chatId: string; + userId?: string; + orgId?: number; } const createAgentStream = async ({ @@ -163,9 +223,14 @@ const createAgentStream = async ({ inputMessages, inputSources, selectedRepos, + disabledMcpServerIds, onWriteSource, + onMcpServerDiscovered, + onMcpServerFailed, traceId, chatId, + userId, + orgId, }: AgentOptions) => { // For every file source, resolve the source code so that we can include it in the system prompt. const fileSources = inputSources.filter((source) => source.type === 'file'); @@ -192,48 +257,162 @@ const createAgentStream = async ({ })) ).filter((source) => source !== undefined); + let mcpToolSetsObj: McpToolsResult = { tools: {}, failedServers: [], serverFaviconUrls: {}, cleanup: async () => {} }; + if (userId && orgId && await hasEntitlement('oauth') && disabledMcpServerIds !== undefined) { + try { + const allMcpClients = await getConnectedMcpClients(userId, orgId); + const mcpClients = allMcpClients.filter((c) => !disabledMcpServerIds.includes(c.serverId)); + mcpToolSetsObj = await getMcpTools(mcpClients); + + for (const [sanitizedName, faviconUrl] of Object.entries(mcpToolSetsObj.serverFaviconUrls)) { + onMcpServerDiscovered(sanitizedName, faviconUrl); + } + + if (mcpClients.length > 0) { + logger.info(`Connected to ${mcpClients.length} external MCP server(s): ${mcpClients.map(c => c.serverName).join(', ')}`); + } + } catch (error) { + logger.error('Failed to connect external MCP servers:', error); + } + } + + for (const serverName of mcpToolSetsObj.failedServers) { + onMcpServerFailed(serverName); + } + + const mcpRegistry = buildMcpToolRegistry(mcpToolSetsObj.tools); + const hasMcpTools = mcpRegistry.length > 0; + + const toolRequestActivation = tool({ + description: dedent` + Activate an MCP tool by name so it becomes callable on your next step. + You MUST pass an exact tool name from the tool registry in the system prompt. + Do NOT pass natural language descriptions or sentences. + If you need multiple tools, call this once per tool. + + Examples: + CORRECT: tool_to_activate_name="mcp_linear__save_comment" + CORRECT: tool_to_activate_name="mcp_linear__create_attachment" + INCORRECT: tool_to_activate_name="create a linear issue and update status" + INCORRECT: tool_to_activate_name="find tools for commenting on issues" + `, + inputSchema: z.object({ + tool_to_activate_name: z.string().describe('Exact tool name from the registry, e.g. "mcp_linear__save_comment"'), + }), + execute: async ({ tool_to_activate_name }) => { + const results = searchMcpTools(tool_to_activate_name, mcpRegistry); + return { + results: results.map(e => ({ name: e.name, description: e.description })), + }; + }, + }); + const systemPrompt = createPrompt({ repos: selectedRepos, files: resolvedFileSources, + mcpToolRegistry: mcpRegistry, }); - const stream = streamText({ - model, - providerOptions, - messages: inputMessages, - system: systemPrompt, - tools: createTools({ source: 'sourcebot-ask-agent', selectedRepos }), - temperature: temperature ?? env.SOURCEBOT_CHAT_MODEL_TEMPERATURE, - stopWhen: [ - stepCountIsGTE(env.SOURCEBOT_CHAT_MAX_STEP_COUNT), - ], - toolChoice: "auto", - onStepFinish: ({ toolResults }) => { - toolResults.forEach(({ output, dynamic }) => { - if (dynamic || isServiceError(output)) { - return; + const builtinTools = createTools({ source: 'sourcebot-ask-agent', selectedRepos }); + const builtinToolNames = Object.keys(builtinTools); + const allTools: Record = { + ...builtinTools, + ...(hasMcpTools ? { tool_request_activation: toolRequestActivation, ...mcpToolSetsObj.tools } : {}), + }; + + try { + const stream = streamText({ + model, + providerOptions, + messages: inputMessages, + system: systemPrompt, + tools: allTools, + activeTools: [ + ...builtinToolNames, + ...(hasMcpTools ? ['tool_request_activation'] : []), + ], + prepareStep: hasMcpTools ? ({ steps }) => { + const activated = new Set(); + for (const step of steps) { + for (const result of step.toolResults) { + if (!result || result.toolName !== 'tool_request_activation') { + continue; + } + const output = result.output as { results?: Array<{ name: string }> }; + for (const { name } of output?.results ?? []) { + if (name in mcpToolSetsObj.tools) { + activated.add(name); + } + } + } + } + return { + activeTools: [ + ...builtinToolNames, + 'tool_request_activation', + ...Array.from(activated), + ], + }; + } : undefined, + temperature: temperature ?? env.SOURCEBOT_CHAT_MODEL_TEMPERATURE, + stopWhen: [ + stepCountIsGTE(env.SOURCEBOT_CHAT_MAX_STEP_COUNT), + ], + toolChoice: "auto", + experimental_repairToolCall: async ({ toolCall, tools, error }) => { + // Fix case mismatches (e.g. model outputs "Mcp_Linear__Save_Comment" instead of "mcp_linear__save_comment") + if (NoSuchToolError.isInstance(error)) { + const lower = toolCall.toolName.toLowerCase(); + if (lower !== toolCall.toolName && lower in tools) { + return { ...toolCall, toolName: lower }; + } } - output.sources?.forEach(onWriteSource); - }); - }, - experimental_telemetry: { - isEnabled: env.SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED === 'true', - metadata: { - langfuseTraceId: traceId, + // For anything we can't fix, return null. + // The AI SDK will mark the call as invalid and pass the error + // back to the model so it can retry with correct parameters. + logger.warn(`Tool call repair failed for "${toolCall.toolName}": ${error.message}`); + return null; }, - }, - onError: (error) => { - logger.error(error); - }, - }); + onStepFinish: ({ toolResults }) => { + toolResults.forEach(({ output, dynamic }) => { + if (dynamic || isServiceError(output)) { + return; + } - return stream; + output.sources?.forEach(onWriteSource); + }); + }, + experimental_telemetry: { + isEnabled: env.SOURCEBOT_TELEMETRY_PII_COLLECTION_ENABLED === 'true', + metadata: { + langfuseTraceId: traceId, + }, + }, + onError: (error) => { + logger.error(error); + }, + }); + + // Clean up MCP transport connections once the stream completes (success or failure). + stream.response.then( + () => mcpToolSetsObj.cleanup(), + () => mcpToolSetsObj.cleanup() + ); + return stream; + } catch (error) { + // If anything between MCP setup and stream return throws, ensure we + // still close the MCP transport connections to avoid leaking them. + await mcpToolSetsObj.cleanup(); + throw error; + } } + const createPrompt = ({ files, repos, + mcpToolRegistry, }: { files?: { path: string; @@ -243,6 +422,7 @@ const createPrompt = ({ revision: string; }[], repos: string[], + mcpToolRegistry: McpToolRegistryEntry[], }) => { return dedent` You are a powerful agentic AI code assistant built into Sourcebot, the world's best code-intelligence platform. Your job is to help developers understand and navigate their large codebases. @@ -287,6 +467,18 @@ const createPrompt = ({ `: ''} + ${(mcpToolRegistry.length > 0) ? dedent` + + External MCP tools are available but must first be activated via \`tool_request_activation\`. + + **CRITICAL**: The list below is the complete and authoritative inventory of all tools available to you: + ${mcpToolRegistry.map(e => `- ${e.name}: ${e.description}`).join('\n')} + + **How to use tool_request_activation**: Pass the exact tool name from the list above as the \`tool_to_activate_name\` parameter. Do NOT pass natural language descriptions or sentences. If you need multiple tools, call \`tool_request_activation\` once per tool. + Example: to activate the comment tool, call \`tool_request_activation\` with tool_to_activate_name="mcp_linear__save_comment", NOT tool_to_activate_name="save a comment on an issue". + + ` : ''} + When you have sufficient context, output your answer as a structured markdown response. diff --git a/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx b/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx new file mode 100644 index 000000000..882e75ce2 --- /dev/null +++ b/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx @@ -0,0 +1,151 @@ +'use client'; + +import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Switch } from "@/components/ui/switch"; +import { getMcpServersWithStatus } from "@/app/api/(client)/client"; +import { isServiceError } from "@/lib/utils"; +import { useQuery } from "@tanstack/react-query"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { AlertTriangleIcon, Plug, PlusIcon, RefreshCwIcon, ServerIcon, SettingsIcon } from "lucide-react"; +import { PlusButtonInfoCard } from "./plusButtonInfoCard"; +import { useRouter } from "next/navigation"; +import { useState } from "react"; + +interface ChatBoxPlusButtonProps { + disabledMcpServerIds: string[]; + onDisabledMcpServerIdsChange: (ids: string[]) => void; +} + +export const ChatBoxPlusButton = ({ + disabledMcpServerIds, + onDisabledMcpServerIdsChange, +}: ChatBoxPlusButtonProps) => { + const [failedFavicons, setFailedFavicons] = useState>(new Set()); + const router = useRouter(); + + const { data: servers, isError, refetch } = useQuery({ + queryKey: ['mcpServersWithStatus'], + queryFn: async () => { + const result = await getMcpServersWithStatus(); + if (isServiceError(result)) { + throw new Error("Failed to load MCP servers"); + } + return result; + }, + }); + + const onToggle = (serverId: string, checked: boolean) => { + if (checked) { + onDisabledMcpServerIdsChange(disabledMcpServerIds.filter((id) => id !== serverId)); + } else { + onDisabledMcpServerIdsChange([...disabledMcpServerIds, serverId]); + } + }; + + const onFaviconError = (serverId: string) => { + setFailedFavicons((prev) => new Set(prev).add(serverId)); + }; + + // Only surface servers the user has attempted to connect (connected or auth expired). + const relevantServers = servers?.filter((s) => s.isConnected || s.isAuthExpired) ?? []; + + return ( + + + + + + + + + + + + e.preventDefault()}> + + + + MCP Servers + + + {isError && relevantServers.length === 0 ? ( + { + e.preventDefault(); + refetch(); + }} + className="gap-2 text-destructive" + > + + Failed to load. Retry? + + ) : relevantServers.length === 0 ? ( + + No MCP servers connected + + ) : ( + relevantServers.map((server) => { + const isEnabled = !server.isAuthExpired && !disabledMcpServerIds.includes(server.id); + return ( + e.preventDefault()} + disabled={server.isAuthExpired} + className="flex items-center justify-between gap-2" + > +
+ {server.isAuthExpired ? ( + + ) : failedFavicons.has(server.id) ? ( + + ) : ( + // eslint-disable-next-line @next/next/no-img-element + onFaviconError(server.id)} + className="w-4 h-4 shrink-0 rounded-sm" + alt="" + /> + )} + {server.name} +
+ onToggle(server.id, checked)} + disabled={server.isAuthExpired} + className="scale-75" + /> +
+ ); + }) + )} + + router.push(`/settings/mcpServers`)} + > + + Manage MCP servers + +
+
+
+
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx b/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx index a0aae38cf..280f7f9bf 100644 --- a/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx +++ b/packages/web/src/features/chat/components/chatBox/chatBoxToolbar.tsx @@ -5,6 +5,7 @@ import { LanguageModelInfo, SearchScope } from "@/features/chat/types"; import { RepositoryQuery, SearchContextQuery } from "@/lib/types"; import { useSelectedLanguageModel } from "../../useSelectedLanguageModel"; import { AtMentionButton } from "./atMentionButton"; +import { ChatBoxPlusButton } from "./chatBoxPlusButton"; import { LanguageModelSelector } from "./languageModelSelector"; import { SearchScopeSelector } from "./searchScopeSelector"; @@ -16,6 +17,10 @@ export interface ChatBoxToolbarProps { onSelectedSearchScopesChange: (items: SearchScope[]) => void; isContextSelectorOpen: boolean; onContextSelectorOpenChanged: (isOpen: boolean) => void; + // TODO_Jack_MakeLinearTask: Make the plus button available on simplified toolbar usages (e.g. askgh) + // once additional features (beyond MCP server toggling) are added to it. + disabledMcpServerIds?: string[]; + onDisabledMcpServerIdsChange?: (ids: string[]) => void; } export const ChatBoxToolbar = ({ @@ -26,6 +31,8 @@ export const ChatBoxToolbar = ({ onSelectedSearchScopesChange, isContextSelectorOpen, onContextSelectorOpenChanged, + disabledMcpServerIds, + onDisabledMcpServerIdsChange, }: ChatBoxToolbarProps) => { const { selectedLanguageModel, setSelectedLanguageModel } = useSelectedLanguageModel({ languageModels, @@ -33,6 +40,15 @@ export const ChatBoxToolbar = ({ return ( <> + {disabledMcpServerIds !== undefined && onDisabledMcpServerIdsChange !== undefined && ( + <> + + + + )} { + return ( +
+
+ +

Extra Features

+
+
+ Add MCP servers, include files and more. +
+
+ ); +}; \ No newline at end of file diff --git a/packages/web/src/features/chat/components/chatThread/chatThread.tsx b/packages/web/src/features/chat/components/chatThread/chatThread.tsx index f60d281b7..af0aee3cc 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThread.tsx +++ b/packages/web/src/features/chat/components/chatThread/chatThread.tsx @@ -7,10 +7,10 @@ import { CustomSlateEditor } from '@/features/chat/customSlateEditor'; import { AdditionalChatRequestParams, CustomEditor, LanguageModelInfo, SBChatMessage, SearchScope, Source } from '@/features/chat/types'; import { createUIMessage, getAllMentionElements, resetEditor, slateContentToString } from '@/features/chat/utils'; import { useChat } from '@ai-sdk/react'; -import { CreateUIMessage, DefaultChatTransport } from 'ai'; +import { CreateUIMessage, DefaultChatTransport, lastAssistantMessageIsCompleteWithApprovalResponses } from 'ai'; import { ArrowDownIcon, CopyIcon } from 'lucide-react'; import { useNavigationGuard } from 'next-navigation-guard'; -import { Fragment, useCallback, useEffect, useRef, useState } from 'react'; +import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useStickToBottom } from 'use-stick-to-bottom'; import { Descendant } from 'slate'; import { useMessagePairs } from '../../useMessagePairs'; @@ -19,12 +19,15 @@ import { ChatBox } from '../chatBox'; import { ChatBoxToolbar } from '../chatBox/chatBoxToolbar'; import { ChatThreadListItem } from './chatThreadListItem'; import { ErrorBanner } from './errorBanner'; +import { McpFailedServersBanner } from './mcpFailedServersBanner'; import { useRouter } from 'next/navigation'; import { usePrevious } from '@uidotdev/usehooks'; import { RepositoryQuery, SearchContextQuery } from '@/lib/types'; import { duplicateChat, generateAndUpdateChatNameFromMessage } from '../../actions'; import { isServiceError } from '@/lib/utils'; import { NotConfiguredErrorBanner } from '../notConfiguredErrorBanner'; +import { McpServerIconContext, McpServerIconMap } from '../../mcpServerIconContext'; +import { ToolApprovalProvider } from '../../toolApprovalContext'; import useCaptureEvent from '@/hooks/useCaptureEvent'; import { SignInPromptBanner } from './signInPromptBanner'; import { DuplicateChatDialog } from '@/app/(app)/chat/components/duplicateChatDialog'; @@ -47,6 +50,8 @@ interface ChatThreadProps { searchContexts: SearchContextQuery[]; selectedSearchScopes: SearchScope[]; onSelectedSearchScopesChange: (items: SearchScope[]) => void; + disabledMcpServerIds: string[]; + onDisabledMcpServerIdsChange: (ids: string[]) => void; isOwner?: boolean; isAuthenticated?: boolean; chatName?: string; @@ -61,6 +66,8 @@ export const ChatThread = ({ searchContexts, selectedSearchScopes, onSelectedSearchScopesChange, + disabledMcpServerIds, + onDisabledMcpServerIdsChange, isOwner = true, isAuthenticated = false, chatName, @@ -86,13 +93,66 @@ export const ChatThread = ({ ) ?? [] ); + const [mcpServerIconMap, setMcpServerIconMap] = useState(() => { + const map: McpServerIconMap = {}; + initialMessages?.forEach((message) => { + message.parts + .filter((part) => part.type === 'data-mcp-server') + .forEach((part) => { + map[part.data.sanitizedName] = part.data.faviconUrl; + }); + }); + return map; + }); + + const [failedMcpServers, setFailedMcpServers] = useState(() => { + const names: string[] = []; + initialMessages?.forEach((message) => { + message.parts + .filter((part) => part.type === 'data-mcp-failed-server') + .forEach((part) => { + if (!names.includes(part.data.serverName)) { + names.push(part.data.serverName); + } + }); + }); + return names; + }); + const [isFailedMcpBannerVisible, setIsFailedMcpBannerVisible] = useState(false); + const { selectedLanguageModel } = useSelectedLanguageModel({ languageModels, }); + // Refs to capture the latest request params for the transport body. + // The transport is created once (useMemo) but params change over time, + // so refs ensure the dynamic body function always reads current values. + const searchScopesRef = useRef(selectedSearchScopes); + const modelRef = useRef(selectedLanguageModel); + const disabledMcpRef = useRef(disabledMcpServerIds); + + useEffect(() => { searchScopesRef.current = selectedSearchScopes; }, [selectedSearchScopes]); + useEffect(() => { modelRef.current = selectedLanguageModel; }, [selectedLanguageModel]); + useEffect(() => { disabledMcpRef.current = disabledMcpServerIds; }, [disabledMcpServerIds]); + + // Transport with dynamic body — resolved on every request (including auto-resends + // triggered by sendAutomaticallyWhen after tool approval). + const transport = useMemo(() => new DefaultChatTransport({ + api: '/api/chat', + headers: { + 'X-Sourcebot-Client-Source': 'sourcebot-web-client', + }, + body: () => ({ + selectedSearchScopes: searchScopesRef.current, + languageModel: modelRef.current, + disabledMcpServerIds: disabledMcpRef.current, + }), + }), []); + const { messages, sendMessage: _sendMessage, + addToolApprovalResponse, error, status, stop, @@ -100,17 +160,28 @@ export const ChatThread = ({ } = useChat({ id: defaultChatId, messages: initialMessages, - transport: new DefaultChatTransport({ - api: '/api/chat', - headers: { - 'X-Sourcebot-Client-Source': 'sourcebot-web-client', - }, - }), + transport, + sendAutomaticallyWhen: lastAssistantMessageIsCompleteWithApprovalResponses, onData: (dataPart) => { // Keeps sources added by the assistant in sync. if (dataPart.type === 'data-source') { setSources((prev) => [...prev, dataPart.data]); } + if (dataPart.type === 'data-mcp-server') { + setMcpServerIconMap((prev) => ({ + ...prev, + [dataPart.data.sanitizedName]: dataPart.data.faviconUrl, + })); + } + if (dataPart.type === 'data-mcp-failed-server') { + setFailedMcpServers((prev) => { + if (prev.includes(dataPart.data.serverName)) { + return prev; + } + return [...prev, dataPart.data.serverName]; + }); + setIsFailedMcpBannerVisible(true); + } } }); @@ -133,6 +204,7 @@ export const ChatThread = ({ body: { selectedSearchScopes, languageModel: selectedLanguageModel, + disabledMcpServerIds, } satisfies AdditionalChatRequestParams, }); @@ -162,6 +234,7 @@ export const ChatThread = ({ selectedLanguageModel, _sendMessage, selectedSearchScopes, + disabledMcpServerIds, messages.length, toast, chatId, @@ -231,13 +304,13 @@ export const ChatThread = ({ const text = slateContentToString(children); const mentions = getAllMentionElements(children); - const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes); + const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes, disabledMcpServerIds); sendMessage(message); scrollToBottom(); } catch (error) { console.error('Failed to restore pending message:', error); } - }, [isAuthenticated, isOwner, chatId, sendMessage, selectedSearchScopes, scrollToBottom]); + }, [isAuthenticated, isOwner, chatId, sendMessage, selectedSearchScopes, disabledMcpServerIds, scrollToBottom]); // Track scroll position for history state restoration. useEffect(() => { @@ -319,13 +392,13 @@ export const ChatThread = ({ const text = slateContentToString(children); const mentions = getAllMentionElements(children); - const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes); + const message = createUIMessage(text, mentions.map(({ data }) => data), selectedSearchScopes, disabledMcpServerIds); sendMessage(message); scrollToBottom(); resetEditor(editor); - }, [sendMessage, selectedSearchScopes, isAuthenticated, captureEvent, chatId, scrollToBottom]); + }, [sendMessage, selectedSearchScopes, disabledMcpServerIds, isAuthenticated, captureEvent, chatId, scrollToBottom]); const onDuplicate = useCallback(async (newName: string): Promise => { if (!defaultChatId) { @@ -347,7 +420,8 @@ export const ChatThread = ({ }, [defaultChatId, toast, router, captureEvent]); return ( - <> + + {error && ( setIsErrorBannerVisible(false)} /> )} + setIsFailedMcpBannerVisible(false)} + />
@@ -480,6 +561,7 @@ export const ChatThread = ({ providers={loginWallProviders} callbackUrl={typeof window !== 'undefined' ? window.location.href : ''} /> - + + ); } diff --git a/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx b/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx index 0cbd4b264..f56bd8f8b 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx +++ b/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx @@ -6,11 +6,13 @@ import { Skeleton } from '@/components/ui/skeleton'; import { CheckCircle, Loader2 } from 'lucide-react'; import { CSSProperties, forwardRef, memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import scrollIntoView from 'scroll-into-view-if-needed'; +import { DynamicToolUIPart } from "ai"; import { Reference, referenceSchema, SBChatMessage, Source } from "../../types"; import { useExtractReferences } from '../../useExtractReferences'; import { getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences, tryResolveFileReference } from '../../utils'; import { AnswerCard } from './answerCard'; import { DetailsCard } from './detailsCard'; +import { ToolApprovalBanner } from './toolApprovalBanner'; import { MarkdownRenderer, REFERENCE_PAYLOAD_ATTRIBUTE } from './markdownRenderer'; import { ReferencedSourcesListView } from './referencedSourcesListView'; import isEqual from "fast-deep-equal/react"; @@ -106,7 +108,8 @@ const ChatThreadListItemComponent = forwardRef { + if (!assistantMessage) { + return []; + } + return assistantMessage.parts.filter( + (part): part is DynamicToolUIPart => part.type === 'dynamic-tool' && part.state === 'approval-requested' + ); + }, [assistantMessage]); + // Auto-collapse when answer first appears, but only once and respect user preference useEffect(() => { @@ -364,6 +377,10 @@ const ChatThreadListItemComponent = forwardRef + {approvalRequestedParts.length > 0 && ( + + )} + {(answerPart && assistantMessage) ? ( - ) : !isStreaming && ( + ) : !isStreaming && approvalRequestedParts.length === 0 && (

Error: No answer response was provided

)}
diff --git a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx b/packages/web/src/features/chat/components/chatThread/detailsCard.tsx index 0e2365ea6..5997df6e7 100644 --- a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx +++ b/packages/web/src/features/chat/components/chatThread/detailsCard.tsx @@ -25,6 +25,8 @@ import { ListReposToolComponent } from './tools/listReposToolComponent'; import { ListTreeToolComponent } from './tools/listTreeToolComponent'; import { ReadFileToolComponent } from './tools/readFileToolComponent'; import { ToolOutputGuard } from './tools/toolOutputGuard'; +import { McpToolComponent } from './tools/mcpToolComponent'; +import { ToolSearchToolComponent } from './tools/toolSearchToolComponent'; interface DetailsCardProps { @@ -48,7 +50,10 @@ const DetailsCardComponent = ({ }: DetailsCardProps) => { const captureEvent = useCaptureEvent(); - const toolCallCount = useMemo(() => thinkingSteps.flat().filter(part => part.type.startsWith('tool-')).length, [thinkingSteps]); + const toolCallCount = useMemo(() => thinkingSteps.flat().filter(part => + part.type.startsWith('tool-') || + (part.type === 'dynamic-tool' && part.toolName.startsWith('mcp_')) + ).length, [thinkingSteps]); const handleExpandedChanged = useCallback((next: boolean) => { captureEvent('wa_chat_details_card_toggled', { chatId, isExpanded: next }); @@ -308,8 +313,19 @@ export const StepPartRenderer = ({ part }: { part: SBChatMessagePart }) => { {(output) => } ) - case 'data-source': + case 'tool-tool_request_activation': + if (part.state !== 'output-available') { + return Activating tool...; + } + return ; case 'dynamic-tool': + if (part.toolName.startsWith('mcp_')) { + return ; + } + return null; + case 'data-source': + case 'data-mcp-server': + case 'data-mcp-failed-server': case 'file': case 'source-document': case 'source-url': diff --git a/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx b/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx new file mode 100644 index 000000000..0c74fe72f --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx @@ -0,0 +1,43 @@ +'use client'; + +import { Button } from '@/components/ui/button'; +import { AlertTriangle, X } from 'lucide-react'; + +interface McpFailedServersBannerProps { + serverNames: string[]; + isVisible: boolean; + onClose: () => void; +} + +export const McpFailedServersBanner = ({ serverNames, isVisible, onClose }: McpFailedServersBannerProps) => { + if (!isVisible || serverNames.length === 0) { + return null; + } + + const message = serverNames.length === 1 + ? `MCP server "${serverNames[0]}" failed to load tools` + : `${serverNames.length} MCP servers failed to load tools`; + + return ( +
+
+
+
+ + + {message} + +
+ +
+
+
+ ); +}; \ No newline at end of file diff --git a/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx b/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx new file mode 100644 index 000000000..0724c93b7 --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx @@ -0,0 +1,101 @@ +'use client'; + +import { Button } from "@/components/ui/button"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { useMcpServerIconMap } from "@/features/chat/mcpServerIconContext"; +import { useToolApproval } from "@/features/chat/toolApprovalContext"; +import { cn } from "@/lib/utils"; +import { DynamicToolUIPart } from "ai"; +import { ChevronRight } from "lucide-react"; +import { useCallback, useState } from "react"; +import { parseMcpToolName } from "./tools/mcpToolComponent"; +import { JsonHighlighter } from "./tools/jsonHighlighter"; + +interface ToolApprovalBannerProps { + parts: DynamicToolUIPart[]; +} + +export const ToolApprovalBanner = ({ parts }: ToolApprovalBannerProps) => { + const addToolApprovalResponse = useToolApproval(); + const iconMap = useMcpServerIconMap(); + + if (parts.length === 0) { + return null; + } + + return ( +
+ {parts.map((part) => ( + + ))} +
+ ); +}; + +const ToolApprovalItem = ({ + part, + addToolApprovalResponse, + iconMap, +}: { + part: DynamicToolUIPart; + addToolApprovalResponse: ReturnType; + iconMap: Record; +}) => { + const [isExpanded, setIsExpanded] = useState(false); + const parsed = parseMcpToolName(part.toolName); + const serverName = parsed?.serverName ?? part.toolName; + const toolName = parsed?.toolName ?? part.toolName; + const faviconUrl = parsed ? iconMap[parsed.serverName] : undefined; + + const hasInput = part.state !== 'input-streaming'; + const requestText = hasInput ? JSON.stringify(part.input, null, 2) : ''; + + const onToggle = useCallback(() => setIsExpanded(v => !v), []); + + const onApprove = useCallback(() => { + if (part.state === 'approval-requested' && addToolApprovalResponse) { + addToolApprovalResponse({ id: part.approval.id, approved: true }); + } + }, [part, addToolApprovalResponse]); + + const onDeny = useCallback(() => { + if (part.state === 'approval-requested' && addToolApprovalResponse) { + addToolApprovalResponse({ id: part.approval.id, approved: false, reason: 'User denied' }); + } + }, [part, addToolApprovalResponse]); + + return ( +
+
+ +
+ + +
+
+ {hasInput && isExpanded && ( +
+ +
+ )} +
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx b/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx new file mode 100644 index 000000000..18203a9de --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx @@ -0,0 +1,151 @@ +'use client'; + +export function unescapeJsonStrings(value: unknown): unknown { + if (typeof value === 'string') { + try { + const parsed: unknown = JSON.parse(value); + if (typeof parsed === 'object' && parsed !== null) { + return unescapeJsonStrings(parsed); + } + } catch { + // not JSON — leave as-is + } + return value; + } + if (Array.isArray(value)) { + return value.map(unescapeJsonStrings); + } + if (typeof value === 'object' && value !== null) { + return Object.fromEntries( + Object.entries(value).map(([k, v]) => [k, unescapeJsonStrings(v)]) + ); + } + return value; +} + +type TokenType = 'key' | 'string' | 'number' | 'boolean' | 'null' | 'structural' | 'whitespace' | 'other'; + +interface Token { + type: TokenType; + value: string; +} + +function tokenizeJson(text: string): Token[] { + const tokens: Token[] = []; + let i = 0; + + while (i < text.length) { + const ch = text[i]; + + // Whitespace + if (/\s/.test(ch)) { + let j = i + 1; + while (j < text.length && /\s/.test(text[j])) { + j++; + } + tokens.push({ type: 'whitespace', value: text.slice(i, j) }); + i = j; + continue; + } + + // String + if (ch === '"') { + let j = i + 1; + while (j < text.length) { + if (text[j] === '\\') { + j += 2; + } else if (text[j] === '"') { + j++; + break; + } else { + j++; + } + } + const str = text.slice(i, j); + + // Lookahead past whitespace for a colon → this is a key + let k = j; + while (k < text.length && /\s/.test(text[k])) { + k++; + } + const isKey = text[k] === ':'; + + tokens.push({ type: isKey ? 'key' : 'string', value: str }); + i = j; + continue; + } + + // Number + if (ch === '-' || /\d/.test(ch)) { + const match = text.slice(i).match(/^-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?/); + if (match) { + tokens.push({ type: 'number', value: match[0] }); + i += match[0].length; + continue; + } + } + + // Boolean / null keywords + if (text.slice(i, i + 4) === 'true') { + tokens.push({ type: 'boolean', value: 'true' }); + i += 4; + continue; + } + if (text.slice(i, i + 5) === 'false') { + tokens.push({ type: 'boolean', value: 'false' }); + i += 5; + continue; + } + if (text.slice(i, i + 4) === 'null') { + tokens.push({ type: 'null', value: 'null' }); + i += 4; + continue; + } + + // Structural characters + if ('{}[]:,'.includes(ch)) { + tokens.push({ type: 'structural', value: ch }); + i++; + continue; + } + + // Fallback + tokens.push({ type: 'other', value: ch }); + i++; + } + + return tokens; +} + +const TOKEN_CLASSES: Record = { + key: 'text-editor-tag-name', + string: 'text-editor-tag-string', + number: 'text-editor-tag-number', + boolean: 'text-editor-tag-atom', + null: 'text-editor-tag-atom', + structural: 'text-muted-foreground', + whitespace: '', + other: '', +}; + +import { useMemo } from "react"; + +export const JsonHighlighter = ({ text }: { text: string }) => { + const tokens = useMemo(() => tokenizeJson(text), [text]); + + return ( +
+            {tokens.map((token, i) => {
+                const cls = TOKEN_CLASSES[token.type];
+                if (!cls) {
+                    return token.value;
+                }
+                return (
+                    
+                        {token.value}
+                    
+                );
+            })}
+        
+ ); +}; diff --git a/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx b/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx new file mode 100644 index 000000000..3e679a21b --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx @@ -0,0 +1,173 @@ +'use client'; + +import { CopyIconButton } from "@/app/(app)/components/copyIconButton"; +import { McpFavicon } from "@/ee/features/mcp/components/mcpFavicon"; +import { useMcpServerIconMap } from "@/features/chat/mcpServerIconContext"; +import { cn } from "@/lib/utils"; +import { DynamicToolUIPart } from "ai"; +import { CheckCircle, ChevronDown, XCircle } from "lucide-react"; +import { useCallback, useMemo, useState } from "react"; +import { JsonHighlighter, unescapeJsonStrings } from "./jsonHighlighter"; + +export function parseMcpToolName(toolName: string): { serverName: string; toolName: string } | null { + if (!toolName.startsWith('mcp_')) { + return null; + } + const withoutPrefix = toolName.slice(4); + const doubleUnderscoreIdx = withoutPrefix.indexOf('__'); + if (doubleUnderscoreIdx === -1) { + return null; + } + return { + serverName: withoutPrefix.slice(0, doubleUnderscoreIdx), + toolName: withoutPrefix.slice(doubleUnderscoreIdx + 2), + }; +} + +export const McpToolComponent = ({ part }: { part: DynamicToolUIPart }) => { + const needsApproval = part.state === 'approval-requested'; + const [isExpanded, setIsExpanded] = useState(needsApproval); + const onToggle = useCallback(() => setIsExpanded(v => !v), []); + + const iconMap = useMcpServerIconMap(); + const parsed = parseMcpToolName(part.toolName); + const displayName = parsed + ? `${parsed.serverName}: ${parsed.toolName}` + : part.toolName; + const faviconUrl = parsed ? iconMap[parsed.serverName] : undefined; + + const hasInput = part.state !== 'input-streaming'; + + const requestText = useMemo( + () => hasInput ? JSON.stringify(part.input, null, 2) : '', + [hasInput, part.input] + ); + const responseText = useMemo(() => { + if (part.state === 'output-available') { + try { + return JSON.stringify(unescapeJsonStrings(part.output), null, 2); + } catch { + return String(part.output); + } + } + if (part.state === 'output-error') { + return part.errorText ?? ''; + } + return undefined; + }, [part.state, part.output, part.errorText]); + + const onCopyRequest = useCallback(() => { + navigator.clipboard.writeText(requestText); + return true; + }, [requestText]); + + const onCopyResponse = useCallback(() => { + if (!responseText) { + return false; + } + navigator.clipboard.writeText(responseText); + return true; + }, [responseText]); + + const renderStatus = () => { + if (part.state === 'output-error') { + return ( + + + {displayName} failed: {part.errorText} + + ); + } + if (part.state === 'output-denied') { + return ( + + + + {displayName} — denied + + ); + } + if (part.state === 'approval-requested') { + return ( + + + {displayName} + + ); + } + if (part.state === 'approval-responded') { + const approved = part.approval.approved; + return ( + + + {approved ? : } + {displayName}{approved ? '...' : ' — denied'} + + ); + } + if (part.state === 'output-available') { + return ( + + + {displayName} + + ); + } + // input-streaming, input-available, or other in-progress states + return ( + + + {displayName}... + + ); + }; + + return ( +
+
+
+ {renderStatus()} +
+ {hasInput && ( + + )} +
+ {hasInput && isExpanded && ( +
+ + + + {responseText !== undefined && ( + <> +
+ +
+ +
+
+ + )} +
+ )} +
+ ); +}; + + +const ResultSection = ({ label, onCopy, children }: { label: string; onCopy: () => boolean; children: React.ReactNode }) => ( +
+
+ {label} + +
+
+ {children} +
+
+); diff --git a/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx b/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx index aac756f4a..43ce2021d 100644 --- a/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx +++ b/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx @@ -6,6 +6,7 @@ import { ToolUIPart } from "ai"; import { ChevronDown } from "lucide-react"; import { cn } from "@/lib/utils"; import { useCallback, useState } from "react"; +import { JsonHighlighter, unescapeJsonStrings } from "./jsonHighlighter"; export const ToolOutputGuard = >({ part, @@ -27,7 +28,7 @@ export const ToolOutputGuard = { const raw = (part.output as { output: string }).output; try { - return JSON.stringify(JSON.parse(raw), null, 2); + return JSON.stringify(unescapeJsonStrings(JSON.parse(raw)), null, 2); } catch { return raw; } @@ -70,17 +71,15 @@ export const ToolOutputGuard = -
-                            {requestText}
-                        
+
{responseText !== undefined && ( <>
-
-                                    {responseText}
-                                
+
+ +
)} diff --git a/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx b/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx new file mode 100644 index 000000000..3711e22bd --- /dev/null +++ b/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx @@ -0,0 +1,53 @@ +'use client'; + +import { Separator } from "@/components/ui/separator"; +import { Collapsible, CollapsibleContent, CollapsibleTrigger } from "@/components/ui/collapsible"; +import { ChevronRight } from "lucide-react"; +import { useState } from "react"; +import { cn } from "@/lib/utils"; + +interface ToolSearchResult { + name: string; + description: string; +} + +interface ToolSearchToolComponentProps { + query: string; + results: ToolSearchResult[]; +} + +export const ToolSearchToolComponent = ({ query, results }: ToolSearchToolComponentProps) => { + const [isOpen, setIsOpen] = useState(false); + + return ( + + +
+ + Searched MCP tools: {query} + + {results.length} result{results.length === 1 ? '' : 's'} + +
+
+ +
+ {results.map((result) => ( +
+ {result.name} + {result.description && ( + <> + - + {result.description} + + )} +
+ ))} + {results.length === 0 && ( + No tools found + )} +
+
+
+ ); +}; diff --git a/packages/web/src/features/chat/constants.ts b/packages/web/src/features/chat/constants.ts index b84e9d922..5ae681b32 100644 --- a/packages/web/src/features/chat/constants.ts +++ b/packages/web/src/features/chat/constants.ts @@ -9,3 +9,4 @@ export const ANSWER_TAG = ''; export const SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY = 'selectedSearchScopes'; export const SET_CHAT_STATE_SESSION_STORAGE_KEY = 'setChatState'; +export const DISABLED_MCP_SERVER_IDS_LOCAL_STORAGE_KEY = 'disabledMcpServerIds'; diff --git a/packages/web/src/features/chat/mcpServerIconContext.tsx b/packages/web/src/features/chat/mcpServerIconContext.tsx new file mode 100644 index 000000000..94628f4a5 --- /dev/null +++ b/packages/web/src/features/chat/mcpServerIconContext.tsx @@ -0,0 +1,10 @@ +'use client'; + +import { createContext, useContext } from 'react'; + +// Maps sanitized server name (e.g. "linear") to a favicon URL. +export type McpServerIconMap = Record; + +export const McpServerIconContext = createContext({}); + +export const useMcpServerIconMap = () => useContext(McpServerIconContext); diff --git a/packages/web/src/features/chat/toolApprovalContext.tsx b/packages/web/src/features/chat/toolApprovalContext.tsx new file mode 100644 index 000000000..d4379c394 --- /dev/null +++ b/packages/web/src/features/chat/toolApprovalContext.tsx @@ -0,0 +1,9 @@ +'use client'; + +import { createContext, useContext } from 'react'; +import type { ChatAddToolApproveResponseFunction } from 'ai'; + +const ToolApprovalContext = createContext(null); + +export const ToolApprovalProvider = ToolApprovalContext.Provider; +export const useToolApproval = () => useContext(ToolApprovalContext); \ No newline at end of file diff --git a/packages/web/src/features/chat/types.test.ts b/packages/web/src/features/chat/types.test.ts new file mode 100644 index 000000000..a9f41df7c --- /dev/null +++ b/packages/web/src/features/chat/types.test.ts @@ -0,0 +1,72 @@ +import { expect, test, describe } from 'vitest'; +import { sbChatMessageMetadataSchema, additionalChatRequestParamsSchema } from './types'; + +describe('sbChatMessageMetadataSchema', () => { + test('accepts disabledMcpServerIds as array of strings', () => { + const result = sbChatMessageMetadataSchema.safeParse({ + disabledMcpServerIds: ['id1', 'id2'], + }); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual(['id1', 'id2']); + } + }); + + test('accepts missing disabledMcpServerIds (optional)', () => { + const result = sbChatMessageMetadataSchema.safeParse({}); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toBeUndefined(); + } + }); + + test('rejects non-string array values', () => { + const result = sbChatMessageMetadataSchema.safeParse({ + disabledMcpServerIds: [123, 456], + }); + + expect(result.success).toBe(false); + }); +}); + +describe('additionalChatRequestParamsSchema', () => { + const validBase = { + languageModel: { + provider: 'anthropic', + model: 'claude-sonnet-4-20250514', + }, + selectedSearchScopes: [], + }; + + test('defaults disabledMcpServerIds to empty array', () => { + const result = additionalChatRequestParamsSchema.safeParse(validBase); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual([]); + } + }); + + test('accepts explicit disabledMcpServerIds array', () => { + const result = additionalChatRequestParamsSchema.safeParse({ + ...validBase, + disabledMcpServerIds: ['abc'], + }); + + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.disabledMcpServerIds).toEqual(['abc']); + } + }); + + test('rejects non-array value for disabledMcpServerIds', () => { + const result = additionalChatRequestParamsSchema.safeParse({ + ...validBase, + disabledMcpServerIds: 'not-an-array', + }); + + expect(result.success).toBe(false); + }); +}); diff --git a/packages/web/src/features/chat/types.ts b/packages/web/src/features/chat/types.ts index 6e990f5c2..3c2619f14 100644 --- a/packages/web/src/features/chat/types.ts +++ b/packages/web/src/features/chat/types.ts @@ -60,6 +60,7 @@ export const sbChatMessageMetadataSchema = z.object({ userId: z.string().optional(), })).optional(), selectedSearchScopes: z.array(searchScopeSchema).optional(), + disabledMcpServerIds: z.array(z.string()).optional(), traceId: z.string().optional(), }); @@ -67,12 +68,22 @@ export type SBChatMessageMetadata = z.infer; export type SBChatMessageToolTypes = { [K in keyof ReturnType]: InferUITool[K]>; +} & { + tool_request_activation: { + input: { tool_to_activate_name: string }; + output: { results: Array<{ name: string; description: string }> }; + }; }; export type SBChatMessageDataParts = { // The `source` data type allows us to know what sources the LLM saw // during retrieval. "source": Source, + // The `mcp-server` data type carries favicon metadata for connected MCP servers, + // keyed by sanitized server name (e.g. "linear"). + "mcp-server": { sanitizedName: string; faviconUrl: string }, + // The `mcp-failed-server` data type surfaces MCP servers that failed to load their tools. + "mcp-failed-server": { serverName: string }, } export type SBChatMessage = UIMessage< @@ -143,6 +154,7 @@ declare module 'slate' { export type SetChatStatePayload = { inputMessage: CreateUIMessage; selectedSearchScopes: SearchScope[]; + disabledMcpServerIds: string[]; } @@ -188,5 +200,6 @@ export type LanguageModelInfo = { export const additionalChatRequestParamsSchema = z.object({ languageModel: languageModelInfoSchema, selectedSearchScopes: z.array(searchScopeSchema), + disabledMcpServerIds: z.array(z.string()).default([]), }); -export type AdditionalChatRequestParams = z.infer; \ No newline at end of file +export type AdditionalChatRequestParams = z.infer; diff --git a/packages/web/src/features/chat/useCreateNewChatThread.ts b/packages/web/src/features/chat/useCreateNewChatThread.ts index 63ead0249..7cf72a0ce 100644 --- a/packages/web/src/features/chat/useCreateNewChatThread.ts +++ b/packages/web/src/features/chat/useCreateNewChatThread.ts @@ -30,11 +30,11 @@ export const useCreateNewChatThread = ({ isAuthenticated = false }: UseCreateNew const hasRestoredPendingMessage = useRef(false); const captureEvent = useCaptureEvent(); - const doCreateChat = useCallback(async (children: Descendant[], selectedSearchScopes: SearchScope[]) => { + const doCreateChat = useCallback(async (children: Descendant[], selectedSearchScopes: SearchScope[], disabledMcpServerIds: string[]) => { const text = slateContentToString(children); const mentions = getAllMentionElements(children); - const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes); + const inputMessage = createUIMessage(text, mentions.map((mention) => mention.data), selectedSearchScopes, disabledMcpServerIds); setIsLoading(true); const response = await createChat({ source: 'sourcebot-web-client' }); @@ -49,6 +49,7 @@ export const useCreateNewChatThread = ({ isAuthenticated = false }: UseCreateNew setChatState({ inputMessage, selectedSearchScopes, + disabledMcpServerIds, }); const url = createPathWithQueryParams(`/chat/${response.id}`); @@ -56,18 +57,18 @@ export const useCreateNewChatThread = ({ isAuthenticated = false }: UseCreateNew router.push(url); }, [router, toast, setChatState]); - const createNewChatThread = useCallback(async (children: Descendant[], selectedSearchScopes: SearchScope[]) => { + const createNewChatThread = useCallback(async (children: Descendant[], selectedSearchScopes: SearchScope[], disabledMcpServerIds: string[]) => { if (!isAuthenticated) { const result = await getAskGhLoginWallData(); if (!isServiceError(result) && result.isEnabled) { captureEvent('wa_askgh_login_wall_prompted', {}); - sessionStorage.setItem(PENDING_NEW_CHAT_KEY, JSON.stringify({ children, selectedSearchScopes })); + sessionStorage.setItem(PENDING_NEW_CHAT_KEY, JSON.stringify({ children, selectedSearchScopes, disabledMcpServerIds })); setLoginWallState({ isOpen: true, providers: result.providers }); return; } } - doCreateChat(children, selectedSearchScopes); + doCreateChat(children, selectedSearchScopes, disabledMcpServerIds); }, [isAuthenticated, captureEvent, doCreateChat]); // Restore pending message after OAuth redirect @@ -85,11 +86,12 @@ export const useCreateNewChatThread = ({ isAuthenticated = false }: UseCreateNew sessionStorage.removeItem(PENDING_NEW_CHAT_KEY); try { - const { children, selectedSearchScopes } = JSON.parse(stored) as { + const { children, selectedSearchScopes, disabledMcpServerIds } = JSON.parse(stored) as { children: Descendant[]; selectedSearchScopes: SearchScope[]; + disabledMcpServerIds: string[]; }; - doCreateChat(children, selectedSearchScopes); + doCreateChat(children, selectedSearchScopes, disabledMcpServerIds ?? []); } catch (error) { console.error('Failed to restore pending message:', error); } diff --git a/packages/web/src/features/chat/utils.test.ts b/packages/web/src/features/chat/utils.test.ts index 26359d2a9..e5a89c0bb 100644 --- a/packages/web/src/features/chat/utils.test.ts +++ b/packages/web/src/features/chat/utils.test.ts @@ -1,5 +1,5 @@ -import { expect, test, vi } from 'vitest' -import { fileReferenceToString, getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences } from './utils' +import { expect, test, describe, vi } from 'vitest' +import { createUIMessage, fileReferenceToString, getAnswerPartFromAssistantMessage, groupMessageIntoSteps, repairReferences } from './utils' import { FILE_REFERENCE_REGEX, ANSWER_TAG } from './constants'; import { SBChatMessage, SBChatMessagePart } from './types'; @@ -351,3 +351,31 @@ test('repairReferences handles malformed inline code blocks', () => { const expected = 'See @file:{github.com/sourcebot-dev/sourcebot::packages/web/src/auth.ts} for details.'; expect(repairReferences(input)).toBe(expected); }); + +describe('createUIMessage', () => { + test('includes disabledMcpServerIds in metadata when provided', () => { + const result = createUIMessage('hello', [], [], ['server1', 'server2']); + + expect(result.metadata?.disabledMcpServerIds).toEqual(['server1', 'server2']); + }); + + test('defaults disabledMcpServerIds to empty array when omitted', () => { + const result = createUIMessage('hello', [], []); + + expect(result.metadata?.disabledMcpServerIds).toEqual([]); + }); + + test('passes through empty array', () => { + const result = createUIMessage('hello', [], [], []); + + expect(result.metadata?.disabledMcpServerIds).toEqual([]); + }); + + test('includes both selectedSearchScopes and disabledMcpServerIds in metadata', () => { + const scopes = [{ type: 'repo' as const, value: 'org/repo', name: 'repo', codeHostType: 'github' }]; + const result = createUIMessage('hello', [], scopes, ['disabled1']); + + expect(result.metadata?.selectedSearchScopes).toEqual(scopes); + expect(result.metadata?.disabledMcpServerIds).toEqual(['disabled1']); + }); +}); diff --git a/packages/web/src/features/chat/utils.ts b/packages/web/src/features/chat/utils.ts index 38dd784fd..2ecccd727 100644 --- a/packages/web/src/features/chat/utils.ts +++ b/packages/web/src/features/chat/utils.ts @@ -176,7 +176,7 @@ export const addLineNumbers = (source: string, lineOffset = 1) => { return source.split('\n').map((line, index) => `${index + lineOffset}: ${line}`).join('\n'); } -export const createUIMessage = (text: string, mentions: MentionData[], selectedSearchScopes: SearchScope[]): CreateUIMessage => { +export const createUIMessage = (text: string, mentions: MentionData[], selectedSearchScopes: SearchScope[], disabledMcpServerIds: string[] = []): CreateUIMessage => { // Converts applicable mentions into sources. const sources: Source[] = mentions .map((mention) => { @@ -209,6 +209,7 @@ export const createUIMessage = (text: string, mentions: MentionData[], selectedS ], metadata: { selectedSearchScopes, + disabledMcpServerIds, }, } } diff --git a/packages/web/src/features/mcp/prismaOAuthClientProvider.ts b/packages/web/src/features/mcp/prismaOAuthClientProvider.ts new file mode 100644 index 000000000..0e0b89819 --- /dev/null +++ b/packages/web/src/features/mcp/prismaOAuthClientProvider.ts @@ -0,0 +1,168 @@ +import 'server-only'; +import type { + OAuthClientProvider, + OAuthClientInformation, + OAuthClientMetadata, + OAuthTokens, +} from '@ai-sdk/mcp'; +// Note: We use the raw (unscoped) prisma client here intentionally. The user-scoped +// prisma extension only filters Repo queries, and all MCP queries in this file already +// filter explicitly by userId and/or serverId, so scoping would be a no-op. +import { __unsafePrisma } from '@/prisma'; +import { encryptOAuthToken, decryptOAuthToken } from '@sourcebot/shared'; + +/** + * Prisma-backed OAuthClientProvider for connecting to external MCP servers. + * + * Stores dynamic client registration (client_id/secret) on McpServer (per-org), + * and per-user tokens + ephemeral PKCE state on McpServerCredential. + */ +export class PrismaOAuthClientProvider implements OAuthClientProvider { + constructor( + private readonly serverId: string, + private readonly userId: string, + private readonly callbackUrl: string, + ) {} + + /** Populated by redirectToAuthorization — read after auth() returns 'REDIRECT'. */ + public authorizationUrl: string | undefined; + + get redirectUrl(): string | URL { + return this.callbackUrl; + } + + get clientMetadata(): OAuthClientMetadata { + return { + redirect_uris: [this.callbackUrl], + client_name: 'Sourcebot', + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'none', + }; + } + + async clientInformation(): Promise { + const server = await __unsafePrisma.mcpServer.findUnique({ + where: { id: this.serverId }, + select: { clientInfo: true }, + }); + if (!server?.clientInfo) return undefined; + + const decrypted = decryptOAuthToken(server.clientInfo); + return decrypted ? JSON.parse(decrypted) : undefined; + } + + async saveClientInformation(info: OAuthClientInformation): Promise { + const encrypted = encryptOAuthToken(JSON.stringify(info)); + await __unsafePrisma.mcpServer.update({ + where: { id: this.serverId }, + data: { clientInfo: encrypted }, + }); + } + + async tokens(): Promise { + const cred = await this.getOrCreateCredential(); + if (!cred.tokens) return undefined; + + const decrypted = decryptOAuthToken(cred.tokens); + return decrypted ? JSON.parse(decrypted) : undefined; + } + + async saveTokens(tokens: OAuthTokens): Promise { + const encrypted = encryptOAuthToken(JSON.stringify(tokens)); + const tokensExpiresAt = tokens.expires_in + ? new Date(Date.now() + tokens.expires_in * 1000) + : null; + await __unsafePrisma.mcpServerCredential.update({ + where: { userId_serverId: { userId: this.userId, serverId: this.serverId } }, + data: { tokens: encrypted, tokensExpiresAt }, + }); + } + + async codeVerifier(): Promise { + const cred = await this.getOrCreateCredential(); + if (!cred.codeVerifier) { + throw new Error('No code verifier found'); + } + return cred.codeVerifier; + } + + async saveCodeVerifier(codeVerifier: string): Promise { + await this.upsertCredential({ codeVerifier }); + } + + async state(): Promise { + return crypto.randomUUID(); + } + + async saveState(state: string): Promise { + await this.upsertCredential({ state }); + } + + async storedState(): Promise { + const cred = await this.getOrCreateCredential(); + return cred.state ?? undefined; + } + + async redirectToAuthorization(url: URL): Promise { + // Force the OAuth provider to show a consent/login screen on every authorization. + // This prevents a stolen-session attack where an attacker signs into Sourcebot on + // a victim's machine and silently obtains the victim's provider tokens via an + // existing browser session. + if (!url.searchParams.has('prompt')) { + url.searchParams.set('prompt', 'consent'); + } + + // Clear any stale tokens from the database. This is called when the SDK determines + // that existing tokens are no longer valid (e.g., the access token expired and the + // refresh token was revoked). Clearing them ensures the UI reflects "not connected" + // so the user knows to re-authenticate, rather than staying stuck in a state where + // the server appears connected but all tool calls fail. + await this.invalidateCredentials('tokens'); + + this.authorizationUrl = url.toString(); + } + + async invalidateCredentials( + scope: 'all' | 'client' | 'tokens' | 'verifier', + ): Promise { + if (scope === 'all' || scope === 'client') { + await __unsafePrisma.mcpServer.update({ + where: { id: this.serverId }, + data: { clientInfo: null }, + }); + } + + if (scope === 'all' || scope === 'tokens') { + await this.upsertCredential({ tokens: null }); + } + + if (scope === 'all' || scope === 'verifier') { + await this.upsertCredential({ codeVerifier: null, state: null }); + } + } + + private async getOrCreateCredential() { + return __unsafePrisma.mcpServerCredential.upsert({ + where: { + userId_serverId: { userId: this.userId, serverId: this.serverId }, + }, + create: { userId: this.userId, serverId: this.serverId }, + update: {}, + }); + } + + private async upsertCredential(data: { + tokens?: string | null; + codeVerifier?: string | null; + state?: string | null; + }) { + await __unsafePrisma.mcpServerCredential.upsert({ + where: { + userId_serverId: { userId: this.userId, serverId: this.serverId }, + }, + create: { userId: this.userId, serverId: this.serverId, ...data }, + update: data, + }); + } +} \ No newline at end of file diff --git a/packages/web/src/lib/errorCodes.ts b/packages/web/src/lib/errorCodes.ts index 714932c30..fdb09d67d 100644 --- a/packages/web/src/lib/errorCodes.ts +++ b/packages/web/src/lib/errorCodes.ts @@ -35,4 +35,6 @@ export enum ErrorCode { LAST_OWNER_CANNOT_BE_DEMOTED = 'LAST_OWNER_CANNOT_BE_DEMOTED', LAST_OWNER_CANNOT_BE_REMOVED = 'LAST_OWNER_CANNOT_BE_REMOVED', API_KEY_USAGE_DISABLED = 'API_KEY_USAGE_DISABLED', + MCP_SERVER_ALREADY_EXISTS = 'MCP_SERVER_ALREADY_EXISTS', + MCP_SERVER_NOT_FOUND = 'MCP_SERVER_NOT_FOUND', } diff --git a/yarn.lock b/yarn.lock index 7be7eb0ae..4aaba200f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -99,6 +99,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/mcp@npm:^2.0.0-beta.11": + version: 2.0.0-beta.11 + resolution: "@ai-sdk/mcp@npm:2.0.0-beta.11" + dependencies: + "@ai-sdk/provider": "npm:4.0.0-beta.5" + "@ai-sdk/provider-utils": "npm:5.0.0-beta.7" + pkce-challenge: "npm:^5.0.0" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/efcc9b9f5f8b20b78b2d0ee6d83b34466b2ec456c3b40b5b8b10af226e7d3f6144f964d87a20c5fc54c24b21f3610cb75cc246c30833b99fb501438a206c9933 + languageName: node + linkType: hard + "@ai-sdk/mistral@npm:^3.0.30": version: 3.0.30 resolution: "@ai-sdk/mistral@npm:3.0.30" @@ -148,6 +161,19 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider-utils@npm:5.0.0-beta.7": + version: 5.0.0-beta.7 + resolution: "@ai-sdk/provider-utils@npm:5.0.0-beta.7" + dependencies: + "@ai-sdk/provider": "npm:4.0.0-beta.5" + "@standard-schema/spec": "npm:^1.1.0" + eventsource-parser: "npm:^3.0.6" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/440825f7b599da6a0bd830c905f9ba4f21defcf7068bc98154ea38158c1134b049cb2815047013668f48b679a23de1d3c19eb072a65115dc860070168104c99e + languageName: node + linkType: hard + "@ai-sdk/provider@npm:3.0.8": version: 3.0.8 resolution: "@ai-sdk/provider@npm:3.0.8" @@ -157,6 +183,15 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/provider@npm:4.0.0-beta.5": + version: 4.0.0-beta.5 + resolution: "@ai-sdk/provider@npm:4.0.0-beta.5" + dependencies: + json-schema: "npm:^0.4.0" + checksum: 10c0/886f5892268cc3425130c9b019a9eb1e2acdb5efd05d920b05d1ac1ab49603393d8e509e6e0a3c46dee533a411a51a2af2c6fa0a173b41130f5175a615add7fb + languageName: node + linkType: hard + "@ai-sdk/react@npm:^3.0.169": version: 3.0.169 resolution: "@ai-sdk/react@npm:3.0.169" @@ -9062,6 +9097,7 @@ __metadata: "@ai-sdk/deepseek": "npm:^2.0.29" "@ai-sdk/google": "npm:^3.0.64" "@ai-sdk/google-vertex": "npm:^4.0.111" + "@ai-sdk/mcp": "npm:^2.0.0-beta.11" "@ai-sdk/mistral": "npm:^3.0.30" "@ai-sdk/openai": "npm:^3.0.53" "@ai-sdk/openai-compatible": "npm:^2.0.41" @@ -9274,7 +9310,7 @@ __metadata: vitest: "npm:^4.1.4" vitest-mock-extended: "npm:^4.0.0" vscode-icons-js: "npm:^11.6.1" - zod: "npm:^3.25.74" + zod: "npm:^3.25.76" zod-to-json-schema: "npm:^3.24.5" languageName: unknown linkType: soft @@ -18526,13 +18562,20 @@ __metadata: languageName: node linkType: hard -"picomatch@npm:^4.0.2, picomatch@npm:^4.0.3, picomatch@npm:^4.0.4": +"picomatch@npm:^4.0.2, picomatch@npm:^4.0.4": version: 4.0.4 resolution: "picomatch@npm:4.0.4" checksum: 10c0/e2c6023372cc7b5764719a5ffb9da0f8e781212fa7ca4bd0562db929df8e117460f00dff3cb7509dacfc06b86de924b247f504d0ce1806a37fac4633081466b0 languageName: node linkType: hard +"picomatch@npm:^4.0.3": + version: 4.0.3 + resolution: "picomatch@npm:4.0.3" + checksum: 10c0/9582c951e95eebee5434f59e426cddd228a7b97a0161a375aed4be244bd3fe8e3a31b846808ea14ef2c8a2527a6eeab7b3946a67d5979e81694654f939473ae2 + languageName: node + linkType: hard + "picospinner@npm:^3.0.0": version: 3.0.0 resolution: "picospinner@npm:3.0.0" @@ -23045,7 +23088,7 @@ __metadata: languageName: node linkType: hard -"zod@npm:^3.25.0": +"zod@npm:^3.25.0, zod@npm:^3.25.76": version: 3.25.76 resolution: "zod@npm:3.25.76" checksum: 10c0/5718ec35e3c40b600316c5b4c5e4976f7fee68151bc8f8d90ec18a469be9571f072e1bbaace10f1e85cf8892ea12d90821b200e980ab46916a6166a4260a983c