mlx: avoid status timeout during inference (#16086)

The MLX runner now routes model work through a locked worker thread. Status also used that worker only to sample memory, so a scheduler health ping could sit behind long prefill or generation until its 10s context expired, causing /v1/status to return 500 and the server to treat the runner as unhealthy.

While Metal doesn't change VRAM reporting, CUDA does. Cache the last memory sample and make status perform only a short best-effort refresh. If the worker is busy, status returns the cached value while a single background refresh continues and updates the cache when the worker becomes available. The in-flight guard and lifecycle context keep this from spawning unbounded refreshes while preserving live VRAM refresh behavior for CUDA.

Fixes #16081
This commit is contained in:
Daniel Hiltgen 2026-05-11 16:03:38 -07:00 committed by GitHub
parent d819ef0f97
commit 206b049508
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 375 additions and 10 deletions

View file

@ -55,6 +55,8 @@ func Execute(args []string) error {
mlx.Sweep() mlx.Sweep()
mlx.ClearCache() mlx.ClearCache()
}) })
runnerCtx, cancelRunner := context.WithCancel(context.Background())
defer cancelRunner()
runner := Runner{ runner := Runner{
Requests: make(chan Request), Requests: make(chan Request),
@ -67,22 +69,30 @@ func Execute(args []string) error {
return err return err
} }
readMemory := func() (uint64, error) {
return uint64(mlx.ActiveMemory() + mlx.CacheMemory()), nil
}
initialMemory, err := mlxthread.Call(context.Background(), worker, readMemory)
if err != nil {
return err
}
memoryCache := newStatusMemoryCache(
runnerCtx,
initialMemory,
time.Now(),
statusMemoryRefreshWait,
func() (uint64, error) {
return mlxthread.Call(runnerCtx, worker, readMemory)
},
)
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
memory, err := mlxthread.Call(r.Context(), worker, func() (uint64, error) {
return uint64(mlx.ActiveMemory() + mlx.CacheMemory()), nil
})
if err != nil {
slog.Error("Failed to read MLX memory status", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(statusResponse{ if err := json.NewEncoder(w).Encode(statusResponse{
Status: 0, Status: 0,
Progress: 100, Progress: 100,
ContextLength: runner.contextLength, ContextLength: runner.contextLength,
Memory: memory, Memory: memoryCache.Memory(),
}); err != nil { }); err != nil {
slog.Error("Failed to encode response", "error", err) slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)

View file

@ -0,0 +1,109 @@
package mlxrunner
import (
"context"
"log/slog"
"sync"
"time"
)
const statusMemoryRefreshWait = 50 * time.Millisecond
type statusMemoryRefreshFunc func() (uint64, error)
// statusMemoryCache keeps health checks from depending synchronously on the
// serialized MLX worker while still refreshing memory telemetry opportunistically.
type statusMemoryCache struct {
done <-chan struct{}
wait time.Duration
refresh statusMemoryRefreshFunc
mu sync.Mutex
memory uint64
refreshedAt time.Time
inFlight chan struct{}
}
func newStatusMemoryCache(ctx context.Context, memory uint64, refreshedAt time.Time, wait time.Duration, refresh statusMemoryRefreshFunc) *statusMemoryCache {
return &statusMemoryCache{
done: ctx.Done(),
wait: wait,
refresh: refresh,
memory: memory,
refreshedAt: refreshedAt,
}
}
func (c *statusMemoryCache) Memory() uint64 {
done := c.startRefresh()
if c.wait <= 0 {
<-done
memory, _ := c.snapshot()
return memory
}
timer := time.NewTimer(c.wait)
defer timer.Stop()
select {
case <-done:
case <-timer.C:
memory, refreshedAt := c.snapshot()
if refreshedAt.IsZero() {
slog.Debug("using cached MLX memory status before first refresh")
} else {
slog.Debug("using cached MLX memory status", "stale", time.Since(refreshedAt))
}
return memory
case <-c.done:
}
memory, _ := c.snapshot()
return memory
}
func (c *statusMemoryCache) startRefresh() chan struct{} {
c.mu.Lock()
if c.inFlight != nil {
done := c.inFlight
c.mu.Unlock()
return done
}
refreshDone := make(chan struct{})
c.inFlight = refreshDone
refresh := c.refresh
lifecycleDone := c.done
c.mu.Unlock()
go func() {
memory, err := refresh()
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
defer close(refreshDone)
if err != nil {
select {
case <-lifecycleDone:
default:
slog.Debug("failed to refresh MLX memory status", "error", err)
}
c.inFlight = nil
return
}
c.memory = memory
c.refreshedAt = now
c.inFlight = nil
}()
return refreshDone
}
func (c *statusMemoryCache) snapshot() (uint64, time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
return c.memory, c.refreshedAt
}

View file

@ -0,0 +1,246 @@
package mlxrunner
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestStatusMemoryCacheWaitsForFastRefresh(t *testing.T) {
var calls atomic.Int32
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
calls.Add(1)
return 42, nil
})
if got := cache.Memory(); got != 42 {
t.Fatalf("got memory %d, want 42", got)
}
if got := calls.Load(); got != 1 {
t.Fatalf("refresh calls = %d, want 1", got)
}
}
func TestStatusMemoryCacheSupportsBlockingWait(t *testing.T) {
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), 0, func() (uint64, error) {
return 42, nil
})
if got := cache.Memory(); got != 42 {
t.Fatalf("got memory %d, want 42", got)
}
}
func TestStatusMemoryCacheReturnsCachedValueAndRefreshesLater(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache := newStatusMemoryCache(ctx, 7, time.Now().Add(-time.Minute), time.Millisecond, func() (uint64, error) {
if calls.Add(1) == 1 {
close(started)
}
select {
case <-release:
return 42, nil
case <-ctx.Done():
return 0, ctx.Err()
}
})
start := time.Now()
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
if elapsed := time.Since(start); elapsed > time.Second {
t.Fatalf("cached memory lookup took too long: %s", elapsed)
}
waitForRefreshStart(t, started)
close(release)
waitForCachedMemory(t, cache, 42)
if got := calls.Load(); got != 1 {
t.Fatalf("refresh calls = %d, want 1", got)
}
}
func TestStatusMemoryCacheReturnsCachedValueBeforeFirstRefresh(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
started := make(chan struct{})
release := make(chan struct{})
cache := newStatusMemoryCache(ctx, 7, time.Time{}, time.Millisecond, func() (uint64, error) {
close(started)
select {
case <-release:
return 42, nil
case <-ctx.Done():
return 0, ctx.Err()
}
})
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
waitForRefreshStart(t, started)
close(release)
waitForCachedMemory(t, cache, 42)
}
func TestStatusMemoryCacheKeepsCachedValueWhenRefreshFails(t *testing.T) {
var calls atomic.Int32
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
calls.Add(1)
return 0, errors.New("refresh failed")
})
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
if got := calls.Load(); got != 1 {
t.Fatalf("refresh calls = %d, want 1", got)
}
}
func TestStatusMemoryCacheReturnsCachedValueWhenContextDone(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
started := make(chan struct{})
release := make(chan struct{})
cache := newStatusMemoryCache(ctx, 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
close(started)
<-release
return 0, ctx.Err()
})
cancel()
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
waitForRefreshStart(t, started)
close(release)
waitForInflightRefresh(t, cache)
}
func TestStatusMemoryCacheAllowsRefreshAfterFailure(t *testing.T) {
var calls atomic.Int32
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
if calls.Add(1) == 1 {
return 0, errors.New("refresh failed")
}
return 42, nil
})
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
if got := cache.Memory(); got != 42 {
t.Fatalf("got memory %d after retry, want 42", got)
}
if got := calls.Load(); got != 2 {
t.Fatalf("refresh calls = %d, want 2", got)
}
}
func TestStatusMemoryCacheAllowsOneInflightRefresh(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache := newStatusMemoryCache(ctx, 11, time.Now().Add(-time.Minute), time.Millisecond, func() (uint64, error) {
if calls.Add(1) == 1 {
close(started)
}
select {
case <-release:
return 99, nil
case <-ctx.Done():
return 0, ctx.Err()
}
})
const goroutines = 8
var wg sync.WaitGroup
errCh := make(chan string, goroutines)
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
if got := cache.Memory(); got != 11 {
errCh <- "got non-cached memory value"
}
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
t.Fatal(err)
}
waitForRefreshStart(t, started)
if got := calls.Load(); got != 1 {
t.Fatalf("refresh calls = %d, want 1", got)
}
close(release)
waitForCachedMemory(t, cache, 99)
}
func waitForRefreshStart(t *testing.T, started <-chan struct{}) {
t.Helper()
select {
case <-started:
case <-time.After(time.Second):
t.Fatal("timeout waiting for refresh to start")
}
}
func waitForCachedMemory(t *testing.T, cache *statusMemoryCache, want uint64) {
t.Helper()
deadline := time.After(time.Second)
for {
got, _ := cache.snapshot()
if got == want {
return
}
select {
case <-deadline:
t.Fatalf("cached memory = %d, want %d", got, want)
case <-time.After(time.Millisecond):
}
}
}
func waitForInflightRefresh(t *testing.T, cache *statusMemoryCache) {
t.Helper()
deadline := time.After(time.Second)
for {
cache.mu.Lock()
inFlight := cache.inFlight
cache.mu.Unlock()
if inFlight == nil {
return
}
select {
case <-deadline:
t.Fatal("timeout waiting for refresh to finish")
case <-time.After(time.Millisecond):
}
}
}