From 6b5596ec3645a70aaec9de6d6a884dab16ece23c Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 12 May 2026 13:26:05 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=8D=AA=20refactor:=20Refresh=20CloudFront?= =?UTF-8?q?=20Media=20Cookies=20(#13091)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: refresh CloudFront media cookies * fix: satisfy changed-file lint * fix: centralize CloudFront image retry * fix: honor base path for CloudFront refresh * fix: bypass auth refresh for CloudFront cookie retry * fix: pass app auth header to CloudFront retry * test: cover CloudFront refresh with OpenID reuse * fix: avoid duplicate CloudFront refresh retries * fix: clear CloudFront scope cookie with matching flags --- .../__tests__/requireJwtAuth.spec.js | 27 +- api/server/middleware/requireJwtAuth.js | 17 +- api/server/routes/__tests__/config.spec.js | 53 +++ api/server/routes/auth.cloudfront.test.js | 154 +++++++++ api/server/routes/auth.js | 23 +- api/server/routes/config.js | 25 +- .../Config/__tests__/useAppStartup.spec.tsx | 39 +++ client/src/hooks/Config/useAppStartup.ts | 12 +- .../cdn/__tests__/cloudfront-cookies.test.ts | 230 ++++++++++++- packages/api/src/cdn/cloudfront-cookies.ts | 318 +++++++++++++++++- packages/client/src/utils/cloudfront.spec.tsx | 202 +++++++++++ packages/client/src/utils/cloudfront.ts | 209 ++++++++++++ packages/client/src/utils/index.ts | 1 + .../specs/headers-helpers.spec.ts | 4 +- packages/data-provider/src/config.ts | 6 + packages/data-provider/src/headers-helpers.ts | 5 + 16 files changed, 1294 insertions(+), 31 deletions(-) create mode 100644 api/server/routes/auth.cloudfront.test.js create mode 100644 packages/client/src/utils/cloudfront.spec.tsx create mode 100644 packages/client/src/utils/cloudfront.ts diff --git a/api/server/middleware/__tests__/requireJwtAuth.spec.js b/api/server/middleware/__tests__/requireJwtAuth.spec.js index 317e8ae6b4..7f0963398d 100644 --- a/api/server/middleware/__tests__/requireJwtAuth.spec.js +++ b/api/server/middleware/__tests__/requireJwtAuth.spec.js @@ -53,6 +53,7 @@ jest.mock('@librechat/api', () => { const { tenantStorage } = require('@librechat/data-schemas'); return { isEnabled: jest.fn(() => false), + maybeRefreshCloudFrontAuthCookiesMiddleware: jest.fn((req, res, next) => next()), tenantContextMiddleware: (req, res, next) => { const tenantId = req.user?.tenantId; if (!tenantId) { @@ -67,7 +68,7 @@ jest.mock('@librechat/api', () => { const requireJwtAuth = require('../requireJwtAuth'); const { getTenantId } = require('@librechat/data-schemas'); -const { isEnabled } = require('@librechat/api'); +const { isEnabled, maybeRefreshCloudFrontAuthCookiesMiddleware } = require('@librechat/api'); const passport = require('passport'); const jwtSecret = 'test-refresh-secret'; @@ -108,6 +109,7 @@ describe('requireJwtAuth tenant context chaining', () => { mockPassportError = null; mockRegisteredStrategies = new Set(['jwt']); isEnabled.mockReturnValue(false); + maybeRefreshCloudFrontAuthCookiesMiddleware.mockClear(); passport.authenticate.mockClear(); passport._strategy.mockClear(); if (originalJwtSecret === undefined) { @@ -134,6 +136,21 @@ describe('requireJwtAuth tenant context chaining', () => { expect(tenantId).toBe('tenant-abc'); }); + it('refreshes CloudFront auth cookies after passport auth succeeds', () => { + const req = mockReq({ tenantId: 'tenant-abc', role: 'user' }); + const res = mockRes(); + const next = jest.fn(); + + requireJwtAuth(req, res, next); + + expect(maybeRefreshCloudFrontAuthCookiesMiddleware).toHaveBeenCalledWith( + req, + res, + expect.any(Function), + ); + expect(next).toHaveBeenCalled(); + }); + it('ALS tenant context is NOT set when user has no tenantId', async () => { const tenantId = await runAuth({ role: 'user' }); expect(tenantId).toBeUndefined(); @@ -201,6 +218,11 @@ describe('requireJwtAuth tenant context chaining', () => { { session: false }, expect.any(Function), ); + expect(maybeRefreshCloudFrontAuthCookiesMiddleware).toHaveBeenCalledWith( + req, + res, + expect.any(Function), + ); }); it('does not authenticate OpenID JWT when the reuse cookie belongs to another user', () => { @@ -236,6 +258,7 @@ describe('requireJwtAuth tenant context chaining', () => { { session: false }, expect.any(Function), ); + expect(maybeRefreshCloudFrontAuthCookiesMiddleware).not.toHaveBeenCalled(); }); it('does not use OpenID JWT when the signed OpenID reuse cookie is missing', () => { @@ -262,6 +285,7 @@ describe('requireJwtAuth tenant context chaining', () => { { session: false }, expect.any(Function), ); + expect(maybeRefreshCloudFrontAuthCookiesMiddleware).not.toHaveBeenCalled(); }); it('does not use OpenID JWT when the OpenID reuse cookie is invalid', () => { @@ -288,6 +312,7 @@ describe('requireJwtAuth tenant context chaining', () => { { session: false }, expect.any(Function), ); + expect(maybeRefreshCloudFrontAuthCookiesMiddleware).not.toHaveBeenCalled(); }); it('skips OpenID JWT fallback when the strategy was not registered', async () => { diff --git a/api/server/middleware/requireJwtAuth.js b/api/server/middleware/requireJwtAuth.js index 2820de5359..e9abbc7fa8 100644 --- a/api/server/middleware/requireJwtAuth.js +++ b/api/server/middleware/requireJwtAuth.js @@ -1,7 +1,11 @@ const cookies = require('cookie'); const jwt = require('jsonwebtoken'); const passport = require('passport'); -const { isEnabled, tenantContextMiddleware } = require('@librechat/api'); +const { + isEnabled, + tenantContextMiddleware, + maybeRefreshCloudFrontAuthCookiesMiddleware, +} = require('@librechat/api'); const hasPassportStrategy = (strategy) => typeof passport._strategy === 'function' && passport._strategy(strategy) != null; @@ -23,6 +27,8 @@ const getValidOpenIdReuseUserId = (parsedCookies) => { }; const getAuthenticatedUserId = (user) => user?.id?.toString?.() ?? user?._id?.toString?.(); +const refreshCloudFrontCookies = + maybeRefreshCloudFrontAuthCookiesMiddleware ?? ((_req, _res, next) => next()); /** * Custom Middleware to handle JWT authentication, with support for OpenID token reuse. @@ -65,8 +71,13 @@ const requireJwtAuth = (req, res, next) => { } req.user = user; req.authStrategy = strategy; - // req.user is now populated by passport — set up tenant ALS context - tenantContextMiddleware(req, res, next); + refreshCloudFrontCookies(req, res, (refreshErr) => { + if (refreshErr) { + return next(refreshErr); + } + // req.user is now populated by passport — set up tenant ALS context + tenantContextMiddleware(req, res, next); + }); })(req, res, next); }; diff --git a/api/server/routes/__tests__/config.spec.js b/api/server/routes/__tests__/config.spec.js index 6acd87ef22..d92c56b8bb 100644 --- a/api/server/routes/__tests__/config.spec.js +++ b/api/server/routes/__tests__/config.spec.js @@ -20,6 +20,12 @@ jest.mock('@librechat/data-schemas', () => ({ getTenantId: (...args) => mockGetTenantId(...args), })); +const mockGetCloudFrontConfig = jest.fn(() => null); +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + getCloudFrontConfig: (...args) => mockGetCloudFrontConfig(...args), +})); + const request = require('supertest'); const express = require('express'); const configRoute = require('../config'); @@ -187,6 +193,53 @@ describe('GET /api/config', () => { expect(response.body).toHaveProperty('serverDomain'); }); + it('should advertise CloudFront cookie refresh only when signed-cookie mode is active', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + mockGetCloudFrontConfig.mockReturnValue({ + domain: 'https://cdn.example.com', + imageSigning: 'cookies', + cookieDomain: '.example.com', + privateKey: 'test-private-key', + keyPairId: 'K123ABC', + }); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.body.cloudFront).toEqual({ + cookieRefresh: { + endpoint: '/api/auth/cloudfront/refresh', + domain: 'https://cdn.example.com', + }, + }); + }); + + it('should omit CloudFront cookie refresh when signed-cookie mode is inactive', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + mockGetCloudFrontConfig.mockReturnValue({ + domain: 'https://cdn.example.com', + imageSigning: 'url', + }); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.body).not.toHaveProperty('cloudFront'); + }); + + it('should omit CloudFront cookie refresh when cookie mode cannot mint cookies', async () => { + mockGetAppConfig.mockResolvedValue(baseAppConfig); + mockGetCloudFrontConfig.mockReturnValue({ + domain: 'https://cdn.example.com', + imageSigning: 'cookies', + }); + const app = createApp(null); + + const response = await request(app).get('/api/config'); + + expect(response.body).not.toHaveProperty('cloudFront'); + }); + it('should default allowAccountDeletion to true when env var is unset', async () => { mockGetAppConfig.mockResolvedValue(baseAppConfig); const app = createApp(null); diff --git a/api/server/routes/auth.cloudfront.test.js b/api/server/routes/auth.cloudfront.test.js new file mode 100644 index 0000000000..9d50ac97a7 --- /dev/null +++ b/api/server/routes/auth.cloudfront.test.js @@ -0,0 +1,154 @@ +const express = require('express'); +const request = require('supertest'); + +const mockForceRefreshCloudFrontAuthCookies = jest.fn(); + +jest.mock('@librechat/api', () => ({ + createSetBalanceConfig: jest.fn(() => (req, res, next) => next()), + forceRefreshCloudFrontAuthCookies: (...args) => mockForceRefreshCloudFrontAuthCookies(...args), +})); + +jest.mock('~/server/controllers/AuthController', () => ({ + refreshController: jest.fn((req, res) => res.status(200).end()), + registrationController: jest.fn((req, res) => res.status(200).end()), + resetPasswordController: jest.fn((req, res) => res.status(200).end()), + resetPasswordRequestController: jest.fn((req, res) => res.status(200).end()), + graphTokenController: jest.fn((req, res) => res.status(200).end()), +})); + +jest.mock('~/server/controllers/TwoFactorController', () => ({ + enable2FA: jest.fn((req, res) => res.status(200).end()), + verify2FA: jest.fn((req, res) => res.status(200).end()), + confirm2FA: jest.fn((req, res) => res.status(200).end()), + disable2FA: jest.fn((req, res) => res.status(200).end()), + regenerateBackupCodes: jest.fn((req, res) => res.status(200).end()), +})); + +jest.mock('~/server/controllers/auth/TwoFactorAuthController', () => ({ + verify2FAWithTempToken: jest.fn((req, res) => res.status(200).end()), +})); + +jest.mock('~/server/controllers/auth/LogoutController', () => ({ + logoutController: jest.fn((req, res) => res.status(200).end()), +})); + +jest.mock('~/server/controllers/auth/LoginController', () => ({ + loginController: jest.fn((req, res) => res.status(200).end()), +})); + +jest.mock('~/models', () => ({ + findBalanceByUser: jest.fn(), + upsertBalanceFields: jest.fn(), +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn(), +})); + +jest.mock('~/server/middleware', () => { + const pass = (req, res, next) => next(); + return { + logHeaders: pass, + loginLimiter: pass, + checkBan: pass, + requireLocalAuth: pass, + requireLdapAuth: pass, + registerLimiter: pass, + checkInviteUser: pass, + validateRegistration: pass, + resetPasswordLimiter: pass, + validatePasswordReset: pass, + requireJwtAuth: jest.fn((req, res, next) => { + if (req.headers.authorization !== 'Bearer ok') { + return res.status(401).json({ message: 'Unauthorized' }); + } + req.user = { _id: 'user123', tenantId: 'tenantA' }; + if (req.headers['x-cloudfront-warmed'] === 'true') { + req.cloudFrontAuthCookieRefreshResult = { + enabled: true, + attempted: true, + refreshed: true, + expiresInSec: 1800, + refreshAfterSec: 1500, + }; + } + return next(); + }), + }; +}); + +const authRouter = require('./auth'); + +describe('POST /api/auth/cloudfront/refresh', () => { + let app; + + beforeEach(() => { + jest.clearAllMocks(); + app = express(); + app.use(express.json()); + app.use('/api/auth', authRouter); + }); + + it('requires authentication', async () => { + await request(app).post('/api/auth/cloudfront/refresh').expect(401); + + expect(mockForceRefreshCloudFrontAuthCookies).not.toHaveBeenCalled(); + }); + + it('returns 404 when CloudFront cookie mode is disabled', async () => { + mockForceRefreshCloudFrontAuthCookies.mockReturnValue({ + enabled: false, + attempted: false, + refreshed: false, + reason: 'cloudfront_disabled', + }); + + const response = await request(app) + .post('/api/auth/cloudfront/refresh') + .set('Authorization', 'Bearer ok') + .expect(404); + + expect(response.status).toBe(404); + }); + + it('returns cookie refresh timing when CloudFront cookies are refreshed', async () => { + mockForceRefreshCloudFrontAuthCookies.mockReturnValue({ + enabled: true, + attempted: true, + refreshed: true, + expiresInSec: 1800, + refreshAfterSec: 1500, + }); + + const response = await request(app) + .post('/api/auth/cloudfront/refresh') + .set('Authorization', 'Bearer ok') + .expect(200); + + expect(response.body).toEqual({ + ok: true, + expiresInSec: 1800, + refreshAfterSec: 1500, + }); + expect(mockForceRefreshCloudFrontAuthCookies).toHaveBeenCalledWith( + expect.objectContaining({ user: { _id: 'user123', tenantId: 'tenantA' } }), + expect.any(Object), + { _id: 'user123', tenantId: 'tenantA' }, + ); + }); + + it('reuses the auth middleware refresh result instead of minting cookies twice', async () => { + const response = await request(app) + .post('/api/auth/cloudfront/refresh') + .set('Authorization', 'Bearer ok') + .set('x-cloudfront-warmed', 'true') + .expect(200); + + expect(response.body).toEqual({ + ok: true, + expiresInSec: 1800, + refreshAfterSec: 1500, + }); + expect(mockForceRefreshCloudFrontAuthCookies).not.toHaveBeenCalled(); + }); +}); diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index c660e6f99d..e2fc08187d 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -1,5 +1,5 @@ const express = require('express'); -const { createSetBalanceConfig } = require('@librechat/api'); +const { createSetBalanceConfig, forceRefreshCloudFrontAuthCookies } = require('@librechat/api'); const { resetPasswordRequestController, resetPasswordController, @@ -28,6 +28,14 @@ const setBalanceConfig = createSetBalanceConfig({ }); const router = express.Router(); +const getCloudFrontAuthCookieRefreshResult = (req, res) => { + const warmedResult = req.cloudFrontAuthCookieRefreshResult; + if (warmedResult && (warmedResult.attempted || !warmedResult.enabled)) { + return warmedResult; + } + + return forceRefreshCloudFrontAuthCookies(req, res, req.user); +}; const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE; //Local @@ -42,6 +50,19 @@ router.post( loginController, ); router.post('/refresh', refreshController); +router.post('/cloudfront/refresh', middleware.requireJwtAuth, (req, res) => { + const result = getCloudFrontAuthCookieRefreshResult(req, res); + if (!result.enabled) { + return res.sendStatus(404); + } + + const status = result.refreshed ? 200 : 500; + return res.status(status).json({ + ok: result.refreshed, + expiresInSec: result.expiresInSec, + refreshAfterSec: result.refreshAfterSec, + }); +}); router.post( '/register', middleware.registerLimiter, diff --git a/api/server/routes/config.js b/api/server/routes/config.js index aaa06a5ee0..46f4cc09da 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,5 +1,5 @@ const express = require('express'); -const { isEnabled, getBalanceConfig } = require('@librechat/api'); +const { isEnabled, getBalanceConfig, getCloudFrontConfig } = require('@librechat/api'); const { defaultSocialLogins } = require('librechat-data-provider'); const { logger, getTenantId, SystemCapabilities } = require('@librechat/data-schemas'); const { hasCapability } = require('~/server/middleware/roles/capabilities'); @@ -116,9 +116,30 @@ function buildWebSearchConfig(appConfig) { }; } +function buildCloudFrontStartupConfig() { + const config = getCloudFrontConfig(); + if ( + config?.imageSigning !== 'cookies' || + !config.domain || + !config.cookieDomain || + !config.privateKey || + !config.keyPairId + ) { + return undefined; + } + + return { + cookieRefresh: { + endpoint: '/api/auth/cloudfront/refresh', + domain: config.domain, + }, + }; +} + router.get('/', async function (req, res) { try { const sharedPayload = buildSharedPayload(); + const cloudFront = buildCloudFrontStartupConfig(); if (!req.user) { const tenantId = getTenantId(); @@ -129,6 +150,7 @@ router.get('/', async function (req, res) { ...sharedPayload, socialLogins: baseConfig?.registration?.socialLogins ?? defaultSocialLogins, turnstile: baseConfig?.turnstileConfig, + ...(cloudFront ? { cloudFront } : {}), }; const interfaceConfig = baseConfig?.interfaceConfig; @@ -170,6 +192,7 @@ router.get('/', async function (req, res) { conversationImportMaxFileSize: process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES ? parseInt(process.env.CONVERSATION_IMPORT_MAX_FILE_SIZE_BYTES, 10) : 0, + ...(cloudFront ? { cloudFront } : {}), }; const webSearch = buildWebSearchConfig(appConfig); diff --git a/client/src/hooks/Config/__tests__/useAppStartup.spec.tsx b/client/src/hooks/Config/__tests__/useAppStartup.spec.tsx index eef2795a76..dd767260d8 100644 --- a/client/src/hooks/Config/__tests__/useAppStartup.spec.tsx +++ b/client/src/hooks/Config/__tests__/useAppStartup.spec.tsx @@ -7,6 +7,21 @@ import type { TUser } from 'librechat-data-provider'; const mockUseHasAccess = jest.fn(); const mockUseMCPServersQuery = jest.fn(); const mockUseMCPToolsQuery = jest.fn(); +const mockInstallCloudFrontImageRetry = jest.fn(() => jest.fn()); +const mockGetTokenHeader = jest.fn(); + +jest.mock('@librechat/client', () => ({ + installCloudFrontImageRetry: (startupConfig: unknown, options: unknown) => + mockInstallCloudFrontImageRetry(startupConfig, options), +})); + +jest.mock('librechat-data-provider', () => { + const actual = jest.requireActual('librechat-data-provider'); + return { + ...actual, + getTokenHeader: () => mockGetTokenHeader(), + }; +}); jest.mock('~/hooks', () => ({ useHasAccess: (args: unknown) => mockUseHasAccess(args), @@ -52,6 +67,7 @@ const wrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => ( describe('useAppStartup — MCP permission gating', () => { beforeEach(() => { + mockInstallCloudFrontImageRetry.mockClear(); mockUseMCPServersQuery.mockReturnValue({ data: undefined, isLoading: false }); mockUseMCPToolsQuery.mockReturnValue({ data: undefined, isLoading: false }); }); @@ -120,4 +136,27 @@ describe('useAppStartup — MCP permission gating', () => { expect(mockUseMCPToolsQuery).toHaveBeenCalledWith({ enabled: false }); }); + + it('installs CloudFront image retry from startup config', () => { + mockUseHasAccess.mockReturnValue(false); + const startupConfig = { + cloudFront: { + cookieRefresh: { + endpoint: '/api/auth/cloudfront/refresh', + domain: 'https://cdn.example.com', + }, + }, + } as never; + + renderHook(() => useAppStartup({ startupConfig, user: mockUser }), { wrapper }); + + expect(mockInstallCloudFrontImageRetry).toHaveBeenCalledWith(startupConfig, { + getAuthorizationHeader: expect.any(Function), + }); + const [, options] = mockInstallCloudFrontImageRetry.mock.calls[0]; + mockGetTokenHeader.mockReturnValue('Bearer app-token'); + + expect(options.getAuthorizationHeader()).toBe('Bearer app-token'); + expect(mockGetTokenHeader).toHaveBeenCalledTimes(1); + }); }); diff --git a/client/src/hooks/Config/useAppStartup.ts b/client/src/hooks/Config/useAppStartup.ts index f40b283ee2..fb8b271c86 100644 --- a/client/src/hooks/Config/useAppStartup.ts +++ b/client/src/hooks/Config/useAppStartup.ts @@ -1,7 +1,13 @@ import { useEffect } from 'react'; import { useRecoilState } from 'recoil'; import TagManager from 'react-gtm-module'; -import { LocalStorageKeys, PermissionTypes, Permissions } from 'librechat-data-provider'; +import { installCloudFrontImageRetry } from '@librechat/client'; +import { + getTokenHeader, + LocalStorageKeys, + PermissionTypes, + Permissions, +} from 'librechat-data-provider'; import type { TStartupConfig, TUser } from 'librechat-data-provider'; import { useMCPToolsQuery, useMCPServersQuery } from '~/data-provider'; import { cleanupTimestampedStorage } from '~/utils/timestamps'; @@ -76,6 +82,10 @@ export default function useAppStartup({ }); }, [defaultPreset, setDefaultPreset, startupConfig?.modelSpecs?.list]); + useEffect(() => { + return installCloudFrontImageRetry(startupConfig, { getAuthorizationHeader: getTokenHeader }); + }, [startupConfig]); + useEffect(() => { if (startupConfig?.analyticsGtmId != null && typeof window.google_tag_manager === 'undefined') { const tagManagerArgs = { diff --git a/packages/api/src/cdn/__tests__/cloudfront-cookies.test.ts b/packages/api/src/cdn/__tests__/cloudfront-cookies.test.ts index 44f3e0b7b0..b9d22048e3 100644 --- a/packages/api/src/cdn/__tests__/cloudfront-cookies.test.ts +++ b/packages/api/src/cdn/__tests__/cloudfront-cookies.test.ts @@ -19,6 +19,8 @@ import type { Response } from 'express'; import { setCloudFrontCookies, clearCloudFrontCookies, + forceRefreshCloudFrontAuthCookies, + maybeRefreshCloudFrontAuthCookies, parseCloudFrontCookieScope, } from '../cloudfront-cookies'; @@ -27,6 +29,24 @@ const { logger: mockLogger } = jest.requireMock('@librechat/data-schemas') as { }; const defaultScope = { userId: 'user123' }; +const encodeScope = (scope: object) => + Buffer.from(JSON.stringify(scope), 'utf8').toString('base64url'); + +afterEach(() => { + jest.restoreAllMocks(); +}); + +function defaultCookieConfig(overrides: object = {}) { + return { + domain: 'https://cdn.example.com', + imageSigning: 'cookies', + cookieExpiry: 1800, + cookieDomain: '.example.com', + privateKey: '-----BEGIN RSA PRIVATE KEY-----\ntest\n-----END RSA PRIVATE KEY-----', + keyPairId: 'K123ABC', + ...overrides, + }; +} describe('setCloudFrontCookies', () => { let mockRes: Partial; @@ -156,6 +176,33 @@ describe('setCloudFrontCookies', () => { expect(cookieNames).toContain('CloudFront-Key-Pair-Id'); }); + it('sets a non-HttpOnly scope cookie with issuedAt and expiresAt timing', () => { + jest.spyOn(Date, 'now').mockReturnValue(1_700_000_000_000); + mockGetCloudFrontConfig.mockReturnValue(defaultCookieConfig({ cookieExpiry: 1800 })); + mockGetSignedCookies.mockReturnValue({ + 'CloudFront-Policy': 'policy-value', + 'CloudFront-Signature': 'signature-value', + 'CloudFront-Key-Pair-Id': 'K123ABC', + }); + + const result = setCloudFrontCookies(mockRes as Response, { + userId: 'user123', + tenantId: 'tenantA', + storageRegion: 'us-east-2', + }); + + expect(result).toBe(true); + const [, value, options] = cookieArgs.find(([name]) => name === 'LibreChat-CloudFront-Scope')!; + expect(options).toMatchObject({ httpOnly: false, path: '/' }); + expect(parseCloudFrontCookieScope(value)).toEqual({ + userId: 'user123', + tenantId: 'tenantA', + storageRegion: 'us-east-2', + issuedAt: 1_700_000_000, + expiresAt: 1_700_001_800, + }); + }); + it('uses cookieDomain from config with path-scoped cookies', () => { mockGetCloudFrontConfig.mockReturnValue({ domain: 'https://cdn.example.com', @@ -299,8 +346,8 @@ describe('setCloudFrontCookies', () => { const [name, value, options] = cookieArgs[cookieArgs.length - 1]; expect(name).toBe('LibreChat-CloudFront-Scope'); expect(options).toMatchObject({ domain: '.example.com', path: '/' }); - expect(Buffer.from(value, 'base64url').toString('utf8')).toBe( - JSON.stringify({ userId: 'user123', tenantId: 'tenantA', storageRegion: null }), + expect(parseCloudFrontCookieScope(value)).toEqual( + expect.objectContaining({ userId: 'user123', tenantId: 'tenantA' }), ); }); @@ -417,8 +464,12 @@ describe('setCloudFrontCookies', () => { expect(cookieArgs[3][2]).toMatchObject({ path: '/a' }); const [, scopeValue] = cookieArgs[cookieArgs.length - 1]; - expect(Buffer.from(scopeValue, 'base64url').toString('utf8')).toBe( - JSON.stringify({ userId: 'user123', tenantId: 'tenantA', storageRegion: 'us-east-2' }), + expect(parseCloudFrontCookieScope(scopeValue)).toEqual( + expect.objectContaining({ + userId: 'user123', + tenantId: 'tenantA', + storageRegion: 'us-east-2', + }), ); }); @@ -461,8 +512,8 @@ describe('setCloudFrontCookies', () => { Resource: 'https://cdn.example.com/a/r/*/t/tenantA/avatars/*', }), ]); - expect(Buffer.from(scopeValue, 'base64url').toString('utf8')).toBe( - JSON.stringify({ userId: 'user123', tenantId: 'tenantA', storageRegion: null }), + expect(parseCloudFrontCookieScope(scopeValue)).toEqual( + expect.objectContaining({ userId: 'user123', tenantId: 'tenantA' }), ); } finally { if (originalRegion == null) { @@ -640,9 +691,6 @@ describe('setCloudFrontCookies', () => { }); describe('parseCloudFrontCookieScope', () => { - const encodeScope = (scope: object) => - Buffer.from(JSON.stringify(scope), 'utf8').toString('base64url'); - it('round-trips a valid user and tenant scope', () => { const value = encodeScope({ userId: 'user123', tenantId: 'tenantA' }); @@ -667,6 +715,163 @@ describe('parseCloudFrontCookieScope', () => { parseCloudFrontCookieScope(encodeScope({ userId: 'user123', tenantId: 'tenant A' })), ).toBeNull(); }); + + it('handles old scope cookies without timing fields', () => { + expect(parseCloudFrontCookieScope(encodeScope({ userId: 'user123' }))).toEqual({ + userId: 'user123', + }); + }); + + it('drops invalid timing fields while preserving valid scope', () => { + expect( + parseCloudFrontCookieScope( + encodeScope({ userId: 'user123', issuedAt: 'bad', expiresAt: Number.NaN }), + ), + ).toEqual({ userId: 'user123' }); + }); +}); + +describe('maybeRefreshCloudFrontAuthCookies', () => { + let mockRes: Partial; + + beforeEach(() => { + jest.clearAllMocks(); + mockRes = { + cookie: jest.fn().mockReturnThis(), + clearCookie: jest.fn().mockReturnThis(), + }; + mockGetSignedCookies.mockReturnValue({ + 'CloudFront-Policy': 'policy-value', + 'CloudFront-Signature': 'signature-value', + 'CloudFront-Key-Pair-Id': 'K123ABC', + }); + mockGetCloudFrontConfig.mockReturnValue(defaultCookieConfig()); + jest.spyOn(Date, 'now').mockReturnValue(1_700_000_000_000); + }); + + it('refreshes when the scope cookie is missing', () => { + const result = maybeRefreshCloudFrontAuthCookies({ cookies: {} }, mockRes as Response, { + _id: 'user123', + }); + + expect(result).toMatchObject({ enabled: true, attempted: true, refreshed: true }); + expect(mockGetSignedCookies).toHaveBeenCalled(); + }); + + it('refreshes when the scope cookie is near expiry', () => { + const result = maybeRefreshCloudFrontAuthCookies( + { + cookies: { + 'LibreChat-CloudFront-Scope': encodeScope({ + userId: 'user123', + expiresAt: 1_700_000_250, + }), + }, + }, + mockRes as Response, + { _id: 'user123' }, + ); + + expect(result).toMatchObject({ attempted: true, refreshed: true, reason: 'near_expiry' }); + }); + + it('refreshes when the tenant or user scope mismatches', () => { + const userMismatch = maybeRefreshCloudFrontAuthCookies( + { + cookies: { + 'LibreChat-CloudFront-Scope': encodeScope({ + userId: 'old-user', + tenantId: 'tenantA', + expiresAt: 1_700_001_000, + }), + }, + }, + mockRes as Response, + { _id: 'user123', tenantId: 'tenantA' }, + ); + + const tenantMismatch = maybeRefreshCloudFrontAuthCookies( + { + cookies: { + 'LibreChat-CloudFront-Scope': encodeScope({ + userId: 'user123', + tenantId: 'old-tenant', + expiresAt: 1_700_001_000, + }), + }, + }, + mockRes as Response, + { _id: 'user123', tenantId: 'tenantA' }, + ); + + expect(userMismatch).toMatchObject({ attempted: true, reason: 'user_scope_mismatch' }); + expect(tenantMismatch).toMatchObject({ attempted: true, reason: 'tenant_scope_mismatch' }); + }); + + it('does not refresh when the scope cookie is still fresh', () => { + const result = maybeRefreshCloudFrontAuthCookies( + { + cookies: { + 'LibreChat-CloudFront-Scope': encodeScope({ + userId: 'user123', + expiresAt: 1_700_001_000, + }), + }, + }, + mockRes as Response, + { _id: 'user123' }, + ); + + expect(result).toMatchObject({ + enabled: true, + attempted: false, + refreshed: false, + reason: 'fresh', + }); + expect(mockGetSignedCookies).not.toHaveBeenCalled(); + }); + + it('does not refresh when CloudFront is disabled', () => { + mockGetCloudFrontConfig.mockReturnValue(null); + + const result = maybeRefreshCloudFrontAuthCookies({ cookies: {} }, mockRes as Response, { + _id: 'user123', + }); + + expect(result).toMatchObject({ enabled: false, attempted: false, refreshed: false }); + expect(mockGetSignedCookies).not.toHaveBeenCalled(); + }); + + it('does not refresh when imageSigning is not cookies', () => { + mockGetCloudFrontConfig.mockReturnValue(defaultCookieConfig({ imageSigning: 'url' })); + + const result = maybeRefreshCloudFrontAuthCookies({ cookies: {} }, mockRes as Response, { + _id: 'user123', + }); + + expect(result).toMatchObject({ enabled: false, attempted: false, refreshed: false }); + expect(mockGetSignedCookies).not.toHaveBeenCalled(); + }); + + it('force-refreshes even when the scope cookie is fresh without calling OIDC refresh', () => { + const oidcRefresh = jest.fn(); + + const result = forceRefreshCloudFrontAuthCookies( + { + cookies: { + 'LibreChat-CloudFront-Scope': encodeScope({ + userId: 'user123', + expiresAt: 1_700_001_000, + }), + }, + }, + mockRes as Response, + { _id: 'user123' }, + ); + + expect(result).toMatchObject({ attempted: true, refreshed: true, reason: 'forced' }); + expect(oidcRefresh).not.toHaveBeenCalled(); + }); }); describe('clearCloudFrontCookies', () => { @@ -742,12 +947,17 @@ describe('clearCloudFrontCookies', () => { secure: true, sameSite: 'none', }; + const scopePathOptions = { + ...rootPathOptions, + httpOnly: false, + }; expect(clearedCookies).toContainEqual(['CloudFront-Policy', legacyPathOptions]); expect(clearedCookies).toContainEqual(['CloudFront-Signature', legacyPathOptions]); expect(clearedCookies).toContainEqual(['CloudFront-Key-Pair-Id', legacyPathOptions]); expect(clearedCookies).toContainEqual(['CloudFront-Policy', rootPathOptions]); expect(clearedCookies).toContainEqual(['CloudFront-Signature', rootPathOptions]); expect(clearedCookies).toContainEqual(['CloudFront-Key-Pair-Id', rootPathOptions]); + expect(clearedCookies).toContainEqual(['LibreChat-CloudFront-Scope', scopePathOptions]); expect(clearedCookies).toContainEqual([ 'CloudFront-Policy', expect.objectContaining({ path: '/r' }), @@ -799,7 +1009,7 @@ describe('clearCloudFrontCookies', () => { { domain: '.example.com', path: '/', - httpOnly: true, + httpOnly: false, secure: true, sameSite: 'none', }, diff --git a/packages/api/src/cdn/cloudfront-cookies.ts b/packages/api/src/cdn/cloudfront-cookies.ts index 9bcfb8d58b..98de321043 100644 --- a/packages/api/src/cdn/cloudfront-cookies.ts +++ b/packages/api/src/cdn/cloudfront-cookies.ts @@ -1,7 +1,7 @@ import { getSignedCookies } from '@aws-sdk/cloudfront-signer'; import { logger } from '@librechat/data-schemas'; -import type { Response } from 'express'; +import type { NextFunction, Response } from 'express'; import { INLINE_AVATAR_PATH_PREFIX, INLINE_IMAGE_PATH_PREFIX } from '~/storage/constants'; import { assertPathSegment } from '~/storage/validation'; @@ -23,8 +23,44 @@ export interface CloudFrontCookieScope { userId?: string | null; tenantId?: string | null; storageRegion?: string | null; + issuedAt?: number | null; + expiresAt?: number | null; } +type CloudFrontScopeValue = string | number | { toString(): string } | null | undefined; + +type CloudFrontScopeUser = { + _id?: CloudFrontScopeValue; + id?: CloudFrontScopeValue; + tenantId?: CloudFrontScopeValue; + orgId?: CloudFrontScopeValue; + storageRegion?: CloudFrontScopeValue; +}; + +type CloudFrontCookieRequest = { + cookies?: Partial>; + user?: CloudFrontScopeUser | null; +}; + +type CloudFrontAuthCookieRefreshRequest = CloudFrontCookieRequest & { + cloudFrontAuthCookieRefreshResult?: CloudFrontAuthCookieRefreshResult; +}; + +export type CloudFrontAuthCookieRefreshResult = { + enabled: boolean; + attempted: boolean; + refreshed: boolean; + reason?: string; + expiresInSec?: number; + refreshAfterSec?: number; +}; + +type CloudFrontCookieRefreshOptions = CloudFrontCookieScope & { + orgId?: CloudFrontScopeValue; + force?: boolean; + refreshWindowSec?: number; +}; + type CookieOptions = { domain: string; httpOnly: boolean; @@ -106,6 +142,41 @@ function getPolicyScopes( ]; } +function getConfiguredCookieExpiry(): number { + const config = getCloudFrontConfig(); + return config?.cookieExpiry ?? DEFAULT_COOKIE_EXPIRY; +} + +export function getCloudFrontCookieRefreshWindowSec(cookieExpiry = getConfiguredCookieExpiry()) { + return Math.min(300, Math.floor(cookieExpiry / 4)); +} + +export function getCloudFrontCookieTiming() { + const expiresInSec = getConfiguredCookieExpiry(); + const refreshWindowSec = getCloudFrontCookieRefreshWindowSec(expiresInSec); + return { + expiresInSec, + refreshAfterSec: Math.max(0, expiresInSec - refreshWindowSec), + refreshWindowSec, + }; +} + +function getEffectiveCloudFrontScope( + scope: CloudFrontCookieScope, + includeRegionInPath: boolean, +): CloudFrontCookieScope { + const configuredStorageRegion = + scope.storageRegion ?? + getCloudFrontConfig()?.storageRegion ?? + s3Config.AWS_REGION ?? + process.env.AWS_REGION; + const scopedStorageRegion = includeRegionInPath ? configuredStorageRegion : scope.storageRegion; + return { + ...scope, + ...(scopedStorageRegion ? { storageRegion: scopedStorageRegion } : {}), + }; +} + function getScopeCookiePaths( scope: CloudFrontCookieScope, { includeTenantRoot = false }: { includeTenantRoot?: boolean } = {}, @@ -136,6 +207,8 @@ function encodeCloudFrontCookieScope(scope: CloudFrontCookieScope): string { userId: scope.userId ?? null, tenantId: scope.tenantId ?? null, storageRegion: scope.storageRegion ?? null, + issuedAt: scope.issuedAt ?? null, + expiresAt: scope.expiresAt ?? null, }; return Buffer.from(JSON.stringify(payload), 'utf8').toString('base64url'); } @@ -152,6 +225,8 @@ export function parseCloudFrontCookieScope( userId?: unknown; tenantId?: unknown; storageRegion?: unknown; + issuedAt?: unknown; + expiresAt?: unknown; }; const scope: CloudFrontCookieScope = {}; if (typeof parsed.userId === 'string') { @@ -163,12 +238,111 @@ export function parseCloudFrontCookieScope( if (typeof parsed.storageRegion === 'string') { scope.storageRegion = assertPolicyPathSegment('storageRegion', parsed.storageRegion); } + if (typeof parsed.issuedAt === 'number' && Number.isFinite(parsed.issuedAt)) { + scope.issuedAt = parsed.issuedAt; + } + if (typeof parsed.expiresAt === 'number' && Number.isFinite(parsed.expiresAt)) { + scope.expiresAt = parsed.expiresAt; + } return scope.userId ? scope : null; } catch { return null; } } +function normalizeCloudFrontScopeValue(value: CloudFrontScopeValue): string | undefined { + if (value == null) { + return undefined; + } + + const normalized = String(value); + return normalized.length > 0 ? normalized : undefined; +} + +function getCloudFrontScopeValue( + optionsValue: CloudFrontScopeValue, + userValue: CloudFrontScopeValue, + requestValue: CloudFrontScopeValue, +): string | undefined { + return normalizeCloudFrontScopeValue(optionsValue ?? userValue ?? requestValue); +} + +export function resolveCloudFrontCookieScope( + req: CloudFrontCookieRequest | null | undefined, + user: CloudFrontScopeUser | null | undefined, + options: CloudFrontCookieRefreshOptions = {}, +): CloudFrontCookieScope { + const storageRegion = getCloudFrontScopeValue( + options.storageRegion, + user?.storageRegion, + req?.user?.storageRegion, + ); + return { + userId: getCloudFrontScopeValue( + options.userId, + user?._id ?? user?.id, + req?.user?._id ?? req?.user?.id, + ), + tenantId: getCloudFrontScopeValue( + options.tenantId ?? options.orgId, + user?.tenantId ?? user?.orgId, + req?.user?.tenantId ?? req?.user?.orgId, + ), + ...(storageRegion ? { storageRegion } : {}), + }; +} + +function getPreviousCloudFrontScope( + req: CloudFrontCookieRequest | null | undefined, +): CloudFrontCookieScope | null { + return parseCloudFrontCookieScope(req?.cookies?.[CLOUDFRONT_SCOPE_COOKIE]); +} + +function getCloudFrontCookieSkipReason(scope: CloudFrontCookieScope): string | null { + const config = getCloudFrontConfig(); + if (!config || config.imageSigning !== 'cookies' || !config.privateKey || !config.keyPairId) { + return 'cloudfront_disabled'; + } + if (!config.cookieDomain) { + return 'missing_cookie_domain'; + } + if (!scope.userId) { + return 'missing_user_id'; + } + return null; +} + +function getScopeRefreshReason( + previousScope: CloudFrontCookieScope | null, + currentScope: CloudFrontCookieScope, + refreshWindowSec: number, +): string | null { + if (!previousScope?.userId) { + return 'missing_scope'; + } + if (previousScope.userId !== currentScope.userId) { + return 'user_scope_mismatch'; + } + if ((previousScope.tenantId ?? null) !== (currentScope.tenantId ?? null)) { + return 'tenant_scope_mismatch'; + } + if ((previousScope.storageRegion ?? null) !== (currentScope.storageRegion ?? null)) { + return 'storage_region_scope_mismatch'; + } + + const expiresAt = Number(previousScope.expiresAt); + if (!Number.isFinite(expiresAt)) { + return 'missing_expiry'; + } + + const now = Math.floor(Date.now() / 1000); + if (expiresAt - now <= refreshWindowSec) { + return 'near_expiry'; + } + + return null; +} + function clearCookiePaths( res: Response, baseOptions: CookieOptions, @@ -215,7 +389,7 @@ export function clearCloudFrontCookies(res: Response, scope: CloudFrontCookieSco } clearCookiePaths(res, baseOptions, paths); - res.clearCookie(CLOUDFRONT_SCOPE_COOKIE, { ...baseOptions, path: '/' }); + res.clearCookie(CLOUDFRONT_SCOPE_COOKIE, { ...baseOptions, httpOnly: false, path: '/' }); } catch (error) { logger.warn('[clearCloudFrontCookies] Failed to clear cookies:', error); } @@ -254,20 +428,15 @@ export function setCloudFrontCookies( try { const { keyPairId, privateKey } = config; - const cookieExpiry = config.cookieExpiry ?? DEFAULT_COOKIE_EXPIRY; - const expiresAtMs = Date.now() + cookieExpiry * 1000; + const cookieExpiry = getConfiguredCookieExpiry(); + const issuedAtEpoch = Math.floor(Date.now() / 1000); + const expiresAtEpoch = issuedAtEpoch + cookieExpiry; + const expiresAtMs = expiresAtEpoch * 1000; const expiresAt = new Date(expiresAtMs); - const expiresAtEpoch = Math.floor(expiresAtMs / 1000); const cleanDomain = config.domain.replace(/\/+$/, ''); const includeRegionInPath = config.includeRegionInPath ?? false; - const configuredStorageRegion = - scope.storageRegion ?? config.storageRegion ?? s3Config.AWS_REGION ?? process.env.AWS_REGION; - const scopedStorageRegion = includeRegionInPath ? configuredStorageRegion : scope.storageRegion; - const effectiveScope = { - ...scope, - ...(scopedStorageRegion ? { storageRegion: scopedStorageRegion } : {}), - }; + const effectiveScope = getEffectiveCloudFrontScope(scope, includeRegionInPath); const policyScopes = getPolicyScopes(cleanDomain, effectiveScope, includeRegionInPath); const resourcesByPath = new Map(); for (const { resource, path } of policyScopes) { @@ -336,8 +505,14 @@ export function setCloudFrontCookies( res.cookie(key, cookies[key], cookieOptions); } } - res.cookie(CLOUDFRONT_SCOPE_COOKIE, encodeCloudFrontCookieScope(effectiveScope), { + const scopeCookieValue = encodeCloudFrontCookieScope({ + ...effectiveScope, + issuedAt: issuedAtEpoch, + expiresAt: expiresAtEpoch, + }); + res.cookie(CLOUDFRONT_SCOPE_COOKIE, scopeCookieValue, { ...baseCookieOptions, + httpOnly: false, path: '/', }); @@ -351,3 +526,120 @@ export function setCloudFrontCookies( return false; } } + +export function maybeRefreshCloudFrontAuthCookies( + req: CloudFrontCookieRequest | null | undefined, + res: Response, + user: CloudFrontScopeUser | null | undefined, + options: CloudFrontCookieRefreshOptions = {}, +): CloudFrontAuthCookieRefreshResult { + try { + const config = getCloudFrontConfig(); + const scope = resolveCloudFrontCookieScope(req, user, options); + const skipReason = getCloudFrontCookieSkipReason(scope); + const timing = getCloudFrontCookieTiming(); + + if (skipReason) { + logger.debug('[maybeRefreshCloudFrontAuthCookies] CloudFront auth cookies skipped', { + attempted: false, + refreshed: false, + reason: skipReason, + has_user_id: Boolean(scope.userId), + has_tenant_scope: Boolean(scope.tenantId), + has_storage_region: Boolean(scope.storageRegion), + }); + return { + enabled: false, + attempted: false, + refreshed: false, + reason: skipReason, + }; + } + + const includeRegionInPath = config?.includeRegionInPath ?? false; + const effectiveScope = getEffectiveCloudFrontScope(scope, includeRegionInPath); + const previousScope = getPreviousCloudFrontScope(req); + const refreshWindowSec = options.refreshWindowSec ?? timing.refreshWindowSec; + const refreshReason = options.force + ? 'forced' + : getScopeRefreshReason(previousScope, effectiveScope, refreshWindowSec); + + if (!refreshReason) { + logger.debug('[maybeRefreshCloudFrontAuthCookies] CloudFront auth cookies still fresh', { + attempted: false, + refreshed: false, + reason: 'fresh', + refresh_window_sec: refreshWindowSec, + }); + return { + enabled: true, + attempted: false, + refreshed: false, + reason: 'fresh', + expiresInSec: timing.expiresInSec, + refreshAfterSec: timing.refreshAfterSec, + }; + } + + const cookiesSet = setCloudFrontCookies(res, effectiveScope, previousScope); + const logPayload = { + attempted: true, + refreshed: cookiesSet, + reason: cookiesSet ? refreshReason : 'set_failed', + refresh_window_sec: refreshWindowSec, + has_tenant_scope: Boolean(effectiveScope.tenantId), + has_storage_region: Boolean(effectiveScope.storageRegion), + has_previous_scope: Boolean(previousScope?.userId), + }; + + if (cookiesSet) { + logger.debug( + '[maybeRefreshCloudFrontAuthCookies] CloudFront auth cookies refreshed', + logPayload, + ); + } else { + logger.warn( + '[maybeRefreshCloudFrontAuthCookies] CloudFront auth cookie refresh failed', + logPayload, + ); + } + + return { + enabled: true, + attempted: true, + refreshed: cookiesSet, + reason: cookiesSet ? refreshReason : 'set_failed', + expiresInSec: timing.expiresInSec, + refreshAfterSec: timing.refreshAfterSec, + }; + } catch (error) { + logger.warn( + '[maybeRefreshCloudFrontAuthCookies] Failed to refresh CloudFront auth cookies:', + error, + ); + return { + enabled: false, + attempted: false, + refreshed: false, + reason: 'error', + }; + } +} + +export function forceRefreshCloudFrontAuthCookies( + req: CloudFrontCookieRequest | null | undefined, + res: Response, + user: CloudFrontScopeUser | null | undefined, + options: CloudFrontCookieRefreshOptions = {}, +): CloudFrontAuthCookieRefreshResult { + return maybeRefreshCloudFrontAuthCookies(req, res, user, { ...options, force: true }); +} + +export function maybeRefreshCloudFrontAuthCookiesMiddleware( + req: CloudFrontAuthCookieRefreshRequest, + res: Response, + next: NextFunction, +): void { + req.cloudFrontAuthCookieRefreshResult = maybeRefreshCloudFrontAuthCookies(req, res, req.user); + next(); +} diff --git a/packages/client/src/utils/cloudfront.spec.tsx b/packages/client/src/utils/cloudfront.spec.tsx new file mode 100644 index 0000000000..6b4e0affc9 --- /dev/null +++ b/packages/client/src/utils/cloudfront.spec.tsx @@ -0,0 +1,202 @@ +import { fireEvent, waitFor } from '@testing-library/react'; + +const mockApiBaseUrl = jest.fn(() => ''); +const mockGetTokenHeader = jest.fn(() => 'Bearer test-token'); + +jest.mock('librechat-data-provider', () => ({ + apiBaseUrl: () => mockApiBaseUrl(), +})); + +import { + isCloudFrontMediaUrl, + refreshCloudFrontCookiesOnce, + installCloudFrontImageRetry, + configureCloudFrontCookieRefresh, +} from './cloudfront'; + +const cloudFrontStartupConfig = { + cloudFront: { + cookieRefresh: { + endpoint: '/api/auth/cloudfront/refresh', + domain: 'https://cdn.example.com', + }, + }, +}; + +function refreshResponse(payload: { ok?: boolean }, ok = true): Response { + return { + ok, + json: () => Promise.resolve(payload), + } as Response; +} + +describe('CloudFront cookie refresh helpers', () => { + let fetchMock: jest.MockedFunction; + const originalFetch = global.fetch; + + beforeEach(() => { + mockApiBaseUrl.mockReturnValue(''); + mockGetTokenHeader.mockReturnValue('Bearer test-token'); + fetchMock = jest.fn(() => + Promise.resolve(refreshResponse({ ok: true })), + ) as jest.MockedFunction; + global.fetch = fetchMock; + configureCloudFrontCookieRefresh(undefined); + jest.spyOn(Date, 'now').mockReturnValue(1_700_000_000_000); + }); + + afterEach(() => { + global.fetch = originalFetch; + jest.restoreAllMocks(); + }); + + it('no-ops when startup config has no CloudFront refresh capability', async () => { + configureCloudFrontCookieRefresh({}); + + await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(false); + + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it('dedupes concurrent refresh calls', async () => { + let resolveRefresh: ((value: Response) => void) | undefined; + fetchMock.mockReturnValue( + new Promise((resolve) => { + resolveRefresh = resolve; + }), + ); + configureCloudFrontCookieRefresh(cloudFrontStartupConfig, { + getAuthorizationHeader: mockGetTokenHeader, + }); + + const first = refreshCloudFrontCookiesOnce(); + const second = refreshCloudFrontCookiesOnce(); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock).toHaveBeenCalledWith( + '/api/auth/cloudfront/refresh', + expect.objectContaining({ + method: 'POST', + credentials: 'include', + headers: expect.objectContaining({ Authorization: 'Bearer test-token' }), + body: '{}', + }), + ); + resolveRefresh?.(refreshResponse({ ok: true })); + await expect(first).resolves.toBe(true); + await expect(second).resolves.toBe(true); + }); + + it('returns false on 401 without retrying the refresh request', async () => { + fetchMock.mockResolvedValue(refreshResponse({}, false)); + configureCloudFrontCookieRefresh(cloudFrontStartupConfig, { + getAuthorizationHeader: mockGetTokenHeader, + }); + + await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(false); + + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it('prefixes the refresh endpoint with the configured app base path', async () => { + mockApiBaseUrl.mockReturnValue('/chat'); + configureCloudFrontCookieRefresh(cloudFrontStartupConfig, { + getAuthorizationHeader: mockGetTokenHeader, + }); + + await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(true); + + expect(fetchMock).toHaveBeenCalledWith( + '/chat/api/auth/cloudfront/refresh', + expect.objectContaining({ method: 'POST' }), + ); + }); + + it('detects only the configured CloudFront domain', () => { + expect( + isCloudFrontMediaUrl( + 'https://cdn.example.com/i/images/user/file.png', + cloudFrontStartupConfig, + ), + ).toBe(true); + expect( + isCloudFrontMediaUrl( + 'https://images.example.net/i/images/user/file.png', + cloudFrontStartupConfig, + ), + ).toBe(false); + }); + + it('retries a configured CloudFront image only once from the global listener', async () => { + const cleanup = installCloudFrontImageRetry(cloudFrontStartupConfig); + const img = document.createElement('img'); + const onFailure = jest.fn(); + img.src = 'https://cdn.example.com/i/images/user/file.png'; + img.addEventListener('error', onFailure); + document.body.appendChild(img); + + fireEvent.error(img); + + await waitFor(() => + expect(img).toHaveAttribute( + 'src', + 'https://cdn.example.com/i/images/user/file.png?_cf_refresh=1700000000000', + ), + ); + + fireEvent.error(img); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(onFailure).toHaveBeenCalledTimes(1); + + cleanup(); + img.remove(); + }); + + it('does not consume the one retry when cookie refresh fails', async () => { + fetchMock + .mockResolvedValueOnce(refreshResponse({ ok: false })) + .mockResolvedValueOnce(refreshResponse({ ok: true })); + const cleanup = installCloudFrontImageRetry(cloudFrontStartupConfig); + const img = document.createElement('img'); + const onFailure = jest.fn(); + img.src = 'https://cdn.example.com/i/images/user/file.png'; + img.addEventListener('error', onFailure); + document.body.appendChild(img); + + fireEvent.error(img); + + await waitFor(() => expect(onFailure).toHaveBeenCalledTimes(1)); + expect(img).toHaveAttribute('src', 'https://cdn.example.com/i/images/user/file.png'); + + fireEvent.error(img); + + await waitFor(() => + expect(img).toHaveAttribute( + 'src', + 'https://cdn.example.com/i/images/user/file.png?_cf_refresh=1700000000000', + ), + ); + expect(fetchMock).toHaveBeenCalledTimes(2); + + cleanup(); + img.remove(); + }); + + it('does not retry arbitrary external images', () => { + const cleanup = installCloudFrontImageRetry(cloudFrontStartupConfig); + const img = document.createElement('img'); + const onFailure = jest.fn(); + img.src = 'https://example.com/photo.png'; + img.addEventListener('error', onFailure); + document.body.appendChild(img); + + fireEvent.error(img); + + expect(fetchMock).not.toHaveBeenCalled(); + expect(onFailure).toHaveBeenCalledTimes(1); + + cleanup(); + img.remove(); + }); +}); diff --git a/packages/client/src/utils/cloudfront.ts b/packages/client/src/utils/cloudfront.ts new file mode 100644 index 0000000000..1dc07cff03 --- /dev/null +++ b/packages/client/src/utils/cloudfront.ts @@ -0,0 +1,209 @@ +import { apiBaseUrl } from 'librechat-data-provider'; +import type { TStartupConfig } from 'librechat-data-provider'; + +type CloudFrontCookieRefreshConfig = NonNullable< + NonNullable['cookieRefresh'] +>; +type CloudFrontCookieRefreshResponse = { + ok?: boolean; +}; +type CloudFrontCookieRefreshOptions = { + getAuthorizationHeader?: () => string | undefined; +}; + +let cookieRefreshConfig: CloudFrontCookieRefreshConfig | undefined; +let getAuthorizationHeader: CloudFrontCookieRefreshOptions['getAuthorizationHeader']; +let refreshPromise: Promise | null = null; +let removeImageErrorListener: (() => void) | null = null; +const retriedImageSources = new WeakMap(); +const pendingImageRefreshes = new WeakMap(); +const forwardedImageErrors = new WeakSet(); + +function getRefreshConfig( + startupConfig?: Pick | null, +): CloudFrontCookieRefreshConfig | undefined { + return startupConfig?.cloudFront?.cookieRefresh ?? cookieRefreshConfig; +} + +function getBaseUrl(): string { + return typeof window === 'undefined' ? 'http://localhost' : window.location.origin; +} + +function parseUrl(value: string): URL | null { + try { + return new URL(value, getBaseUrl()); + } catch { + return null; + } +} + +export function configureCloudFrontCookieRefresh( + startupConfig?: Pick | null, + options: CloudFrontCookieRefreshOptions = {}, +): void { + cookieRefreshConfig = startupConfig?.cloudFront?.cookieRefresh; + getAuthorizationHeader = options.getAuthorizationHeader; +} + +export function isCloudFrontMediaUrl( + url: string | null | undefined, + startupConfig?: Pick | null, +): boolean { + const config = getRefreshConfig(startupConfig); + if (!url || !config?.domain) { + return false; + } + + const mediaUrl = parseUrl(url); + const cloudFrontUrl = parseUrl(config.domain); + return mediaUrl?.origin === cloudFrontUrl?.origin; +} + +export function withCloudFrontCacheBuster(url: string): string { + const parsed = parseUrl(url); + if (!parsed) { + return url; + } + + parsed.searchParams.set('_cf_refresh', Date.now().toString()); + return parsed.toString(); +} + +function getRetryKey(url: string): string { + const parsed = parseUrl(url); + if (!parsed) { + return url; + } + + parsed.searchParams.delete('_cf_refresh'); + return parsed.toString(); +} + +function dispatchImageError(img: HTMLImageElement): void { + forwardedImageErrors.add(img); + img.dispatchEvent(new Event('error')); +} + +function getRefreshEndpoint(endpoint: string): string { + if (/^https?:\/\//i.test(endpoint)) { + return endpoint; + } + + const baseUrl = apiBaseUrl(); + if (!baseUrl || endpoint === baseUrl || endpoint.startsWith(`${baseUrl}/`)) { + return endpoint; + } + + return `${baseUrl}${endpoint.startsWith('/') ? '' : '/'}${endpoint}`; +} + +async function postCloudFrontCookieRefresh(endpoint: string): Promise { + const authorization = getAuthorizationHeader?.(); + const headers: Record = { + Accept: 'application/json', + 'Content-Type': 'application/json', + }; + + if (authorization) { + headers.Authorization = authorization; + } + + const response = await fetch(endpoint, { + method: 'POST', + credentials: 'include', + headers, + body: '{}', + }); + + if (!response.ok) { + return false; + } + + const payload = (await response.json()) as CloudFrontCookieRefreshResponse; + return payload.ok === true; +} + +export function refreshCloudFrontCookiesOnce(): Promise { + const config = getRefreshConfig(); + if (!config?.endpoint) { + return Promise.resolve(false); + } + + if (refreshPromise) { + return refreshPromise; + } + + const endpoint = getRefreshEndpoint(config.endpoint); + refreshPromise = postCloudFrontCookieRefresh(endpoint) + .catch(() => false) + .finally(() => { + refreshPromise = null; + }); + + return refreshPromise; +} + +export function installCloudFrontImageRetry( + startupConfig?: Pick | null, + options: CloudFrontCookieRefreshOptions = {}, +): () => void { + configureCloudFrontCookieRefresh(startupConfig, options); + removeImageErrorListener?.(); + removeImageErrorListener = null; + + const config = getRefreshConfig(); + if (typeof window === 'undefined' || !config?.endpoint || !config.domain) { + return () => undefined; + } + + const handleImageError = (event: Event) => { + const img = event.target; + if (!(img instanceof HTMLImageElement)) { + return; + } + if (forwardedImageErrors.has(img)) { + forwardedImageErrors.delete(img); + return; + } + + const failedSrc = img.currentSrc || img.src || img.getAttribute('src') || ''; + if (!isCloudFrontMediaUrl(failedSrc)) { + return; + } + + const retryKey = getRetryKey(failedSrc); + if (retriedImageSources.get(img) === retryKey) { + return; + } + + event.preventDefault(); + event.stopPropagation(); + event.stopImmediatePropagation(); + if (pendingImageRefreshes.get(img) === retryKey) { + return; + } + pendingImageRefreshes.set(img, retryKey); + + void refreshCloudFrontCookiesOnce().then((refreshed) => { + pendingImageRefreshes.delete(img); + if (!refreshed || !img.isConnected) { + dispatchImageError(img); + return; + } + + retriedImageSources.set(img, retryKey); + img.src = withCloudFrontCacheBuster(failedSrc); + }); + }; + + window.addEventListener('error', handleImageError, true); + const cleanup = () => { + window.removeEventListener('error', handleImageError, true); + if (removeImageErrorListener === cleanup) { + removeImageErrorListener = null; + } + }; + removeImageErrorListener = cleanup; + + return cleanup; +} diff --git a/packages/client/src/utils/index.ts b/packages/client/src/utils/index.ts index dfa740defd..fef7b9ed36 100644 --- a/packages/client/src/utils/index.ts +++ b/packages/client/src/utils/index.ts @@ -1,3 +1,4 @@ export * from './utils'; export * from './theme'; +export * from './cloudfront'; export { default as logger } from './logger'; diff --git a/packages/data-provider/specs/headers-helpers.spec.ts b/packages/data-provider/specs/headers-helpers.spec.ts index 4df7a2f934..60a0d291e0 100644 --- a/packages/data-provider/specs/headers-helpers.spec.ts +++ b/packages/data-provider/specs/headers-helpers.spec.ts @@ -1,5 +1,5 @@ import axios from 'axios'; -import { setTokenHeader } from '../src/headers-helpers'; +import { getTokenHeader, setTokenHeader } from '../src/headers-helpers'; describe('setTokenHeader', () => { afterEach(() => { @@ -9,12 +9,14 @@ describe('setTokenHeader', () => { it('sets the Authorization header with a Bearer token', () => { setTokenHeader('my-token'); expect(axios.defaults.headers.common['Authorization']).toBe('Bearer my-token'); + expect(getTokenHeader()).toBe('Bearer my-token'); }); it('deletes the Authorization header when called with undefined', () => { axios.defaults.headers.common['Authorization'] = 'Bearer old-token'; setTokenHeader(undefined); expect(axios.defaults.headers.common['Authorization']).toBeUndefined(); + expect(getTokenHeader()).toBeUndefined(); }); it('is a no-op when clearing an already absent header', () => { diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 5c4e1c6ce0..b003f6137c 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -1103,6 +1103,12 @@ export type TStartupConfig = { scraperProvider?: ScraperProviders; rerankerType?: RerankerTypes; }; + cloudFront?: { + cookieRefresh?: { + endpoint: string; + domain: string; + }; + }; mcpServers?: Record< string, { diff --git a/packages/data-provider/src/headers-helpers.ts b/packages/data-provider/src/headers-helpers.ts index fa24b36997..b591ff4bc6 100644 --- a/packages/data-provider/src/headers-helpers.ts +++ b/packages/data-provider/src/headers-helpers.ts @@ -11,3 +11,8 @@ export function setTokenHeader(token: string | undefined) { axios.defaults.headers.common['Authorization'] = 'Bearer ' + token; } } + +export function getTokenHeader(): string | undefined { + const authorization = axios.defaults.headers.common['Authorization']; + return typeof authorization === 'string' ? authorization : undefined; +}