mlx: update the imagegen runner for mlx thread affinity (#16096)

This commit is contained in:
Patrick Devine 2026-05-11 13:05:06 -07:00 committed by GitHub
parent 3d5a011a2e
commit d819ef0f97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 272 additions and 11 deletions

View file

@ -2856,8 +2856,12 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
}); err != nil {
// Only send JSON error if streaming hasn't started yet
// (once streaming starts, headers are committed and we can't change status code)
if !streamStarted {
if !isStreaming || !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
} else {
data, _ := json.Marshal(gin.H{"error": err.Error()})
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
}
return
}

View file

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
@ -2688,3 +2689,120 @@ func TestImageGenerateStreamFalse(t *testing.T) {
t.Errorf("expected done=true")
}
}
func newImageGenerateTestServer(t *testing.T, mock *mockRunner) Server {
t.Helper()
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
n := model.ParseName("test-image")
cfg := model.ConfigV2{Capabilities: []string{"image"}}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
t.Fatal(err)
}
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
t.Fatal(err)
}
loadedModel, err := GetModel("test-image")
if err != nil {
t.Fatal(err)
}
opts := api.DefaultOptions()
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: map[string]*runnerRef{
schedulerModelKey(loadedModel): {
llama: mock,
Options: &opts,
model: loadedModel,
isImagegen: true,
numParallel: 1,
},
},
newServerFn: newMockServer(mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
},
}
go s.sched.Run(t.Context())
return s
}
func TestImageGenerateStreamFalseErrorAfterProgress(t *testing.T) {
mock := mockRunner{}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
return errors.New("runner died")
}
s := newImageGenerateTestServer(t, &mock)
streamFalse := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",
Prompt: "test prompt",
Stream: &streamFalse,
})
if w.Code != http.StatusInternalServerError {
t.Fatalf("expected status 500, got %d: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), "runner died") {
t.Fatalf("expected runner error in body, got %q", w.Body.String())
}
}
func TestImageGenerateStreamingErrorAfterProgress(t *testing.T) {
mock := mockRunner{}
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
return errors.New("runner died")
}
s := newImageGenerateTestServer(t, &mock)
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-image",
Prompt: "test prompt",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 after streaming started, got %d: %s", w.Code, w.Body.String())
}
lines := strings.Split(strings.TrimSpace(w.Body.String()), "\n")
if len(lines) != 2 {
t.Fatalf("expected progress and error lines, got %d:\n%s", len(lines), w.Body.String())
}
var progress api.GenerateResponse
if err := json.Unmarshal([]byte(lines[0]), &progress); err != nil {
t.Fatalf("failed to parse progress response: %v", err)
}
if progress.Completed != 1 || progress.Total != 3 || progress.Done {
t.Fatalf("progress response = %+v", progress)
}
var errorResponse struct {
Error string `json:"error"`
}
if err := json.Unmarshal([]byte(lines[1]), &errorResponse); err != nil {
t.Fatalf("failed to parse error response: %v", err)
}
if errorResponse.Error != "runner died" {
t.Fatalf("error = %q, want runner died", errorResponse.Error)
}
}

View file

