From 33514c579c3011dcac11da78fff13c333525d480 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Mon, 11 May 2026 16:53:09 -0400 Subject: [PATCH] feat: add MCP remote proxy support --- .env.example | 2 + librechat.example.yaml | 1 + .../mcp/__tests__/MCPConnectionSSRF.test.ts | 264 +++++++++++++++++ packages/api/src/mcp/__tests__/mcp.spec.ts | 26 ++ packages/api/src/mcp/connection.ts | 277 +++++++++++++++--- packages/api/src/utils/env.spec.ts | 18 ++ packages/api/src/utils/env.ts | 11 + packages/data-provider/specs/mcp.spec.ts | 38 ++- packages/data-provider/src/mcp.ts | 25 ++ 9 files changed, 627 insertions(+), 35 deletions(-) diff --git a/.env.example b/.env.example index 00030c7187..9f7335343b 100644 --- a/.env.example +++ b/.env.example @@ -116,6 +116,8 @@ NODE_MAX_OLD_SPACE_SIZE=6144 # ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic +# Optional outbound proxy for server-side requests, including remote MCP HTTP/SSE transports. +# Remote MCP transports also honor HTTP_PROXY, HTTPS_PROXY, and NO_PROXY when PROXY is unset. PROXY= #===================================# diff --git a/librechat.example.yaml b/librechat.example.yaml index 7d61b486f4..7efe2842b3 100644 --- a/librechat.example.yaml +++ b/librechat.example.yaml @@ -304,6 +304,7 @@ actions: # everything: # # type: sse # type can optionally be omitted # url: http://localhost:3001/sse +# # proxy: "${MCP_PROXY_URL}" # optional outbound proxy (http/https/socks/socks5) # timeout: 60000 # 1 minute timeout for this server, this is the default timeout for MCP servers. # puppeteer: # type: stdio diff --git a/packages/api/src/mcp/__tests__/MCPConnectionSSRF.test.ts b/packages/api/src/mcp/__tests__/MCPConnectionSSRF.test.ts index 98f698c9cd..315e984468 100644 --- a/packages/api/src/mcp/__tests__/MCPConnectionSSRF.test.ts +++ b/packages/api/src/mcp/__tests__/MCPConnectionSSRF.test.ts @@ -51,6 +51,7 @@ jest.mock('~/auth', () => ({ callback(null, '127.0.0.1', 4); }, })), + isSSRFTarget: jest.fn(() => false), resolveHostnameSSRF: jest.fn(async () => false), })); @@ -852,6 +853,269 @@ describe('MCP SSRF protection – customFetch input shapes', () => { return factory.call(connection, () => null); } + it('should allocate proxy dispatchers for streamable-http when proxy is configured', () => { + conn = new MCPConnection({ + serverName: 'customfetch-proxy-dispatchers', + serverConfig: { + type: 'streamable-http', + url: 'https://mcp.example.com/mcp', + proxy: 'http://proxy.example.com:8080', + }, + useSSRFProtection: false, + }); + + const privateSelf = conn as unknown as { + agents: Array<{ constructor: { name: string } }>; + createFetchFunction: ( + getHeaders: () => Record | null | undefined, + timeout?: number, + sseBodyTimeout?: number, + configuredSecretHeaderKeys?: ReadonlySet, + baseUrl?: string, + ) => CustomFetch; + }; + privateSelf.createFetchFunction.call( + conn, + () => null, + undefined, + 300000, + undefined, + 'https://mcp.example.com/mcp', + ); + + expect(privateSelf.agents.map((agent) => agent.constructor.name)).toEqual([ + 'ProxyAgent', + 'ProxyAgent', + ]); + }); + + it('should use the PROXY env var for streamable-http when server proxy is not configured', () => { + const originalProxy = process.env.PROXY; + process.env.PROXY = 'http://env-proxy.example.com:8080'; + try { + conn = new MCPConnection({ + serverName: 'customfetch-env-proxy-dispatchers', + serverConfig: { + type: 'streamable-http', + url: 'https://mcp.example.com/mcp', + }, + useSSRFProtection: false, + }); + + const privateSelf = conn as unknown as { + agents: Array<{ constructor: { name: string } }>; + createFetchFunction: ( + getHeaders: () => Record | null | undefined, + timeout?: number, + sseBodyTimeout?: number, + configuredSecretHeaderKeys?: ReadonlySet, + baseUrl?: string, + ) => CustomFetch; + }; + privateSelf.createFetchFunction.call( + conn, + () => null, + undefined, + 300000, + undefined, + 'https://mcp.example.com/mcp', + ); + + expect(privateSelf.agents.map((agent) => agent.constructor.name)).toEqual([ + 'ProxyAgent', + 'ProxyAgent', + ]); + } finally { + if (originalProxy == null) { + delete process.env.PROXY; + } else { + process.env.PROXY = originalProxy; + } + } + }); + + it('should use standard HTTP proxy env vars for streamable-http when PROXY is absent', () => { + const originalProxy = process.env.PROXY; + const originalHttpProxy = process.env.HTTP_PROXY; + const originalHttpsProxy = process.env.HTTPS_PROXY; + const originalNoProxy = process.env.NO_PROXY; + const originalLowerHttpProxy = process.env.http_proxy; + const originalLowerHttpsProxy = process.env.https_proxy; + const originalLowerNoProxy = process.env.no_proxy; + + delete process.env.PROXY; + delete process.env.http_proxy; + delete process.env.https_proxy; + delete process.env.no_proxy; + process.env.HTTP_PROXY = 'http://http-proxy.example.com:8080'; + process.env.HTTPS_PROXY = 'http://https-proxy.example.com:8080'; + process.env.NO_PROXY = 'localhost,127.0.0.1'; + + try { + conn = new MCPConnection({ + serverName: 'customfetch-standard-env-proxy-dispatchers', + serverConfig: { + type: 'streamable-http', + url: 'https://mcp.example.com/mcp', + }, + useSSRFProtection: false, + }); + + const privateSelf = conn as unknown as { + agents: Array<{ constructor: { name: string } }>; + createFetchFunction: ( + getHeaders: () => Record | null | undefined, + timeout?: number, + sseBodyTimeout?: number, + configuredSecretHeaderKeys?: ReadonlySet, + baseUrl?: string, + ) => CustomFetch; + }; + privateSelf.createFetchFunction.call( + conn, + () => null, + undefined, + 300000, + undefined, + 'https://mcp.example.com/mcp', + ); + + expect(privateSelf.agents.map((agent) => agent.constructor.name)).toEqual([ + 'ProxyAgent', + 'ProxyAgent', + ]); + } finally { + if (originalProxy == null) { + delete process.env.PROXY; + } else { + process.env.PROXY = originalProxy; + } + if (originalHttpProxy == null) { + delete process.env.HTTP_PROXY; + } else { + process.env.HTTP_PROXY = originalHttpProxy; + } + if (originalHttpsProxy == null) { + delete process.env.HTTPS_PROXY; + } else { + process.env.HTTPS_PROXY = originalHttpsProxy; + } + if (originalNoProxy == null) { + delete process.env.NO_PROXY; + } else { + process.env.NO_PROXY = originalNoProxy; + } + if (originalLowerHttpProxy == null) { + delete process.env.http_proxy; + } else { + process.env.http_proxy = originalLowerHttpProxy; + } + if (originalLowerHttpsProxy == null) { + delete process.env.https_proxy; + } else { + process.env.https_proxy = originalLowerHttpsProxy; + } + if (originalLowerNoProxy == null) { + delete process.env.no_proxy; + } else { + process.env.no_proxy = originalLowerNoProxy; + } + } + }); + + it('should honor NO_PROXY when standard HTTP proxy env vars are configured', () => { + const originalProxy = process.env.PROXY; + const originalHttpsProxy = process.env.HTTPS_PROXY; + const originalNoProxy = process.env.NO_PROXY; + const originalLowerHttpsProxy = process.env.https_proxy; + const originalLowerNoProxy = process.env.no_proxy; + + delete process.env.PROXY; + delete process.env.https_proxy; + delete process.env.no_proxy; + process.env.HTTPS_PROXY = 'http://https-proxy.example.com:8080'; + process.env.NO_PROXY = 'mcp.example.com'; + + try { + conn = new MCPConnection({ + serverName: 'customfetch-standard-env-no-proxy', + serverConfig: { + type: 'streamable-http', + url: 'https://mcp.example.com/mcp', + }, + useSSRFProtection: false, + }); + + const privateSelf = conn as unknown as { + agents: Array<{ constructor: { name: string } }>; + createFetchFunction: ( + getHeaders: () => Record | null | undefined, + timeout?: number, + sseBodyTimeout?: number, + configuredSecretHeaderKeys?: ReadonlySet, + baseUrl?: string, + ) => CustomFetch; + }; + privateSelf.createFetchFunction.call( + conn, + () => null, + undefined, + 300000, + undefined, + 'https://mcp.example.com/mcp', + ); + + expect(privateSelf.agents.map((agent) => agent.constructor.name)).toEqual(['Agent', 'Agent']); + } finally { + if (originalProxy == null) { + delete process.env.PROXY; + } else { + process.env.PROXY = originalProxy; + } + if (originalHttpsProxy == null) { + delete process.env.HTTPS_PROXY; + } else { + process.env.HTTPS_PROXY = originalHttpsProxy; + } + if (originalNoProxy == null) { + delete process.env.NO_PROXY; + } else { + process.env.NO_PROXY = originalNoProxy; + } + if (originalLowerHttpsProxy == null) { + delete process.env.https_proxy; + } else { + process.env.https_proxy = originalLowerHttpsProxy; + } + if (originalLowerNoProxy == null) { + delete process.env.no_proxy; + } else { + process.env.no_proxy = originalLowerNoProxy; + } + } + }); + + it('should preflight proxied targets before dispatching network requests', async () => { + mockedResolveHostnameSSRF.mockResolvedValueOnce(true); + + conn = new MCPConnection({ + serverName: 'customfetch-proxy-ssrf', + serverConfig: { + type: 'streamable-http', + url: 'https://mcp.example.com/mcp', + proxy: 'http://proxy.example.com:8080', + }, + useSSRFProtection: true, + }); + + const customFetch = getCustomFetch(conn); + + await expect(customFetch('http://blocked.example.com/mcp')).rejects.toThrow( + /proxied MCP request target/, + ); + expect(mockedResolveHostnameSSRF).toHaveBeenCalledWith('blocked.example.com', null, '80'); + }); + it.each<['string' | 'URL' | 'Request']>([['string'], ['URL'], ['Request']])( 'should accept a %s input without throwing on URL derivation', async (shape) => { diff --git a/packages/api/src/mcp/__tests__/mcp.spec.ts b/packages/api/src/mcp/__tests__/mcp.spec.ts index d5cc44569f..de2839371d 100644 --- a/packages/api/src/mcp/__tests__/mcp.spec.ts +++ b/packages/api/src/mcp/__tests__/mcp.spec.ts @@ -150,6 +150,18 @@ describe('Environment Variable Extraction (MCP)', () => { expect(result.headers).toEqual(options.headers); }); + it('should validate proxy URLs for remote HTTP transports', () => { + const options = { + type: 'streamable-http', + url: 'https://example.com/api', + proxy: 'http://proxy.example.com:8080', + }; + + const result = StreamableHTTPOptionsSchema.parse(options); + + expect(result.proxy).toBe('http://proxy.example.com:8080'); + }); + it('should accept "http" as an alias for "streamable-http"', () => { const options = { type: 'http', @@ -324,6 +336,20 @@ describe('Environment Variable Extraction (MCP)', () => { }); }); + it('should process proxy in streamable-http options', () => { + process.env.MCP_PROXY_URL = 'http://proxy.example.com:8080'; + const options: MCPOptions = { + type: 'streamable-http', + url: 'https://example.com', + proxy: '${MCP_PROXY_URL}', + }; + + const result = processMCPEnv({ options }); + + expect('proxy' in result && result.proxy).toBe('http://proxy.example.com:8080'); + delete process.env.MCP_PROXY_URL; + }); + it('should maintain streamable-http type in processed options', () => { const options: MCPOptions = { type: 'streamable-http', diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index 9f779d7e87..6945b0c5eb 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -1,6 +1,6 @@ import { EventEmitter } from 'events'; import { logger } from '@librechat/data-schemas'; -import { fetch as undiciFetch, Agent } from 'undici'; +import { fetch as undiciFetch, Agent, ProxyAgent } from 'undici'; import { StdioClientTransport, getDefaultEnvironment, @@ -15,16 +15,30 @@ import type { RequestInit as UndiciRequestInit, RequestInfo as UndiciRequestInfo, Response as UndiciResponse, + Dispatcher, } from 'undici'; import type { MCPOAuthTokens } from './oauth/types'; import type * as t from './types'; -import { createSSRFSafeUndiciConnect, resolveHostnameSSRF } from '~/auth'; +import { createSSRFSafeUndiciConnect, isSSRFTarget, resolveHostnameSSRF } from '~/auth'; import { runOutsideTracing } from '~/utils/tracing'; import { sanitizeUrlForLogging } from './utils'; import { withTimeout } from '~/utils/promise'; import { mcpConfig } from './mcpConfig'; type FetchLike = (url: string | URL, init?: RequestInit) => Promise; +type ManagedDispatcher = Agent | ProxyAgent; + +type MCPProxyConfig = + | { + type: 'explicit'; + proxyUrl: string; + } + | { + type: 'env'; + httpProxy?: string; + httpsProxy?: string; + noProxy?: string; + }; function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions { return 'command' in options; @@ -164,7 +178,7 @@ function normalizeInitHeaders(init: UndiciRequestInit | undefined): Record | null | undefined, ): UndiciRequestInit { const hasInitHeaders = init?.headers != null; @@ -196,6 +210,166 @@ function getUrlPort(url: URL | string): string { return ''; } +function getTrimmedEnv(...keys: string[]): string | undefined { + for (const key of keys) { + const value = process.env[key]?.trim(); + if (value) { + return value; + } + } + return undefined; +} + +function getMCPProxyConfig(options: t.MCPOptions): MCPProxyConfig | undefined { + const configuredProxy = + 'proxy' in options && typeof options.proxy === 'string' ? options.proxy.trim() : ''; + if (configuredProxy) { + return { type: 'explicit', proxyUrl: configuredProxy }; + } + + const libreChatProxy = process.env.PROXY?.trim() ?? ''; + if (libreChatProxy) { + return { type: 'explicit', proxyUrl: libreChatProxy }; + } + + const httpProxy = getTrimmedEnv('http_proxy', 'HTTP_PROXY'); + const httpsProxy = getTrimmedEnv('https_proxy', 'HTTPS_PROXY'); + if (!httpProxy && !httpsProxy) { + return undefined; + } + + return { + type: 'env', + httpProxy, + httpsProxy, + noProxy: getTrimmedEnv('no_proxy', 'NO_PROXY'), + }; +} + +function getProxyEntryPort(entry: string): { hostname: string; port: number } { + const parsed = entry.match(/^(.+):(\d+)$/); + return { + hostname: (parsed ? parsed[1] : entry) + .replace(/^\*?\./, '') + .replace(/^\[|\]$/g, '') + .toLowerCase(), + port: parsed ? Number.parseInt(parsed[2], 10) : 0, + }; +} + +function shouldBypassEnvProxy(url: URL, noProxy?: string): boolean { + if (!noProxy) { + return false; + } + + const trimmed = noProxy.trim(); + if (!trimmed) { + return false; + } + if (trimmed === '*') { + return true; + } + + const hostname = url.hostname.replace(/^\[|\]$/g, '').toLowerCase(); + const port = Number.parseInt(getUrlPort(url), 10) || 0; + + for (const entry of trimmed.split(/[,\s]/)) { + if (!entry) { + continue; + } + + const proxyEntry = getProxyEntryPort(entry); + if (proxyEntry.port && proxyEntry.port !== port) { + continue; + } + if (hostname === proxyEntry.hostname || hostname.endsWith(`.${proxyEntry.hostname}`)) { + return true; + } + } + + return false; +} + +function getProxyUrlForRequest( + proxyConfig: MCPProxyConfig | undefined, + urlString: string, +): string | undefined { + if (!proxyConfig || !urlString) { + return undefined; + } + if (proxyConfig.type === 'explicit') { + return proxyConfig.proxyUrl; + } + + const url = new URL(urlString); + if (shouldBypassEnvProxy(url, proxyConfig.noProxy)) { + return undefined; + } + if (url.protocol === 'https:') { + return proxyConfig.httpsProxy ?? proxyConfig.httpProxy; + } + if (url.protocol === 'http:') { + return proxyConfig.httpProxy; + } + return undefined; +} + +function createMCPDispatcher(options: { + bodyTimeout: number; + headersTimeout: number; + proxyUrl?: string; + keepAliveTimeout?: number; + keepAliveMaxTimeout?: number; + connect?: ReturnType; +}): ManagedDispatcher { + const { bodyTimeout, headersTimeout, proxyUrl, keepAliveTimeout, keepAliveMaxTimeout, connect } = + options; + + const baseOptions = { + bodyTimeout, + headersTimeout, + ...(keepAliveTimeout != null ? { keepAliveTimeout } : {}), + ...(keepAliveMaxTimeout != null ? { keepAliveMaxTimeout } : {}), + }; + + if (proxyUrl) { + return new ProxyAgent({ + uri: proxyUrl, + ...baseOptions, + }); + } + + return new Agent({ + ...baseOptions, + ...(connect != null ? { connect } : {}), + }); +} + +async function assertProxiedRequestTargetAllowed( + urlString: string, + proxyConfig: MCPProxyConfig | undefined, + useSSRFProtection: boolean, + allowedAddresses?: string[] | null, +): Promise { + if (!proxyConfig || !useSSRFProtection) { + return; + } + + const targetUrl = new URL(urlString); + const port = getUrlPort(targetUrl); + const isBlockedTarget = + isSSRFTarget(targetUrl.hostname, allowedAddresses, port) || + (await resolveHostnameSSRF(targetUrl.hostname, allowedAddresses, port)); + + if (!isBlockedTarget) { + return; + } + + throw new Error( + `SSRF protection: proxied MCP request target "${targetUrl.hostname}" resolved to a private/reserved address`, + ); +} + /** * Drops credential-bearing headers when a 307/308 redirect crosses an origin * boundary. Removes the always-forbidden set plus any caller-supplied secret @@ -413,7 +587,7 @@ export class MCPConnection extends EventEmitter { private isReconnecting = false; private isInitializing = false; private reconnectAttempts = 0; - private agents: Agent[] = []; + private agents: Dispatcher[] = []; private readonly userId?: string; private lastPingTime: number; private lastConnectionCheckAt: number = 0; @@ -423,6 +597,7 @@ export class MCPConnection extends EventEmitter { private oauthRecovery = false; private readonly useSSRFProtection: boolean; private readonly allowedAddresses?: string[] | null; + private readonly proxyConfig?: MCPProxyConfig; iconPath?: string; timeout?: number; sseReadTimeout?: number; @@ -538,6 +713,7 @@ export class MCPConnection extends EventEmitter { this.userId = params.userId; this.useSSRFProtection = params.useSSRFProtection === true; this.allowedAddresses = params.allowedAddresses ?? null; + this.proxyConfig = getMCPProxyConfig(params.serverConfig); this.iconPath = params.serverConfig.iconPath; this.timeout = params.serverConfig.timeout; this.sseReadTimeout = params.serverConfig.sseReadTimeout; @@ -580,57 +756,69 @@ export class MCPConnection extends EventEmitter { configuredSecretHeaderKeys?: ReadonlySet, baseUrl?: string, ): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise { + const proxyConfig = this.proxyConfig; + const initialProxyUrl = baseUrl ? getProxyUrlForRequest(proxyConfig, baseUrl) : undefined; const basePort = baseUrl ? getUrlPort(baseUrl) : ''; - const ssrfConnect = this.useSSRFProtection - ? createSSRFSafeUndiciConnect(this.allowedAddresses, basePort) - : undefined; + const ssrfConnect = + this.useSSRFProtection && !initialProxyUrl + ? createSSRFSafeUndiciConnect(this.allowedAddresses, basePort) + : undefined; const connectOpts = ssrfConnect != null ? { connect: ssrfConnect } : {}; + const useSSRFProtection = this.useSSRFProtection; + const allowedAddresses = this.allowedAddresses; /** Capture only the fields needed by the fetch closure; see factory note above. */ const agents = this.agents; const effectiveTimeout = timeout || DEFAULT_TIMEOUT; - const postAgent = new Agent({ + const postAgent = createMCPDispatcher({ bodyTimeout: effectiveTimeout, headersTimeout: effectiveTimeout, + proxyUrl: initialProxyUrl, ...connectOpts, }); this.agents.push(postAgent); - let getAgent: Agent | undefined; + let getAgent: ManagedDispatcher | undefined; if (sseBodyTimeout != null) { - getAgent = new Agent({ + getAgent = createMCPDispatcher({ bodyTimeout: sseBodyTimeout, headersTimeout: effectiveTimeout, + proxyUrl: initialProxyUrl, ...connectOpts, }); this.agents.push(getAgent); } - let safeRedirectPostAgent: Agent | undefined; - let safeRedirectGetAgent: Agent | undefined; + const safeRedirectAgents = new Map(); /** * Allowlist mode keeps the original MCP URL admin-approved, but redirect * targets are server-controlled. These agents add connect-time DNS checks * for those cross-origin hops so DNS rebinding cannot beat the standalone * resolveHostnameSSRF pre-check. */ - const createSafeRedirectAgent = (bodyTimeout: number): Agent => { - const redirectSSRFConnect = createSSRFSafeUndiciConnect(); - const agent = new Agent({ + const getSafeRedirectDispatcher = ( + isGetRequest: boolean, + targetUrlString: string, + ): ManagedDispatcher => { + const bodyTimeout = + isGetRequest && sseBodyTimeout != null ? sseBodyTimeout : effectiveTimeout; + const redirectProxyUrl = getProxyUrlForRequest(proxyConfig, targetUrlString); + const key = `${bodyTimeout}:${redirectProxyUrl ?? 'direct'}`; + const existingAgent = safeRedirectAgents.get(key); + if (existingAgent) { + return existingAgent; + } + + const redirectSSRFConnect = redirectProxyUrl ? undefined : createSSRFSafeUndiciConnect(); + const agent = createMCPDispatcher({ bodyTimeout, headersTimeout: effectiveTimeout, - connect: redirectSSRFConnect, + proxyUrl: redirectProxyUrl, + ...(redirectSSRFConnect != null ? { connect: redirectSSRFConnect } : {}), }); + safeRedirectAgents.set(key, agent); agents.push(agent); return agent; }; - const getSafeRedirectDispatcher = (isGetRequest: boolean): Agent => { - if (!isGetRequest || sseBodyTimeout == null) { - safeRedirectPostAgent ??= createSafeRedirectAgent(effectiveTimeout); - return safeRedirectPostAgent; - } - safeRedirectGetAgent ??= createSafeRedirectAgent(sseBodyTimeout); - return safeRedirectGetAgent; - }; return async function customFetch( input: UndiciRequestInfo, @@ -663,9 +851,16 @@ export class MCPConnection extends EventEmitter { let currentInit = buildFetchInit(resolvedInit, dispatcher, requestHeaders); let currentUrlString = urlString; + let currentAllowedAddresses = allowedAddresses; const originalOrigin = new URL(currentUrlString).origin; for (let redirects = 0; ; redirects++) { + await assertProxiedRequestTargetAllowed( + currentUrlString, + proxyConfig, + useSSRFProtection, + currentAllowedAddresses, + ); const response = await undiciFetch(currentUrlString, currentInit); const isMethodPreservingRedirect = response.status === 307 || response.status === 308; @@ -695,7 +890,7 @@ export class MCPConnection extends EventEmitter { * design — letting redirect targets inherit the exemption would open * an SSRF amplification primitive. */ - if (await resolveHostnameSSRF(targetUrl.hostname)) { + if (isSSRFTarget(targetUrl.hostname) || (await resolveHostnameSSRF(targetUrl.hostname))) { logger.warn( `[MCP] Blocked redirect to private/reserved address: ${sanitizeUrlForLogging(targetUrl)}`, ); @@ -715,6 +910,7 @@ export class MCPConnection extends EventEmitter { } if (isCrossOriginRedirect) { + currentAllowedAddresses = null; /** * Once a server-controlled cross-origin hop is seen, keep the safe * dispatcher for the rest of this redirect chain. Restoring the @@ -725,7 +921,7 @@ export class MCPConnection extends EventEmitter { */ currentInit = { ...currentInit, - dispatcher: getSafeRedirectDispatcher(isGet), + dispatcher: getSafeRedirectDispatcher(isGet, targetUrl.href), }; } @@ -821,14 +1017,17 @@ export class MCPConnection extends EventEmitter { * The connect timeout is extended because proxies may delay initial response. */ const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT; - const ssrfConnect = this.useSSRFProtection - ? createSSRFSafeUndiciConnect(this.allowedAddresses, getUrlPort(url)) - : undefined; - const sseAgent = new Agent({ + const sseProxyUrl = getProxyUrlForRequest(this.proxyConfig, options.url); + const ssrfConnect = + this.useSSRFProtection && !sseProxyUrl + ? createSSRFSafeUndiciConnect(this.allowedAddresses, getUrlPort(url)) + : undefined; + const sseAgent = createMCPDispatcher({ bodyTimeout: sseTimeout, headersTimeout: sseTimeout, keepAliveTimeout: sseTimeout, keepAliveMaxTimeout: sseTimeout * 2, + proxyUrl: sseProxyUrl, ...(ssrfConnect != null ? { connect: ssrfConnect } : {}), }); this.agents.push(sseAgent); @@ -842,13 +1041,23 @@ export class MCPConnection extends EventEmitter { signal: abortController.signal, }, eventSourceInit: { - fetch: (url, init) => { + fetch: async (url, init) => { + const { urlString, resolvedInit } = await resolveFetchInput( + url as UndiciRequestInfo, + init as UndiciRequestInit, + ); + await assertProxiedRequestTargetAllowed( + urlString, + this.proxyConfig, + this.useSSRFProtection, + this.allowedAddresses, + ); /** Merge headers: SSE defaults < init headers < user headers (user wins) */ const fetchHeaders = new Headers( - Object.assign({}, SSE_REQUEST_HEADERS, init?.headers, headers), + Object.assign({}, SSE_REQUEST_HEADERS, resolvedInit?.headers, headers), ); - return undiciFetch(url, { - ...init, + return undiciFetch(urlString, { + ...resolvedInit, redirect: 'manual', dispatcher: sseAgent, headers: fetchHeaders, diff --git a/packages/api/src/utils/env.spec.ts b/packages/api/src/utils/env.spec.ts index e1244fa605..8a4cdefb4a 100644 --- a/packages/api/src/utils/env.spec.ts +++ b/packages/api/src/utils/env.spec.ts @@ -990,6 +990,7 @@ describe('processMCPEnv', () => { process.env.OAUTH_CLIENT_ID = 'oauth-client-id-value'; process.env.OAUTH_CLIENT_SECRET = 'oauth-client-secret-value'; process.env.MCP_SERVER_URL = 'https://mcp.example.com'; + process.env.MCP_PROXY_URL = 'http://proxy.example.com:8080'; }); afterEach(() => { @@ -998,6 +999,7 @@ describe('processMCPEnv', () => { delete process.env.OAUTH_CLIENT_ID; delete process.env.OAUTH_CLIENT_SECRET; delete process.env.MCP_SERVER_URL; + delete process.env.MCP_PROXY_URL; }); it('should return null/undefined as-is', () => { @@ -1045,6 +1047,22 @@ describe('processMCPEnv', () => { }); }); + it('should process outbound proxy for remote MCP options', () => { + const options: MCPOptions = { + type: 'sse', + url: '${MCP_SERVER_URL}/sse', + proxy: '${MCP_PROXY_URL}', + }; + + const result = processMCPEnv({ options }); + + expect(result).toEqual({ + type: 'sse', + url: 'https://mcp.example.com/sse', + proxy: 'http://proxy.example.com:8080', + }); + }); + it('should process OAuth configuration with environment variables', () => { const options: MCPOptions = { type: 'streamable-http', diff --git a/packages/api/src/utils/env.ts b/packages/api/src/utils/env.ts index b5220f1ae8..90da393a53 100644 --- a/packages/api/src/utils/env.ts +++ b/packages/api/src/utils/env.ts @@ -383,6 +383,17 @@ export function processMCPEnv(params: { }); } + // Process outbound proxy if it exists (for SSE and StreamableHTTP types) + if ('proxy' in newObj && newObj.proxy) { + newObj.proxy = processSingleValue({ + user, + body, + dbSourced, + customUserVars, + originalValue: newObj.proxy, + }); + } + // Process OAuth configuration if it exists (for all transport types) if ('oauth' in newObj && newObj.oauth) { const processedOAuth: Record = {}; diff --git a/packages/data-provider/specs/mcp.spec.ts b/packages/data-provider/specs/mcp.spec.ts index 573769c4fa..2bdb0fb5c3 100644 --- a/packages/data-provider/specs/mcp.spec.ts +++ b/packages/data-provider/specs/mcp.spec.ts @@ -1,4 +1,8 @@ -import { SSEOptionsSchema, MCPServerUserInputSchema } from '../src/mcp'; +import { + SSEOptionsSchema, + StreamableHTTPOptionsSchema, + MCPServerUserInputSchema, +} from '../src/mcp'; describe('MCPServerUserInputSchema', () => { describe('env variable exfiltration prevention', () => { @@ -52,6 +56,38 @@ describe('MCPServerUserInputSchema', () => { }); }); + describe('proxy field restrictions', () => { + it('should accept admin-configured proxies for streamable-http', () => { + const result = StreamableHTTPOptionsSchema.safeParse({ + type: 'streamable-http', + url: 'https://mcp-server.com/http', + proxy: 'http://proxy.example.com:8080', + }); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.proxy).toBe('http://proxy.example.com:8080'); + } + }); + + it('should reject unsupported proxy protocols', () => { + const result = StreamableHTTPOptionsSchema.safeParse({ + type: 'streamable-http', + url: 'https://mcp-server.com/http', + proxy: 'ftp://proxy.example.com', + }); + expect(result.success).toBe(false); + }); + + it('should reject proxy configuration from user input', () => { + const result = MCPServerUserInputSchema.safeParse({ + type: 'streamable-http', + url: 'https://mcp-server.com/http', + proxy: 'http://proxy.example.com:8080', + }); + expect(result.success).toBe(false); + }); + }); + describe('protocol allowlisting', () => { it('should reject file:// URLs for SSE', () => { const result = MCPServerUserInputSchema.safeParse({ diff --git a/packages/data-provider/src/mcp.ts b/packages/data-provider/src/mcp.ts index b22a599b9b..4236d49096 100644 --- a/packages/data-provider/src/mcp.ts +++ b/packages/data-provider/src/mcp.ts @@ -103,6 +103,25 @@ const BaseOptionsSchema = z.object({ .optional(), }); +const ProxyUrlSchema = z + .string() + .transform((val: string) => extractEnvVariable(val)) + .pipe(z.string().url()) + .refine( + (val: string) => { + const protocol = new URL(val).protocol; + return ( + protocol === 'http:' || + protocol === 'https:' || + protocol === 'socks:' || + protocol === 'socks5:' + ); + }, + { + message: 'Proxy URL must use http://, https://, socks://, or socks5://', + }, + ); + export const StdioOptionsSchema = BaseOptionsSchema.extend({ type: z.literal('stdio').default('stdio'), /** @@ -163,6 +182,8 @@ export const WebSocketOptionsSchema = BaseOptionsSchema.extend({ export const SSEOptionsSchema = BaseOptionsSchema.extend({ type: z.literal('sse').default('sse'), headers: z.record(z.string(), z.string()).optional(), + /** Optional outbound proxy URL for this remote MCP transport */ + proxy: ProxyUrlSchema.optional(), url: z .string() .transform((val: string) => extractEnvVariable(val)) @@ -181,6 +202,8 @@ export const SSEOptionsSchema = BaseOptionsSchema.extend({ export const StreamableHTTPOptionsSchema = BaseOptionsSchema.extend({ type: z.union([z.literal('streamable-http'), z.literal('http')]), headers: z.record(z.string(), z.string()).optional(), + /** Optional outbound proxy URL for this remote MCP transport */ + proxy: ProxyUrlSchema.optional(), url: z .string() .transform((val: string) => extractEnvVariable(val)) @@ -261,9 +284,11 @@ export const MCPServerUserInputSchema = z.union([ url: userUrlSchema(isWsProtocol, 'WebSocket URL must use ws:// or wss://'), }), omitServerManagedFields(SSEOptionsSchema).extend({ + proxy: z.never().optional(), url: userUrlSchema(isHttpProtocol, 'SSE URL must use http:// or https://'), }), omitServerManagedFields(StreamableHTTPOptionsSchema).extend({ + proxy: z.never().optional(), url: userUrlSchema(isHttpProtocol, 'Streamable HTTP URL must use http:// or https://'), }), ]);