fix: optimize tool token cache alignment

This commit is contained in:
Danny Avila 2026-05-09 12:37:14 -04:00
parent ae8e6b33e5
commit a2105df07e
3 changed files with 129 additions and 36 deletions

View file

@ -10,16 +10,30 @@ import type { GenericTool, LCTool, LCToolRegistry, TokenCounter } from '@librech
import { collectToolSchemas, computeToolSchemaTokens, getOrComputeToolTokens } from './toolTokens';
/* ---------- Mock standardCache with hoisted get/set for per-test overrides ---------- */
/* ---------- Mock standardCache with hoisted methods for per-test overrides ---------- */
const mockCacheStore = new Map<string, unknown>();
const mockGet = jest.fn((key: string) => Promise.resolve(mockCacheStore.get(key)));
const mockGetMany = jest.fn((keys: string[]) =>
Promise.resolve(keys.map((key) => mockCacheStore.get(key))),
);
const mockSet = jest.fn((key: string, value: unknown) => {
mockCacheStore.set(key, value);
return Promise.resolve(true);
});
const mockSetMany = jest.fn((entries: Array<{ key: string; value: unknown }>) => {
for (const { key, value } of entries) {
mockCacheStore.set(key, value);
}
return Promise.resolve(entries.map(() => true));
});
jest.mock('~/cache', () => ({
standardCache: jest.fn(() => ({ get: mockGet, set: mockSet })),
standardCache: jest.fn(() => ({
get: mockGet,
getMany: mockGetMany,
set: mockSet,
setMany: mockSetMany,
})),
}));
jest.mock('@librechat/data-schemas', () => ({
@ -68,11 +82,24 @@ function hasCacheKey(fragment: string): boolean {
beforeEach(() => {
mockCacheStore.clear();
mockGet.mockClear();
mockGetMany.mockClear();
mockSet.mockClear();
mockSetMany.mockClear();
mockGet.mockImplementation((key: string) => Promise.resolve(mockCacheStore.get(key)));
mockGetMany.mockImplementation((keys: string[]) =>
Promise.resolve(keys.map((key) => mockCacheStore.get(key))),
);
mockSet.mockImplementation((key: string, value: unknown) => {
mockCacheStore.set(key, value);
return Promise.resolve(true);
});
mockSetMany.mockImplementation((entries: Array<{ key: string; value: unknown }>) => {
for (const { key, value } of entries) {
mockCacheStore.set(key, value);
}
return Promise.resolve(entries.map(() => true));
});
});
/* ========================================================================= */
@ -331,23 +358,38 @@ describe('getOrComputeToolTokens', () => {
expect(counter.mock.calls.length).toBe(callCountAfterFirst);
});
it('applies different multipliers while separating provider namespaces', async () => {
it('applies different multipliers while reusing raw tokenizer counts', async () => {
const defs = [makeToolDef('tool')];
const counter = jest.fn(fakeTokenCounter);
const openai = await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
tokenCounter: counter,
});
const callsAfterOpenAI = counter.mock.calls.length;
const anthropic = await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.ANTHROPIC,
tokenCounter: fakeTokenCounter,
tokenCounter: counter,
});
expect(openai).not.toBe(anthropic);
expect(mockCacheStore.size).toBe(2);
expect(counter.mock.calls.length).toBe(callsAfterOpenAI);
expect(mockCacheStore.size).toBe(1);
});
it('reads per-tool cache entries in one batch', async () => {
await getOrComputeToolTokens({
toolDefinitions: [makeToolDef('tool_a'), makeToolDef('tool_b')],
provider: Providers.OPENAI,
tokenCounter: fakeTokenCounter,
});
expect(mockGetMany).toHaveBeenCalledTimes(1);
expect(mockGetMany.mock.calls[0][0]).toHaveLength(2);
expect(mockGet).not.toHaveBeenCalled();
});
it('only computes new tools when tool set grows', async () => {
@ -410,7 +452,7 @@ describe('getOrComputeToolTokens', () => {
expect(mockCacheStore.size).toBe(2);
});
it('separates cache namespaces when provider or model changes', async () => {
it('reuses cache across models with the same tokenizer namespace', async () => {
const counter = jest.fn(fakeTokenCounter);
const defs = [makeToolDef('tool')];
@ -425,7 +467,7 @@ describe('getOrComputeToolTokens', () => {
await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.OPENAI,
clientOptions: { model: 'gpt-4o' },
clientOptions: { model: 'gpt-4.1' },
tokenCounter: counter,
});
expect(counter.mock.calls.length).toBe(callsAfterFirst);
@ -441,6 +483,43 @@ describe('getOrComputeToolTokens', () => {
expect(mockCacheStore.size).toBe(2);
});
it('separates raw cache entries across tokenizer namespaces', async () => {
const counter = jest.fn(fakeTokenCounter);
const defs = [makeToolDef('tool')];
await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.OPENAI,
clientOptions: { model: 'gpt-4o' },
tokenCounter: counter,
});
const callsAfterFirst = counter.mock.calls.length;
await getOrComputeToolTokens({
toolDefinitions: defs,
provider: Providers.BEDROCK,
clientOptions: { model: 'anthropic.claude-3-5-sonnet-20241022-v2:0' },
tokenCounter: counter,
});
expect(counter.mock.calls.length).toBe(callsAfterFirst + 1);
expect(mockCacheStore.size).toBe(2);
});
it('ignores invalid cached values and recomputes', async () => {
mockGetMany.mockResolvedValueOnce(['not-a-number']);
const counter = jest.fn(fakeTokenCounter);
const result = await getOrComputeToolTokens({
toolDefinitions: [makeToolDef('tool')],
provider: Providers.OPENAI,
tokenCounter: counter,
});
expect(result).toBeGreaterThan(0);
expect(counter).toHaveBeenCalledTimes(1);
});
it('scopes cache keys by tenantId when provided', async () => {
const defs = [makeToolDef('tool')];
@ -496,7 +575,7 @@ describe('getOrComputeToolTokens', () => {
});
it('falls back to compute when cache read throws', async () => {
mockGet.mockRejectedValueOnce(new Error('Redis down'));
mockGetMany.mockRejectedValueOnce(new Error('Redis down'));
const result = await getOrComputeToolTokens({
toolDefinitions: [makeToolDef('tool')],
@ -505,11 +584,11 @@ describe('getOrComputeToolTokens', () => {
});
expect(result).toBeGreaterThan(0);
expect(mockGet).toHaveBeenCalled();
expect(mockGetMany).toHaveBeenCalled();
});
it('does not throw when cache write fails', async () => {
mockSet.mockRejectedValueOnce(new Error('Redis write error'));
mockSetMany.mockRejectedValueOnce(new Error('Redis write error'));
const result = await getOrComputeToolTokens({
toolDefinitions: [makeToolDef('tool_write_fail')],
@ -518,7 +597,7 @@ describe('getOrComputeToolTokens', () => {
});
expect(result).toBeGreaterThan(0);
expect(mockSet).toHaveBeenCalled();
expect(mockSetMany).toHaveBeenCalled();
});
it('matches computeToolSchemaTokens output for same inputs', async () => {

View file

@ -54,9 +54,14 @@ function hashForCache(value: string): string {
return createHash('sha256').update(value).digest('base64url').slice(0, 16);
}
function getCounterNamespace(provider: Providers, clientOptions?: ClientOptions): string {
/** Mirrors @librechat/agents encodingForModel; raw cached counts are tokenizer-scoped. */
function getCounterNamespace(clientOptions?: ClientOptions): string {
const model = String((clientOptions as { model?: string } | undefined)?.model ?? '');
return hashForCache(`${provider}:${model}`);
return model.toLowerCase().includes('claude') ? 'claude' : 'o200k_base';
}
function isPositiveFiniteNumber(value: unknown): value is number {
return typeof value === 'number' && Number.isFinite(value) && value > 0;
}
function isDirectToolDefinition(def: LCTool, discoveredTools?: ReadonlySet<string>): boolean {
@ -195,7 +200,7 @@ export function computeToolSchemaTokens(
/**
* Returns tool schema tokens, using per-tool caching to avoid redundant
* token counting. Each tool's raw (pre-multiplier) token count is cached
* individually, keyed by tenant, provider/model namespace, tool name/type, and
* individually, keyed by tenant, tokenizer namespace, tool name/type, and
* schema fingerprint. The provider-specific multiplier is applied to the sum.
*
* Returns 0 if there are no tools.
@ -225,7 +230,7 @@ export async function getOrComputeToolTokens({
}
const keyPrefix = tenantId ? `${tenantId}:` : '';
const counterNamespace = getCounterNamespace(provider, clientOptions);
const counterNamespace = getCounterNamespace(clientOptions);
let cache: Keyv | undefined;
try {
@ -234,38 +239,47 @@ export async function getOrComputeToolTokens({
logger.debug('[toolTokens] Cache init failed, computing fresh', err);
}
const keyedEntries = entries.map(({ cacheKey, json }) => ({
key: `${keyPrefix}${counterNamespace}:${cacheKey}:${hashForCache(json)}`,
json,
}));
const cachedCounts: Array<number | undefined> = new Array(keyedEntries.length);
if (cache) {
const activeCache = cache;
try {
const readResults = await activeCache.getMany<number>(keyedEntries.map(({ key }) => key));
for (let i = 0; i < readResults.length; i++) {
if (isPositiveFiniteNumber(readResults[i])) {
cachedCounts[i] = readResults[i];
}
}
} catch (err) {
logger.debug('[toolTokens] Cache batch read failed, computing misses fresh', err);
}
}
let rawTotal = 0;
const toWrite: Array<{ key: string; value: number }> = [];
for (const { cacheKey, json } of entries) {
const fullKey = `${keyPrefix}${counterNamespace}:${cacheKey}:${hashForCache(json)}`;
let rawCount: number | undefined;
for (let i = 0; i < keyedEntries.length; i++) {
const { key, json } = keyedEntries[i];
let rawCount = cachedCounts[i];
if (cache) {
try {
rawCount = (await cache.get(fullKey)) as number | undefined;
} catch {
// Cache read failed for this tool — will compute fresh
}
}
if (rawCount == null || rawCount <= 0) {
if (rawCount == null) {
rawCount = tokenCounter(new SystemMessage(json));
if (rawCount > 0 && cache) {
toWrite.push({ key: fullKey, value: rawCount });
toWrite.push({ key, value: rawCount });
}
}
rawTotal += rawCount;
}
// Fire-and-forget cache writes for newly computed tools
if (cache && toWrite.length > 0) {
for (const { key, value } of toWrite) {
cache.set(key, value).catch((err: unknown) => {
logger.debug(`[toolTokens] Cache write failed for ${key}`, err);
});
}
cache.setMany(toWrite).catch((err: unknown) => {
logger.debug('[toolTokens] Cache batch write failed', err);
});
}
const multiplier = getToolTokenMultiplier(provider, clientOptions);

View file

@ -1839,7 +1839,7 @@ export enum CacheKeys {
ADMIN_OAUTH_EXCHANGE = 'ADMIN_OAUTH_EXCHANGE',
/**
* Key for cached tool schema token counts.
* Entries are keyed by tenant, provider/model namespace, tool name/type, and schema fingerprint.
* Entries are keyed by tenant, tokenizer namespace, tool name/type, and schema fingerprint.
*/
TOOL_TOKENS = 'TOOL_TOKENS',
}