@ -15,6 +15,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/internal/mlxthread"
)
// Execute is the entry point for the unified MLX runner subprocess.
@ -45,17 +46,30 @@ func Execute(args []string) error {
return fmt.Errorf("imagegen runner only supports image generation models")
}
// Initialize MLX only for image generation mode.
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
worker, err := mlxthread.Start("imagegen", func() error {
if err := mlx.InitMLX(); err != nil {
slog.Error("unable to initialize MLX", "error", err)
return err
}
slog.Info("MLX library initialized")
return nil
})
if err != nil {
return err
}
slog.Info("MLX library initialized")
// Create and start server
server, err := newServer(*modelName, *port)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
var server *server
if err := worker.Do(context.Background(), func() error {
var err error
server, err = newServer(*modelName, *port)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
server.mlxThread = worker
return nil
}); err != nil {
return err
}
// Set up HTTP handlers
@ -77,7 +91,17 @@ func Execute(args []string) error {
slog.Info("shutting down mlx runner")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
httpServer.Shutdown(ctx)
if err := httpServer.Shutdown(ctx); err != nil {
slog.Warn("graceful shutdown timed out", "error", err)
if err := httpServer.Close(); err != nil {
slog.Warn("failed to close http server", "error", err)
}
}
if err := worker.Stop(ctx, func() {
mlx.ClearCache()
}); err != nil {
slog.Warn("failed to stop mlx worker", "error", err)
}
close(done)
}()
@ -110,6 +134,7 @@ func detectModelMode(modelName string) ModelMode {
type server struct {
modelName string
port int
mlxThread *mlxthread.Thread
// Image generation model.
imageModel ImageModel
@ -147,5 +172,10 @@ func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
return
}
s.handleImageCompletion(w, r, req)
if err := s.mlxThread.Do(r.Context(), func() error {
s.handleImageCompletion(w, r, req)
return nil
}); err != nil && r.Context().Err() == nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

View file

@ -367,9 +367,16 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
// Check if subprocess is still alive
if s.HasExited() {
slog.Error("mlx subprocess has exited unexpectedly")
if errMsg := s.getLastErr(); errMsg != "" {
return fmt.Errorf("mlx runner closed response before completion: %s", errMsg)
}
}
return scanErr
if scanErr != nil {
return scanErr
}
return errors.New("mlx runner closed response before completion")
}
// Close terminates the subprocess.

102
x/imagegen/server_test.go Normal file
View file

@ -0,0 +1,102 @@
package imagegen
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"github.com/ollama/ollama/llm"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func newCompletionTestServer(handler func(*http.Request) string) *Server {
return &Server{
port: 11434,
done: make(chan error, 1),
client: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
body := handler(req)
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}, nil
}),
},
}
}
func TestCompletionReturnsImageData(t *testing.T) {
s := newCompletionTestServer(func(r *http.Request) string {
if r.URL.Path != "/completion" {
t.Fatalf("path = %q, want /completion", r.URL.Path)
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatal(err)
}
if req.Prompt != "test prompt" || req.Width != 512 || req.Height != 256 || req.Steps != 7 || req.Seed != 42 {
t.Fatalf("unexpected request: %+v", req)
}
if len(req.Images) != 1 || string(req.Images[0]) != "input-image" {
t.Fatalf("images = %q, want input-image", req.Images)
}
return `{"step":1,"total":2}` + "\n" +
`{"done":true,"image":"base64png"}` + "\n"
})
var responses []llm.CompletionResponse
err := s.Completion(context.Background(), llm.CompletionRequest{
Prompt: "test prompt",
Width: 512,
Height: 256,
Steps: 7,
Seed: 42,
Images: []llm.ImageData{{Data: []byte("input-image")}},
}, func(resp llm.CompletionResponse) {
responses = append(responses, resp)
})
if err != nil {
t.Fatal(err)
}
if len(responses) != 2 {
t.Fatalf("responses = %d, want 2", len(responses))
}
if responses[0].Step != 1 || responses[0].TotalSteps != 2 || responses[0].Done {
t.Fatalf("progress response = %+v", responses[0])
}
if !responses[1].Done || responses[1].Image != "base64png" {
t.Fatalf("final response = %+v", responses[1])
}
}
func TestCompletionEOFBeforeDoneReturnsError(t *testing.T) {
s := newCompletionTestServer(func(r *http.Request) string {
return `{"step":1,"total":2}` + "\n"
})
var responses []llm.CompletionResponse
err := s.Completion(context.Background(), llm.CompletionRequest{Prompt: "test prompt"}, func(resp llm.CompletionResponse) {
responses = append(responses, resp)
})
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "closed response before completion") {
t.Fatalf("error = %v", err)
}
if len(responses) != 1 || responses[0].Done {
t.Fatalf("responses = %+v, want one non-done progress response", responses)
}
}