From 77491439c282192a078fdf349f0fecd686490294 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 19 Mar 2026 11:20:50 -0700 Subject: [PATCH] mlxrunner: support partial match on pure transformer caches Previously, a partial match within a node's edge would truncate the path to the parent snapshot - effectively making all cache types behave as recurrent caches. Caches with only transformer layers can rewind to arbitrary boundary so this restores this capability to improve cache hits --- x/mlxrunner/cache.go | 51 ++++++------ x/mlxrunner/cache/cache.go | 28 +++++-- x/mlxrunner/cache/recurrent.go | 8 +- x/mlxrunner/cache/recurrent_test.go | 34 ++++---- x/mlxrunner/cache_test.go | 118 ++++++++++++++++++---------- 5 files changed, 140 insertions(+), 99 deletions(-) diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go index a5709101d..f403b5dfc 100644 --- a/x/mlxrunner/cache.go +++ b/x/mlxrunner/cache.go @@ -93,21 +93,8 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1]) } - // Check for partial match within a node's edge — truncate path - // to the parent boundary. snapshot() will split the node and - // create the branch point during prefill when caches are ready. - partialMatch := false - if len(matchPath) > 1 { - lastNode := matchPath[len(matchPath)-1] - matchedInEdge := matched - lastNode.startOffset() - if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) { - matchPath = matchPath[:len(matchPath)-1] - partialMatch = true - } - } - // Switch to the matched path, paging in/out as needed. - c.switchToPath(matchPath) + c.switchToPath(matchPath, matched) // switchToPath aligns caches to a common offset prefix := c.minCacheOffset() @@ -116,7 +103,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { // Schedule a snapshot at the branch point during prefill so future // requests diverging here can restore instead of re-evaluating. var snapshotAt int - if partialMatch || (prefix == 0 && matched > 0) { + if prefix < matched { snapshotAt = matched } @@ -142,7 +129,7 @@ func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession { // switchToPath transitions from the current active path to a new path, // paging out diverging segments and paging in the new path. -func (c *kvCache) switchToPath(newPath []*trieNode) { +func (c *kvCache) switchToPath(newPath []*trieNode, matched int) { defer c.enforceEvictionPolicy() // Find common ancestor index. @@ -167,7 +154,10 @@ func (c *kvCache) switchToPath(newPath []*trieNode) { // non-leaf nodes here would produce wrong results for non-rewindable // caches (e.g. RecurrentCache) whose state reflects the leaf, not // the intermediate boundary. - if leaf := len(c.activePath) - 1; leaf >= commonLen { + leaf := len(c.activePath) - 1 + leafDiverges := leaf >= commonLen + leafNeedsRewind := matched < c.activePath[leaf].endOffset + if leafDiverges || leafNeedsRewind { node := c.activePath[leaf] if !node.hasAllSnapshots() { fromOffset := node.startOffset() @@ -184,14 +174,16 @@ func (c *kvCache) switchToPath(newPath []*trieNode) { } } - // Rewind each cache to the ancestor offset or free it. Freed - // caches (e.g. RecurrentCache that can't rewind) will be restored - // from snapshots during page-in. + // Rewind each cache to the target offset or free it. When matched + // falls within the ancestor's range (same-path case), we rewind + // directly to the match point. Otherwise we rewind to the ancestor + // and let page-in bring us forward to matched. + rewindTarget := min(ancestorOffset, matched) for _, kv := range c.caches { if kv == nil { continue } - if !kv.Restore(nil, ancestorOffset) { + if !kv.Restore(nil, rewindTarget) { kv.Free() } } @@ -199,10 +191,12 @@ func (c *kvCache) switchToPath(newPath []*trieNode) { // Page in — walk the full new path, restoring from snapshots. // Freed caches naturally pick up the first available snapshot. // Caches already past a node skip it via offset check. +pageIn: for _, node := range newPath { - if len(node.snapshots) == 0 { + if !node.hasSnapshots() { continue } + nodeTarget := min(node.endOffset, matched) for j, kv := range c.caches { if kv == nil { continue @@ -210,19 +204,18 @@ func (c *kvCache) switchToPath(newPath []*trieNode) { if j >= len(node.snapshots) || node.snapshots[j] == nil { continue } - if kv.Offset() >= node.endOffset { + if kv.Offset() >= nodeTarget { continue } - if !kv.Restore(node.snapshots[j], node.endOffset) { - slog.Warn("cache restore failure during page-in, freeing all caches", "layer", j, "offset", node.startOffset()) - c.freeAll() - c.activePath = []*trieNode{c.root} - return + if !kv.Restore(node.snapshots[j], nodeTarget) { + // Restore failed — stop page-in and let alignment + // bring all caches to a consistent offset. + break pageIn } } if node.endOffset > ancestorOffset { pageInCount++ - logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), node.endOffset)) + logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget)) } } diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go index 8e024115d..87b42e6cc 100644 --- a/x/mlxrunner/cache/cache.go +++ b/x/mlxrunner/cache/cache.go @@ -17,7 +17,8 @@ type Cache interface { Snapshot(fromOffset int) Snapshot // Restore brings the cache to target. If snapshot is nil, rewinds - // using the cache's own live state. + // using the cache's own live state. Returns false if the target is + // unreachable (e.g. target > current offset, or negative). Restore(snapshot Snapshot, target int) bool // Merge combines two sequential snapshots [a,b) and [b,c) into [a,c). @@ -122,17 +123,21 @@ func (c *KVCache) Snapshot(fromOffset int) Snapshot { } func (c *KVCache) Restore(snapshot Snapshot, target int) bool { + if target < 0 { + return false + } + if snapshot == nil { - // Rewind using live state — just clamp offset. - target = max(0, min(target, c.offset)) + if target > c.offset { + return false + } c.offset = target return true } snap := snapshot.(*kvSnapshot) - // Check that the cache has data up to the snapshot's starting point. - if c.offset < snap.fromOffset { + if target > snap.toOffset || c.offset < snap.fromOffset { return false } @@ -354,7 +359,14 @@ func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot { } func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool { + if target < 0 { + return false + } + if snapshot == nil { + if target >= c.offset { + return target == c.offset + } // Live rewind is only safe when the buffer hasn't filled yet // (offset <= maxSize). Once the window has shifted, rewinding // leaves fewer than maxSize trailing tokens to attend to — @@ -362,7 +374,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool { if c.offset > c.maxSize { return false } - target = max(0, min(target, c.offset)) c.offset = target c.idx = target return true @@ -370,6 +381,10 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool { snap := snapshot.(*rotatingSnapshot) + if target > snap.toOffset { + return false + } + // Reject if clamping would leave an incomplete window. if target < snap.toOffset && snap.toOffset > c.maxSize { return false @@ -388,7 +403,6 @@ func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool { // Clamp to target if needed. if target < c.offset { - target = max(0, target) c.offset = target c.idx = target } diff --git a/x/mlxrunner/cache/recurrent.go b/x/mlxrunner/cache/recurrent.go index 4025c69a3..2c02ff3d4 100644 --- a/x/mlxrunner/cache/recurrent.go +++ b/x/mlxrunner/cache/recurrent.go @@ -150,10 +150,10 @@ func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool { snap := snapshot.(*recurrentSnapshot) - // Recurrent state encodes all tokens up to snap.offset. Restoring - // to a target before that would leave stale state from tokens - // [target, snap.offset) baked in. Only allow restoring forward. - if target < snap.offset { + // Recurrent snapshots encode cumulative state up to exactly + // snap.offset. Target must match — rewinding would leave stale + // state, and advancing isn't possible without feeding tokens. + if target != snap.offset { return false } diff --git a/x/mlxrunner/cache/recurrent_test.go b/x/mlxrunner/cache/recurrent_test.go index 64d482593..ef8b7f7a3 100644 --- a/x/mlxrunner/cache/recurrent_test.go +++ b/x/mlxrunner/cache/recurrent_test.go @@ -6,39 +6,35 @@ import ( "github.com/ollama/ollama/x/mlxrunner/mlx" ) -// TestRecurrentCacheRestoreDirectionality verifies that RecurrentCache only -// allows restoring forward (target >= snapshot offset), never backward. -func TestRecurrentCacheRestoreDirectionality(t *testing.T) { +// TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore +// only succeeds when target exactly matches the snapshot's offset. Recurrent +// state is cumulative, so it can't be rewound or fast-forwarded. +func TestRecurrentCacheRestoreExactOffset(t *testing.T) { skipIfNoMLX(t) c := NewRecurrentCache(3, 12, 4, 8, 8) _ = c.ConvState(1, mlx.DTypeFloat16) _ = c.DeltaState(1, mlx.DTypeFloat16) c.Advance(10) - snap := c.Snapshot(0) + snap := c.Snapshot(0) // snap.offset == 10 - c.Advance(5) // now at 15 + c.Advance(5) // cache now at 15 - // Restore backward should fail. + // target < snap.offset: fails (can't rewind past snapshot) if c.Restore(snap, 5) { - t.Fatal("Restore(snap, 5) should fail — target < snap.offset") + t.Fatal("Restore(snap, 5) should fail — target != snap.offset") } - // Restore to exact snap offset should succeed. + // target > snap.offset: fails (can't advance without feeding tokens) + if c.Restore(snap, 15) { + t.Fatal("Restore(snap, 15) should fail — target != snap.offset") + } + + // target == snap.offset: succeeds if !c.Restore(snap, 10) { - t.Fatal("Restore(snap, 10) should succeed") + t.Fatal("Restore(snap, 10) should succeed — target == snap.offset") } if c.Offset() != 10 { t.Fatalf("offset = %d, want 10", c.Offset()) } - - // Restore forward (target > snap offset) should succeed, offset = snap.offset. - snap2 := c.Snapshot(0) - if !c.Restore(snap2, 15) { - t.Fatal("Restore(snap, 15) should succeed") - } - // Recurrent state is at snap.offset (10), not target (15). - if c.Offset() != 10 { - t.Fatalf("offset = %d, want 10 (snap offset)", c.Offset()) - } } diff --git a/x/mlxrunner/cache_test.go b/x/mlxrunner/cache_test.go index 13524b124..21e2f53ce 100644 --- a/x/mlxrunner/cache_test.go +++ b/x/mlxrunner/cache_test.go @@ -79,20 +79,20 @@ func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot { } func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool { + if target < 0 { + return false + } + if snapshot == nil { - // Rewind live state. - if target < 0 { - target = 0 - } if target > len(c.tokens) { - target = len(c.tokens) + return false } c.tokens = c.tokens[:target] return true } s := snapshot.(*fakeSnapshot) - if len(c.tokens) < s.from { - return false // don't have base data up to snapshot start + if target > s.to || len(c.tokens) < s.from { + return false } c.tokens = append(c.tokens[:s.from], s.tokens...) if target < len(c.tokens) { @@ -196,9 +196,13 @@ func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot { } func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool { + if target < 0 { + return false + } + if snapshot == nil { - if target == len(c.tokens) { - return true + if target >= len(c.tokens) { + return target == len(c.tokens) } // Live rewind only works when buffer hasn't filled (offset <= maxSize). if len(c.tokens) > c.maxSize { @@ -208,6 +212,14 @@ func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bo return true } s := snapshot.(*fakeSnapshot) + if target > s.to { + return false + } + // Reject if clamping would leave an incomplete window + // (matches RotatingKVCache behavior). + if target < s.to && s.to > c.maxSize { + return false + } c.tokens = slices.Clone(s.tokens) if target < len(c.tokens) { c.tokens = c.tokens[:target] @@ -268,8 +280,8 @@ func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool { return target == len(c.tokens) // can only no-op } s := snapshot.(*fakeSnapshot) - if target < s.to { - return false // can't go backward + if target != s.to { + return false // cumulative state requires exact match } c.tokens = slices.Clone(s.tokens) return true @@ -294,9 +306,10 @@ type feedableCache interface { // testEnv encapsulates a kvCache and its fake caches for a test scenario. type testEnv struct { - kvc *kvCache - caches []cache.Cache // typed references for assertions - tracker *snapshotTracker + kvc *kvCache + caches []cache.Cache // typed references for assertions + tracker *snapshotTracker + rewindable bool // true when all caches support arbitrary Restore(nil, target) } // newTransformerEnv creates a test environment with a single rewindable cache @@ -305,23 +318,28 @@ func newTransformerEnv() *testEnv { tracker := &snapshotTracker{} caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}} return &testEnv{ - kvc: &kvCache{caches: caches}, - caches: caches, - tracker: tracker, + kvc: &kvCache{caches: caches}, + caches: caches, + tracker: tracker, + rewindable: true, } } // newSlidingWindowEnv creates a test environment with one rewindable cache and -// one sliding window cache (Mistral-style architecture). +// one sliding window cache (Mistral-style architecture). The sliding window +// maxSize is set small enough that test sequences fill it, making +// Restore(nil, target) fail — the same behavior as production models where +// the window fills after a few turns. func newSlidingWindowEnv() *testEnv { tr := &snapshotTracker{} rc := &fakeRewindableCache{tracker: tr} - sw := &fakeSlidingWindowCache{maxSize: 32, tracker: tr} + sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr} caches := []cache.Cache{rc, sw} return &testEnv{ - kvc: &kvCache{caches: caches}, - caches: caches, - tracker: tr, + kvc: &kvCache{caches: caches}, + caches: caches, + tracker: tr, + rewindable: false, } } @@ -333,9 +351,10 @@ func newRecurrentEnv() *testEnv { nrc := &fakeRecurrentCache{tracker: tr} caches := []cache.Cache{rc, nrc} return &testEnv{ - kvc: &kvCache{caches: caches}, - caches: caches, - tracker: tr, + kvc: &kvCache{caches: caches}, + caches: caches, + tracker: tr, + rewindable: false, } } @@ -590,15 +609,24 @@ func TestBranchCreationAndReuse(t *testing.T) { } // Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A. - // Partial match in A's edge triggers snapshotOffset. + // For rewindable caches, switchToPath rewinds to the match point + // so only the non-matching suffix needs evaluation. For non-rewindable + // caches (RecurrentCache), the rewind fails and freeAll fires. resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31}) - if resB.snapshotOffset != 5 { - t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset) - } - // Cache was rewound to 0 (partial match truncates path to root), - // so all tokens were re-evaluated. - if len(resB.remaining) != 8 { - t.Fatalf("B: remaining = %d, want 8", len(resB.remaining)) + if env.rewindable { + if resB.snapshotOffset != 0 { + t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset) + } + if len(resB.remaining) != 3 { + t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining)) + } + } else { + if resB.snapshotOffset != 5 { + t.Fatalf("B: snapshotOffset = %d, want 5", resB.snapshotOffset) + } + if len(resB.remaining) != 8 { + t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining)) + } } env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31}) @@ -635,14 +663,24 @@ func TestExactMatchSeedBehavior(t *testing.T) { simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}) // Request B: identical prompt. Holdback means matched=4, partial in - // the 5-token edge, so path truncates to root and all tokens are - // re-evaluated. snapshotOffset should be set at the holdback point. + // the 5-token edge. For rewindable caches, switchToPath rewinds to + // offset 4, so only the held-back token needs re-evaluation. For + // non-rewindable caches, the rewind fails and freeAll fires. resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21}) - if len(resB.remaining) != 5 { - t.Fatalf("B: remaining = %d, want 5 (full re-eval due to holdback)", len(resB.remaining)) - } - if resB.snapshotOffset != 4 { - t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset) + if env.rewindable { + if len(resB.remaining) != 1 { + t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining)) + } + if resB.snapshotOffset != 0 { + t.Fatalf("B: snapshotOffset = %d, want 0 (rewind succeeded)", resB.snapshotOffset) + } + } else { + if len(resB.remaining) != 5 { + t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining)) + } + if resB.snapshotOffset != 4 { + t.Fatalf("B: snapshotOffset = %d, want 4", resB.snapshotOffset) + } } env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})