fix: bypass auth refresh for CloudFront cookie retry

This commit is contained in:
Danny Avila 2026-05-12 11:07:46 -04:00
parent b26ddfb6bf
commit ceb40c55f2
4 changed files with 87 additions and 22 deletions

View file

@ -1,13 +1,11 @@
import { fireEvent, waitFor } from '@testing-library/react';
const mockRequestPost = jest.fn();
const mockApiBaseUrl = jest.fn(() => '');
const mockGetTokenHeader = jest.fn(() => 'Bearer test-token');
jest.mock('librechat-data-provider', () => ({
apiBaseUrl: () => mockApiBaseUrl(),
request: {
post: (...args: unknown[]) => mockRequestPost(...args),
},
getTokenHeader: () => mockGetTokenHeader(),
}));
import {
@ -26,15 +24,30 @@ const cloudFrontStartupConfig = {
},
};
function refreshResponse(payload: { ok?: boolean }, ok = true): Response {
return {
ok,
json: () => Promise.resolve(payload),
} as Response;
}
describe('CloudFront cookie refresh helpers', () => {
let fetchMock: jest.MockedFunction<typeof fetch>;
const originalFetch = global.fetch;
beforeEach(() => {
mockRequestPost.mockReset();
mockApiBaseUrl.mockReturnValue('');
mockGetTokenHeader.mockReturnValue('Bearer test-token');
fetchMock = jest.fn(() =>
Promise.resolve(refreshResponse({ ok: true })),
) as jest.MockedFunction<typeof fetch>;
global.fetch = fetchMock;
configureCloudFrontCookieRefresh(undefined);
jest.spyOn(Date, 'now').mockReturnValue(1_700_000_000_000);
});
afterEach(() => {
global.fetch = originalFetch;
jest.restoreAllMocks();
});
@ -43,12 +56,12 @@ describe('CloudFront cookie refresh helpers', () => {
await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(false);
expect(mockRequestPost).not.toHaveBeenCalled();
expect(fetchMock).not.toHaveBeenCalled();
});
it('dedupes concurrent refresh calls', async () => {
let resolveRefresh: ((value: { ok: boolean }) => void) | undefined;
mockRequestPost.mockReturnValue(
let resolveRefresh: ((value: Response) => void) | undefined;
fetchMock.mockReturnValue(
new Promise((resolve) => {
resolveRefresh = resolve;
}),
@ -58,21 +71,40 @@ describe('CloudFront cookie refresh helpers', () => {
const first = refreshCloudFrontCookiesOnce();
const second = refreshCloudFrontCookiesOnce();
expect(mockRequestPost).toHaveBeenCalledTimes(1);
expect(mockRequestPost).toHaveBeenCalledWith('/api/auth/cloudfront/refresh', {});
resolveRefresh?.({ ok: true });
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenCalledWith(
'/api/auth/cloudfront/refresh',
expect.objectContaining({
method: 'POST',
credentials: 'include',
headers: expect.objectContaining({ Authorization: 'Bearer test-token' }),
body: '{}',
}),
);
resolveRefresh?.(refreshResponse({ ok: true }));
await expect(first).resolves.toBe(true);
await expect(second).resolves.toBe(true);
});
it('returns false on 401 without retrying the refresh request', async () => {
fetchMock.mockResolvedValue(refreshResponse({}, false));
configureCloudFrontCookieRefresh(cloudFrontStartupConfig);
await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(false);
expect(fetchMock).toHaveBeenCalledTimes(1);
});
it('prefixes the refresh endpoint with the configured app base path', async () => {
mockApiBaseUrl.mockReturnValue('/chat');
mockRequestPost.mockResolvedValue({ ok: true });
configureCloudFrontCookieRefresh(cloudFrontStartupConfig);
await expect(refreshCloudFrontCookiesOnce()).resolves.toBe(true);
expect(mockRequestPost).toHaveBeenCalledWith('/chat/api/auth/cloudfront/refresh', {});
expect(fetchMock).toHaveBeenCalledWith(
'/chat/api/auth/cloudfront/refresh',
expect.objectContaining({ method: 'POST' }),
);
});
it('detects only the configured CloudFront domain', () => {
@ -91,7 +123,6 @@ describe('CloudFront cookie refresh helpers', () => {
});
it('retries a configured CloudFront image only once from the global listener', async () => {
mockRequestPost.mockResolvedValue({ ok: true });
const cleanup = installCloudFrontImageRetry(cloudFrontStartupConfig);
const img = document.createElement('img');
const onFailure = jest.fn();
@ -110,7 +141,7 @@ describe('CloudFront cookie refresh helpers', () => {
fireEvent.error(img);
expect(mockRequestPost).toHaveBeenCalledTimes(1);
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(onFailure).toHaveBeenCalledTimes(1);
cleanup();
@ -118,7 +149,6 @@ describe('CloudFront cookie refresh helpers', () => {
});
it('does not retry arbitrary external images', () => {
mockRequestPost.mockResolvedValue({ ok: true });
const cleanup = installCloudFrontImageRetry(cloudFrontStartupConfig);
const img = document.createElement('img');
const onFailure = jest.fn();
@ -128,7 +158,7 @@ describe('CloudFront cookie refresh helpers', () => {
fireEvent.error(img);
expect(mockRequestPost).not.toHaveBeenCalled();
expect(fetchMock).not.toHaveBeenCalled();
expect(onFailure).toHaveBeenCalledTimes(1);
cleanup();

View file

@ -1,8 +1,12 @@
import { apiBaseUrl, request, type TStartupConfig } from 'librechat-data-provider';
import { apiBaseUrl, getTokenHeader } from 'librechat-data-provider';
import type { TStartupConfig } from 'librechat-data-provider';
type CloudFrontCookieRefreshConfig = NonNullable<
NonNullable<TStartupConfig['cloudFront']>['cookieRefresh']
>;
type CloudFrontCookieRefreshResponse = {
ok?: boolean;
};
let cookieRefreshConfig: CloudFrontCookieRefreshConfig | undefined;
let refreshPromise: Promise<boolean> | null = null;
@ -84,6 +88,32 @@ function getRefreshEndpoint(endpoint: string): string {
return `${baseUrl}${endpoint.startsWith('/') ? '' : '/'}${endpoint}`;
}
async function postCloudFrontCookieRefresh(endpoint: string): Promise<boolean> {
const authorization = getTokenHeader();
const headers: Record<string, string> = {
Accept: 'application/json',
'Content-Type': 'application/json',
};
if (authorization) {
headers.Authorization = authorization;
}
const response = await fetch(endpoint, {
method: 'POST',
credentials: 'include',
headers,
body: '{}',
});
if (!response.ok) {
return false;
}
const payload = (await response.json()) as CloudFrontCookieRefreshResponse;
return payload.ok === true;
}
export function refreshCloudFrontCookiesOnce(): Promise<boolean> {
const config = getRefreshConfig();
if (!config?.endpoint) {
@ -95,9 +125,7 @@ export function refreshCloudFrontCookiesOnce(): Promise<boolean> {
}
const endpoint = getRefreshEndpoint(config.endpoint);
refreshPromise = request
.post(endpoint, {})
.then((payload: { ok?: boolean }) => payload.ok === true)
refreshPromise = postCloudFrontCookieRefresh(endpoint)
.catch(() => false)
.finally(() => {
refreshPromise = null;

View file

@ -1,5 +1,5 @@
import axios from 'axios';
import { setTokenHeader } from '../src/headers-helpers';
import { getTokenHeader, setTokenHeader } from '../src/headers-helpers';
describe('setTokenHeader', () => {
afterEach(() => {
@ -9,12 +9,14 @@ describe('setTokenHeader', () => {
it('sets the Authorization header with a Bearer token', () => {
setTokenHeader('my-token');
expect(axios.defaults.headers.common['Authorization']).toBe('Bearer my-token');
expect(getTokenHeader()).toBe('Bearer my-token');
});
it('deletes the Authorization header when called with undefined', () => {
axios.defaults.headers.common['Authorization'] = 'Bearer old-token';
setTokenHeader(undefined);
expect(axios.defaults.headers.common['Authorization']).toBeUndefined();
expect(getTokenHeader()).toBeUndefined();
});
it('is a no-op when clearing an already absent header', () => {

View file

@ -11,3 +11,8 @@ export function setTokenHeader(token: string | undefined) {
axios.defaults.headers.common['Authorization'] = 'Bearer ' + token;
}
}
export function getTokenHeader(): string | undefined {
const authorization = axios.defaults.headers.common['Authorization'];
return typeof authorization === 'string' ? authorization : undefined;
}