mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 06:21:28 +00:00
mlx: update the imagegen runner for mlx thread affinity (#16096)
This commit is contained in:
parent
3d5a011a2e
commit
d819ef0f97
5 changed files with 272 additions and 11 deletions
|
|
@ -2856,8 +2856,12 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
|
|||
}); err != nil {
|
||||
// Only send JSON error if streaming hasn't started yet
|
||||
// (once streaming starts, headers are committed and we can't change status code)
|
||||
if !streamStarted {
|
||||
if !isStreaming || !streamStarted {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
} else {
|
||||
data, _ := json.Marshal(gin.H{"error": err.Error()})
|
||||
c.Writer.Write(append(data, '\n'))
|
||||
c.Writer.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
|
@ -2688,3 +2689,120 @@ func TestImageGenerateStreamFalse(t *testing.T) {
|
|||
t.Errorf("expected done=true")
|
||||
}
|
||||
}
|
||||
|
||||
func newImageGenerateTestServer(t *testing.T, mock *mockRunner) Server {
|
||||
t.Helper()
|
||||
|
||||
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
n := model.ParseName("test-image")
|
||||
cfg := model.ConfigV2{Capabilities: []string{"image"}}
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loadedModel, err := GetModel("test-image")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: map[string]*runnerRef{
|
||||
schedulerModelKey(loadedModel): {
|
||||
llama: mock,
|
||||
Options: &opts,
|
||||
model: loadedModel,
|
||||
isImagegen: true,
|
||||
numParallel: 1,
|
||||
},
|
||||
},
|
||||
newServerFn: newMockServer(mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
return s
|
||||
}
|
||||
|
||||
func TestImageGenerateStreamFalseErrorAfterProgress(t *testing.T) {
|
||||
mock := mockRunner{}
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
|
||||
return errors.New("runner died")
|
||||
}
|
||||
s := newImageGenerateTestServer(t, &mock)
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
Prompt: "test prompt",
|
||||
Stream: &streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected status 500, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
if !strings.Contains(w.Body.String(), "runner died") {
|
||||
t.Fatalf("expected runner error in body, got %q", w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageGenerateStreamingErrorAfterProgress(t *testing.T) {
|
||||
mock := mockRunner{}
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
|
||||
return errors.New("runner died")
|
||||
}
|
||||
s := newImageGenerateTestServer(t, &mock)
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-image",
|
||||
Prompt: "test prompt",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200 after streaming started, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
lines := strings.Split(strings.TrimSpace(w.Body.String()), "\n")
|
||||
if len(lines) != 2 {
|
||||
t.Fatalf("expected progress and error lines, got %d:\n%s", len(lines), w.Body.String())
|
||||
}
|
||||
|
||||
var progress api.GenerateResponse
|
||||
if err := json.Unmarshal([]byte(lines[0]), &progress); err != nil {
|
||||
t.Fatalf("failed to parse progress response: %v", err)
|
||||
}
|
||||
if progress.Completed != 1 || progress.Total != 3 || progress.Done {
|
||||
t.Fatalf("progress response = %+v", progress)
|
||||
}
|
||||
|
||||
var errorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(lines[1]), &errorResponse); err != nil {
|
||||
t.Fatalf("failed to parse error response: %v", err)
|
||||
}
|
||||
if errorResponse.Error != "runner died" {
|
||||
t.Fatalf("error = %q, want runner died", errorResponse.Error)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/internal/mlxthread"
|
||||
)
|
||||
|
||||
// Execute is the entry point for the unified MLX runner subprocess.
|
||||
|
|
@ -45,17 +46,30 @@ func Execute(args []string) error {
|
|||
return fmt.Errorf("imagegen runner only supports image generation models")
|
||||
}
|
||||
|
||||
// Initialize MLX only for image generation mode.
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
worker, err := mlxthread.Start("imagegen", func() error {
|
||||
if err := mlx.InitMLX(); err != nil {
|
||||
slog.Error("unable to initialize MLX", "error", err)
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Info("MLX library initialized")
|
||||
|
||||
// Create and start server
|
||||
server, err := newServer(*modelName, *port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create server: %w", err)
|
||||
var server *server
|
||||
if err := worker.Do(context.Background(), func() error {
|
||||
var err error
|
||||
server, err = newServer(*modelName, *port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create server: %w", err)
|
||||
}
|
||||
server.mlxThread = worker
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up HTTP handlers
|
||||
|
|
@ -77,7 +91,17 @@ func Execute(args []string) error {
|
|||
slog.Info("shutting down mlx runner")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
httpServer.Shutdown(ctx)
|
||||
if err := httpServer.Shutdown(ctx); err != nil {
|
||||
slog.Warn("graceful shutdown timed out", "error", err)
|
||||
if err := httpServer.Close(); err != nil {
|
||||
slog.Warn("failed to close http server", "error", err)
|
||||
}
|
||||
}
|
||||
if err := worker.Stop(ctx, func() {
|
||||
mlx.ClearCache()
|
||||
}); err != nil {
|
||||
slog.Warn("failed to stop mlx worker", "error", err)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
|
@ -110,6 +134,7 @@ func detectModelMode(modelName string) ModelMode {
|
|||
type server struct {
|
||||
modelName string
|
||||
port int
|
||||
mlxThread *mlxthread.Thread
|
||||
|
||||
// Image generation model.
|
||||
imageModel ImageModel
|
||||
|
|
@ -147,5 +172,10 @@ func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
s.handleImageCompletion(w, r, req)
|
||||
if err := s.mlxThread.Do(r.Context(), func() error {
|
||||
s.handleImageCompletion(w, r, req)
|
||||
return nil
|
||||
}); err != nil && r.Context().Err() == nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -367,9 +367,16 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
|
|||
// Check if subprocess is still alive
|
||||
if s.HasExited() {
|
||||
slog.Error("mlx subprocess has exited unexpectedly")
|
||||
if errMsg := s.getLastErr(); errMsg != "" {
|
||||
return fmt.Errorf("mlx runner closed response before completion: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
return scanErr
|
||||
if scanErr != nil {
|
||||
return scanErr
|
||||
}
|
||||
|
||||
return errors.New("mlx runner closed response before completion")
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
|
|
|
|||
102
x/imagegen/server_test.go
Normal file
102
x/imagegen/server_test.go
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
package imagegen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return fn(req)
|
||||
}
|
||||
|
||||
func newCompletionTestServer(handler func(*http.Request) string) *Server {
|
||||
return &Server{
|
||||
port: 11434,
|
||||
done: make(chan error, 1),
|
||||
client: &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
body := handler(req)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionReturnsImageData(t *testing.T) {
|
||||
s := newCompletionTestServer(func(r *http.Request) string {
|
||||
if r.URL.Path != "/completion" {
|
||||
t.Fatalf("path = %q, want /completion", r.URL.Path)
|
||||
}
|
||||
|
||||
var req Request
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if req.Prompt != "test prompt" || req.Width != 512 || req.Height != 256 || req.Steps != 7 || req.Seed != 42 {
|
||||
t.Fatalf("unexpected request: %+v", req)
|
||||
}
|
||||
if len(req.Images) != 1 || string(req.Images[0]) != "input-image" {
|
||||
t.Fatalf("images = %q, want input-image", req.Images)
|
||||
}
|
||||
|
||||
return `{"step":1,"total":2}` + "\n" +
|
||||
`{"done":true,"image":"base64png"}` + "\n"
|
||||
})
|
||||
|
||||
var responses []llm.CompletionResponse
|
||||
err := s.Completion(context.Background(), llm.CompletionRequest{
|
||||
Prompt: "test prompt",
|
||||
Width: 512,
|
||||
Height: 256,
|
||||
Steps: 7,
|
||||
Seed: 42,
|
||||
Images: []llm.ImageData{{Data: []byte("input-image")}},
|
||||
}, func(resp llm.CompletionResponse) {
|
||||
responses = append(responses, resp)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(responses) != 2 {
|
||||
t.Fatalf("responses = %d, want 2", len(responses))
|
||||
}
|
||||
if responses[0].Step != 1 || responses[0].TotalSteps != 2 || responses[0].Done {
|
||||
t.Fatalf("progress response = %+v", responses[0])
|
||||
}
|
||||
if !responses[1].Done || responses[1].Image != "base64png" {
|
||||
t.Fatalf("final response = %+v", responses[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionEOFBeforeDoneReturnsError(t *testing.T) {
|
||||
s := newCompletionTestServer(func(r *http.Request) string {
|
||||
return `{"step":1,"total":2}` + "\n"
|
||||
})
|
||||
|
||||
var responses []llm.CompletionResponse
|
||||
err := s.Completion(context.Background(), llm.CompletionRequest{Prompt: "test prompt"}, func(resp llm.CompletionResponse) {
|
||||
responses = append(responses, resp)
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "closed response before completion") {
|
||||
t.Fatalf("error = %v", err)
|
||||
}
|
||||
if len(responses) != 1 || responses[0].Done {
|
||||
t.Fatalf("responses = %+v, want one non-done progress response", responses)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue