From a2105df07e8d816ade9f325f29a44a45cb5e4d99 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Sat, 9 May 2026 12:37:14 -0400 Subject: [PATCH] fix: optimize tool token cache alignment --- packages/api/src/agents/toolTokens.spec.ts | 103 ++++++++++++++++++--- packages/api/src/agents/toolTokens.ts | 60 +++++++----- packages/data-provider/src/config.ts | 2 +- 3 files changed, 129 insertions(+), 36 deletions(-) diff --git a/packages/api/src/agents/toolTokens.spec.ts b/packages/api/src/agents/toolTokens.spec.ts index 9a33285007..9ed7ea39f5 100644 --- a/packages/api/src/agents/toolTokens.spec.ts +++ b/packages/api/src/agents/toolTokens.spec.ts @@ -10,16 +10,30 @@ import type { GenericTool, LCTool, LCToolRegistry, TokenCounter } from '@librech import { collectToolSchemas, computeToolSchemaTokens, getOrComputeToolTokens } from './toolTokens'; -/* ---------- Mock standardCache with hoisted get/set for per-test overrides ---------- */ +/* ---------- Mock standardCache with hoisted methods for per-test overrides ---------- */ const mockCacheStore = new Map(); const mockGet = jest.fn((key: string) => Promise.resolve(mockCacheStore.get(key))); +const mockGetMany = jest.fn((keys: string[]) => + Promise.resolve(keys.map((key) => mockCacheStore.get(key))), +); const mockSet = jest.fn((key: string, value: unknown) => { mockCacheStore.set(key, value); return Promise.resolve(true); }); +const mockSetMany = jest.fn((entries: Array<{ key: string; value: unknown }>) => { + for (const { key, value } of entries) { + mockCacheStore.set(key, value); + } + return Promise.resolve(entries.map(() => true)); +}); jest.mock('~/cache', () => ({ - standardCache: jest.fn(() => ({ get: mockGet, set: mockSet })), + standardCache: jest.fn(() => ({ + get: mockGet, + getMany: mockGetMany, + set: mockSet, + setMany: mockSetMany, + })), })); jest.mock('@librechat/data-schemas', () => ({ @@ -68,11 +82,24 @@ function hasCacheKey(fragment: string): boolean { beforeEach(() => { mockCacheStore.clear(); + mockGet.mockClear(); + mockGetMany.mockClear(); + mockSet.mockClear(); + mockSetMany.mockClear(); mockGet.mockImplementation((key: string) => Promise.resolve(mockCacheStore.get(key))); + mockGetMany.mockImplementation((keys: string[]) => + Promise.resolve(keys.map((key) => mockCacheStore.get(key))), + ); mockSet.mockImplementation((key: string, value: unknown) => { mockCacheStore.set(key, value); return Promise.resolve(true); }); + mockSetMany.mockImplementation((entries: Array<{ key: string; value: unknown }>) => { + for (const { key, value } of entries) { + mockCacheStore.set(key, value); + } + return Promise.resolve(entries.map(() => true)); + }); }); /* ========================================================================= */ @@ -331,23 +358,38 @@ describe('getOrComputeToolTokens', () => { expect(counter.mock.calls.length).toBe(callCountAfterFirst); }); - it('applies different multipliers while separating provider namespaces', async () => { + it('applies different multipliers while reusing raw tokenizer counts', async () => { const defs = [makeToolDef('tool')]; + const counter = jest.fn(fakeTokenCounter); const openai = await getOrComputeToolTokens({ toolDefinitions: defs, provider: Providers.OPENAI, - tokenCounter: fakeTokenCounter, + tokenCounter: counter, }); + const callsAfterOpenAI = counter.mock.calls.length; const anthropic = await getOrComputeToolTokens({ toolDefinitions: defs, provider: Providers.ANTHROPIC, - tokenCounter: fakeTokenCounter, + tokenCounter: counter, }); expect(openai).not.toBe(anthropic); - expect(mockCacheStore.size).toBe(2); + expect(counter.mock.calls.length).toBe(callsAfterOpenAI); + expect(mockCacheStore.size).toBe(1); + }); + + it('reads per-tool cache entries in one batch', async () => { + await getOrComputeToolTokens({ + toolDefinitions: [makeToolDef('tool_a'), makeToolDef('tool_b')], + provider: Providers.OPENAI, + tokenCounter: fakeTokenCounter, + }); + + expect(mockGetMany).toHaveBeenCalledTimes(1); + expect(mockGetMany.mock.calls[0][0]).toHaveLength(2); + expect(mockGet).not.toHaveBeenCalled(); }); it('only computes new tools when tool set grows', async () => { @@ -410,7 +452,7 @@ describe('getOrComputeToolTokens', () => { expect(mockCacheStore.size).toBe(2); }); - it('separates cache namespaces when provider or model changes', async () => { + it('reuses cache across models with the same tokenizer namespace', async () => { const counter = jest.fn(fakeTokenCounter); const defs = [makeToolDef('tool')]; @@ -425,7 +467,7 @@ describe('getOrComputeToolTokens', () => { await getOrComputeToolTokens({ toolDefinitions: defs, provider: Providers.OPENAI, - clientOptions: { model: 'gpt-4o' }, + clientOptions: { model: 'gpt-4.1' }, tokenCounter: counter, }); expect(counter.mock.calls.length).toBe(callsAfterFirst); @@ -441,6 +483,43 @@ describe('getOrComputeToolTokens', () => { expect(mockCacheStore.size).toBe(2); }); + it('separates raw cache entries across tokenizer namespaces', async () => { + const counter = jest.fn(fakeTokenCounter); + const defs = [makeToolDef('tool')]; + + await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.OPENAI, + clientOptions: { model: 'gpt-4o' }, + tokenCounter: counter, + }); + const callsAfterFirst = counter.mock.calls.length; + + await getOrComputeToolTokens({ + toolDefinitions: defs, + provider: Providers.BEDROCK, + clientOptions: { model: 'anthropic.claude-3-5-sonnet-20241022-v2:0' }, + tokenCounter: counter, + }); + + expect(counter.mock.calls.length).toBe(callsAfterFirst + 1); + expect(mockCacheStore.size).toBe(2); + }); + + it('ignores invalid cached values and recomputes', async () => { + mockGetMany.mockResolvedValueOnce(['not-a-number']); + const counter = jest.fn(fakeTokenCounter); + + const result = await getOrComputeToolTokens({ + toolDefinitions: [makeToolDef('tool')], + provider: Providers.OPENAI, + tokenCounter: counter, + }); + + expect(result).toBeGreaterThan(0); + expect(counter).toHaveBeenCalledTimes(1); + }); + it('scopes cache keys by tenantId when provided', async () => { const defs = [makeToolDef('tool')]; @@ -496,7 +575,7 @@ describe('getOrComputeToolTokens', () => { }); it('falls back to compute when cache read throws', async () => { - mockGet.mockRejectedValueOnce(new Error('Redis down')); + mockGetMany.mockRejectedValueOnce(new Error('Redis down')); const result = await getOrComputeToolTokens({ toolDefinitions: [makeToolDef('tool')], @@ -505,11 +584,11 @@ describe('getOrComputeToolTokens', () => { }); expect(result).toBeGreaterThan(0); - expect(mockGet).toHaveBeenCalled(); + expect(mockGetMany).toHaveBeenCalled(); }); it('does not throw when cache write fails', async () => { - mockSet.mockRejectedValueOnce(new Error('Redis write error')); + mockSetMany.mockRejectedValueOnce(new Error('Redis write error')); const result = await getOrComputeToolTokens({ toolDefinitions: [makeToolDef('tool_write_fail')], @@ -518,7 +597,7 @@ describe('getOrComputeToolTokens', () => { }); expect(result).toBeGreaterThan(0); - expect(mockSet).toHaveBeenCalled(); + expect(mockSetMany).toHaveBeenCalled(); }); it('matches computeToolSchemaTokens output for same inputs', async () => { diff --git a/packages/api/src/agents/toolTokens.ts b/packages/api/src/agents/toolTokens.ts index e2005293df..e62ea463b3 100644 --- a/packages/api/src/agents/toolTokens.ts +++ b/packages/api/src/agents/toolTokens.ts @@ -54,9 +54,14 @@ function hashForCache(value: string): string { return createHash('sha256').update(value).digest('base64url').slice(0, 16); } -function getCounterNamespace(provider: Providers, clientOptions?: ClientOptions): string { +/** Mirrors @librechat/agents encodingForModel; raw cached counts are tokenizer-scoped. */ +function getCounterNamespace(clientOptions?: ClientOptions): string { const model = String((clientOptions as { model?: string } | undefined)?.model ?? ''); - return hashForCache(`${provider}:${model}`); + return model.toLowerCase().includes('claude') ? 'claude' : 'o200k_base'; +} + +function isPositiveFiniteNumber(value: unknown): value is number { + return typeof value === 'number' && Number.isFinite(value) && value > 0; } function isDirectToolDefinition(def: LCTool, discoveredTools?: ReadonlySet): boolean { @@ -195,7 +200,7 @@ export function computeToolSchemaTokens( /** * Returns tool schema tokens, using per-tool caching to avoid redundant * token counting. Each tool's raw (pre-multiplier) token count is cached - * individually, keyed by tenant, provider/model namespace, tool name/type, and + * individually, keyed by tenant, tokenizer namespace, tool name/type, and * schema fingerprint. The provider-specific multiplier is applied to the sum. * * Returns 0 if there are no tools. @@ -225,7 +230,7 @@ export async function getOrComputeToolTokens({ } const keyPrefix = tenantId ? `${tenantId}:` : ''; - const counterNamespace = getCounterNamespace(provider, clientOptions); + const counterNamespace = getCounterNamespace(clientOptions); let cache: Keyv | undefined; try { @@ -234,38 +239,47 @@ export async function getOrComputeToolTokens({ logger.debug('[toolTokens] Cache init failed, computing fresh', err); } + const keyedEntries = entries.map(({ cacheKey, json }) => ({ + key: `${keyPrefix}${counterNamespace}:${cacheKey}:${hashForCache(json)}`, + json, + })); + const cachedCounts: Array = new Array(keyedEntries.length); + + if (cache) { + const activeCache = cache; + try { + const readResults = await activeCache.getMany(keyedEntries.map(({ key }) => key)); + for (let i = 0; i < readResults.length; i++) { + if (isPositiveFiniteNumber(readResults[i])) { + cachedCounts[i] = readResults[i]; + } + } + } catch (err) { + logger.debug('[toolTokens] Cache batch read failed, computing misses fresh', err); + } + } + let rawTotal = 0; const toWrite: Array<{ key: string; value: number }> = []; - for (const { cacheKey, json } of entries) { - const fullKey = `${keyPrefix}${counterNamespace}:${cacheKey}:${hashForCache(json)}`; - let rawCount: number | undefined; + for (let i = 0; i < keyedEntries.length; i++) { + const { key, json } = keyedEntries[i]; + let rawCount = cachedCounts[i]; - if (cache) { - try { - rawCount = (await cache.get(fullKey)) as number | undefined; - } catch { - // Cache read failed for this tool — will compute fresh - } - } - - if (rawCount == null || rawCount <= 0) { + if (rawCount == null) { rawCount = tokenCounter(new SystemMessage(json)); if (rawCount > 0 && cache) { - toWrite.push({ key: fullKey, value: rawCount }); + toWrite.push({ key, value: rawCount }); } } rawTotal += rawCount; } - // Fire-and-forget cache writes for newly computed tools if (cache && toWrite.length > 0) { - for (const { key, value } of toWrite) { - cache.set(key, value).catch((err: unknown) => { - logger.debug(`[toolTokens] Cache write failed for ${key}`, err); - }); - } + cache.setMany(toWrite).catch((err: unknown) => { + logger.debug('[toolTokens] Cache batch write failed', err); + }); } const multiplier = getToolTokenMultiplier(provider, clientOptions); diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index 9904d82649..01f9e9bef1 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -1839,7 +1839,7 @@ export enum CacheKeys { ADMIN_OAUTH_EXCHANGE = 'ADMIN_OAUTH_EXCHANGE', /** * Key for cached tool schema token counts. - * Entries are keyed by tenant, provider/model namespace, tool name/type, and schema fingerprint. + * Entries are keyed by tenant, tokenizer namespace, tool name/type, and schema fingerprint. */ TOOL_TOKENS = 'TOOL_TOKENS', }