mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
mlx: avoid status timeout during inference (#16086)
The MLX runner now routes model work through a locked worker thread. Status also used that worker only to sample memory, so a scheduler health ping could sit behind long prefill or generation until its 10s context expired, causing /v1/status to return 500 and the server to treat the runner as unhealthy. While Metal doesn't change VRAM reporting, CUDA does. Cache the last memory sample and make status perform only a short best-effort refresh. If the worker is busy, status returns the cached value while a single background refresh continues and updates the cache when the worker becomes available. The in-flight guard and lifecycle context keep this from spawning unbounded refreshes while preserving live VRAM refresh behavior for CUDA. Fixes #16081
This commit is contained in:
parent
d819ef0f97
commit
206b049508
3 changed files with 375 additions and 10 deletions
|
|
@ -55,6 +55,8 @@ func Execute(args []string) error {
|
|||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
})
|
||||
runnerCtx, cancelRunner := context.WithCancel(context.Background())
|
||||
defer cancelRunner()
|
||||
|
||||
runner := Runner{
|
||||
Requests: make(chan Request),
|
||||
|
|
@ -67,22 +69,30 @@ func Execute(args []string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
readMemory := func() (uint64, error) {
|
||||
return uint64(mlx.ActiveMemory() + mlx.CacheMemory()), nil
|
||||
}
|
||||
initialMemory, err := mlxthread.Call(context.Background(), worker, readMemory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
memoryCache := newStatusMemoryCache(
|
||||
runnerCtx,
|
||||
initialMemory,
|
||||
time.Now(),
|
||||
statusMemoryRefreshWait,
|
||||
func() (uint64, error) {
|
||||
return mlxthread.Call(runnerCtx, worker, readMemory)
|
||||
},
|
||||
)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
memory, err := mlxthread.Call(r.Context(), worker, func() (uint64, error) {
|
||||
return uint64(mlx.ActiveMemory() + mlx.CacheMemory()), nil
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("Failed to read MLX memory status", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||
Status: 0,
|
||||
Progress: 100,
|
||||
ContextLength: runner.contextLength,
|
||||
Memory: memory,
|
||||
Memory: memoryCache.Memory(),
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
|
|
|
|||
109
x/mlxrunner/status_memory.go
Normal file
109
x/mlxrunner/status_memory.go
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const statusMemoryRefreshWait = 50 * time.Millisecond
|
||||
|
||||
type statusMemoryRefreshFunc func() (uint64, error)
|
||||
|
||||
// statusMemoryCache keeps health checks from depending synchronously on the
|
||||
// serialized MLX worker while still refreshing memory telemetry opportunistically.
|
||||
type statusMemoryCache struct {
|
||||
done <-chan struct{}
|
||||
wait time.Duration
|
||||
refresh statusMemoryRefreshFunc
|
||||
|
||||
mu sync.Mutex
|
||||
memory uint64
|
||||
refreshedAt time.Time
|
||||
inFlight chan struct{}
|
||||
}
|
||||
|
||||
func newStatusMemoryCache(ctx context.Context, memory uint64, refreshedAt time.Time, wait time.Duration, refresh statusMemoryRefreshFunc) *statusMemoryCache {
|
||||
return &statusMemoryCache{
|
||||
done: ctx.Done(),
|
||||
wait: wait,
|
||||
refresh: refresh,
|
||||
memory: memory,
|
||||
refreshedAt: refreshedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *statusMemoryCache) Memory() uint64 {
|
||||
done := c.startRefresh()
|
||||
if c.wait <= 0 {
|
||||
<-done
|
||||
memory, _ := c.snapshot()
|
||||
return memory
|
||||
}
|
||||
|
||||
timer := time.NewTimer(c.wait)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-timer.C:
|
||||
memory, refreshedAt := c.snapshot()
|
||||
if refreshedAt.IsZero() {
|
||||
slog.Debug("using cached MLX memory status before first refresh")
|
||||
} else {
|
||||
slog.Debug("using cached MLX memory status", "stale", time.Since(refreshedAt))
|
||||
}
|
||||
return memory
|
||||
case <-c.done:
|
||||
}
|
||||
|
||||
memory, _ := c.snapshot()
|
||||
return memory
|
||||
}
|
||||
|
||||
func (c *statusMemoryCache) startRefresh() chan struct{} {
|
||||
c.mu.Lock()
|
||||
if c.inFlight != nil {
|
||||
done := c.inFlight
|
||||
c.mu.Unlock()
|
||||
return done
|
||||
}
|
||||
|
||||
refreshDone := make(chan struct{})
|
||||
c.inFlight = refreshDone
|
||||
refresh := c.refresh
|
||||
lifecycleDone := c.done
|
||||
c.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
memory, err := refresh()
|
||||
now := time.Now()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
defer close(refreshDone)
|
||||
|
||||
if err != nil {
|
||||
select {
|
||||
case <-lifecycleDone:
|
||||
default:
|
||||
slog.Debug("failed to refresh MLX memory status", "error", err)
|
||||
}
|
||||
c.inFlight = nil
|
||||
return
|
||||
}
|
||||
|
||||
c.memory = memory
|
||||
c.refreshedAt = now
|
||||
c.inFlight = nil
|
||||
}()
|
||||
|
||||
return refreshDone
|
||||
}
|
||||
|
||||
func (c *statusMemoryCache) snapshot() (uint64, time.Time) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.memory, c.refreshedAt
|
||||
}
|
||||
246
x/mlxrunner/status_memory_test.go
Normal file
246
x/mlxrunner/status_memory_test.go
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStatusMemoryCacheWaitsForFastRefresh(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
|
||||
calls.Add(1)
|
||||
return 42, nil
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 42 {
|
||||
t.Fatalf("got memory %d, want 42", got)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("refresh calls = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheSupportsBlockingWait(t *testing.T) {
|
||||
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), 0, func() (uint64, error) {
|
||||
return 42, nil
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 42 {
|
||||
t.Fatalf("got memory %d, want 42", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheReturnsCachedValueAndRefreshesLater(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
|
||||
cache := newStatusMemoryCache(ctx, 7, time.Now().Add(-time.Minute), time.Millisecond, func() (uint64, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
close(started)
|
||||
}
|
||||
select {
|
||||
case <-release:
|
||||
return 42, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
start := time.Now()
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
if elapsed := time.Since(start); elapsed > time.Second {
|
||||
t.Fatalf("cached memory lookup took too long: %s", elapsed)
|
||||
}
|
||||
|
||||
waitForRefreshStart(t, started)
|
||||
close(release)
|
||||
waitForCachedMemory(t, cache, 42)
|
||||
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("refresh calls = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheReturnsCachedValueBeforeFirstRefresh(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
cache := newStatusMemoryCache(ctx, 7, time.Time{}, time.Millisecond, func() (uint64, error) {
|
||||
close(started)
|
||||
select {
|
||||
case <-release:
|
||||
return 42, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
|
||||
waitForRefreshStart(t, started)
|
||||
close(release)
|
||||
waitForCachedMemory(t, cache, 42)
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheKeepsCachedValueWhenRefreshFails(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
|
||||
calls.Add(1)
|
||||
return 0, errors.New("refresh failed")
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("refresh calls = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheReturnsCachedValueWhenContextDone(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
cache := newStatusMemoryCache(ctx, 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
|
||||
close(started)
|
||||
<-release
|
||||
return 0, ctx.Err()
|
||||
})
|
||||
|
||||
cancel()
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
|
||||
waitForRefreshStart(t, started)
|
||||
close(release)
|
||||
waitForInflightRefresh(t, cache)
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheAllowsRefreshAfterFailure(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
return 0, errors.New("refresh failed")
|
||||
}
|
||||
return 42, nil
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
if got := cache.Memory(); got != 42 {
|
||||
t.Fatalf("got memory %d after retry, want 42", got)
|
||||
}
|
||||
if got := calls.Load(); got != 2 {
|
||||
t.Fatalf("refresh calls = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheAllowsOneInflightRefresh(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
|
||||
cache := newStatusMemoryCache(ctx, 11, time.Now().Add(-time.Minute), time.Millisecond, func() (uint64, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
close(started)
|
||||
}
|
||||
select {
|
||||
case <-release:
|
||||
return 99, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
const goroutines = 8
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan string, goroutines)
|
||||
for range goroutines {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if got := cache.Memory(); got != 11 {
|
||||
errCh <- "got non-cached memory value"
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
for err := range errCh {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
waitForRefreshStart(t, started)
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("refresh calls = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(release)
|
||||
waitForCachedMemory(t, cache, 99)
|
||||
}
|
||||
|
||||
func waitForRefreshStart(t *testing.T, started <-chan struct{}) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for refresh to start")
|
||||
}
|
||||
}
|
||||
|
||||
func waitForCachedMemory(t *testing.T, cache *statusMemoryCache, want uint64) {
|
||||
t.Helper()
|
||||
deadline := time.After(time.Second)
|
||||
for {
|
||||
got, _ := cache.snapshot()
|
||||
if got == want {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatalf("cached memory = %d, want %d", got, want)
|
||||
case <-time.After(time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForInflightRefresh(t *testing.T, cache *statusMemoryCache) {
|
||||
t.Helper()
|
||||
deadline := time.After(time.Second)
|
||||
for {
|
||||
cache.mu.Lock()
|
||||
inFlight := cache.inFlight
|
||||
cache.mu.Unlock()
|
||||
if inFlight == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatal("timeout waiting for refresh to finish")
|
||||
case <-time.After(time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue