ollama/api/client_test.go
2026-06-12 16:28:58 -07:00

448 lines
12 KiB
Go

package api
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestClientFromEnvironment(t *testing.T) {
type testCase struct {
value string
expect string
err error
}
testCases := map[string]*testCase{
"empty": {value: "", expect: "http://127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "http://1.2.3.4:11434"},
"only port": {value: ":1234", expect: "http://:1234"},
"address and port": {value: "1.2.3.4:1234", expect: "http://1.2.3.4:1234"},
"scheme http and address": {value: "http://1.2.3.4", expect: "http://1.2.3.4:80"},
"scheme https and address": {value: "https://1.2.3.4", expect: "https://1.2.3.4:443"},
"scheme, address, and port": {value: "https://1.2.3.4:1234", expect: "https://1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "http://example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "http://example.com:1234"},
"scheme http and hostname": {value: "http://example.com", expect: "http://example.com:80"},
"scheme https and hostname": {value: "https://example.com", expect: "https://example.com:443"},
"scheme, hostname, and port": {value: "https://example.com:1234", expect: "https://example.com:1234"},
"trailing slash": {value: "example.com/", expect: "http://example.com:11434"},
"trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"},
}
for k, v := range testCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
client, err := ClientFromEnvironment()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if client.base.String() != v.expect {
t.Fatalf("expected %s, got %s", v.expect, client.base.String())
}
})
}
}
// testError represents an internal error type with status code and message
// this is used since the error response from the server is not a standard error struct
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
return e.message
}
func TestClientStream(t *testing.T) {
testCases := []struct {
name string
responses []any
wantErr string
}{
{
name: "immediate error response",
responses: []any{
testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
},
wantErr: "test error message",
},
{
name: "error after successful chunks, ok response",
responses: []any{
ChatResponse{Message: Message{Content: "partial response 1"}},
ChatResponse{Message: Message{Content: "partial response 2"}},
testError{
message: "mid-stream error",
statusCode: http.StatusOK,
},
},
wantErr: "mid-stream error",
},
{
name: "http status error takes precedence over general error",
responses: []any{
testError{
message: "custom error message",
statusCode: http.StatusInternalServerError,
},
},
wantErr: "500",
},
{
name: "successful stream completion",
responses: []any{
ChatResponse{Message: Message{Content: "chunk 1"}},
ChatResponse{Message: Message{Content: "chunk 2"}},
ChatResponse{
Message: Message{Content: "final chunk"},
Done: true,
DoneReason: "stop",
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("expected http.Flusher")
}
w.Header().Set("Content-Type", "application/x-ndjson")
for _, resp := range tc.responses {
if errResp, ok := resp.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
flusher.Flush()
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(t.Context(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
}
receivedChunks = append(receivedChunks, resp)
return nil
})
if tc.wantErr != "" {
if err == nil {
t.Fatal("expected error but got nil")
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestClientStreamReportsReadErrors(t *testing.T) {
client := NewClient(
&url.URL{Scheme: "http", Host: "example.com"},
&http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
body := failingReader{
data: []byte(`{"message":{"content":"partial"}}` + "\n"),
err: io.ErrUnexpectedEOF,
}
return &http.Response{
StatusCode: http.StatusOK,
Status: "200 OK",
Body: io.NopCloser(&body),
Header: make(http.Header),
}, nil
})},
)
err := client.stream(t.Context(), http.MethodPost, "/api/chat", nil, func([]byte) error {
return nil
})
if err == nil {
t.Fatal("expected stream read error")
}
if !strings.Contains(err.Error(), io.ErrUnexpectedEOF.Error()) {
t.Fatalf("expected unexpected EOF, got %v", err)
}
}
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
wantStatusCode int
}{
{
name: "immediate error response",
response: testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
},
{
name: "server error response",
response: testError{
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "successful response",
response: struct {
ID string `json:"id"`
Success bool `json:"success"`
}{
ID: "msg_123",
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var resp struct {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(t.Context(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("got nil, want error %q", tc.wantErr)
}
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}
if err != nil {
t.Fatalf("got error %q, want nil", err)
}
if expectedResp, ok := tc.response.(struct {
ID string `json:"id"`
Success bool `json:"success"`
}); ok {
if resp.ID != expectedResp.ID {
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
}
if resp.Success != expectedResp.Success {
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
}
}
})
}
}
func TestClientWebSearchExperimentalUsesLocalRoute(t *testing.T) {
var gotPath string
var gotMethod string
var gotRequest WebSearchRequest
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotMethod = r.Method
if err := json.NewDecoder(r.Body).Decode(&gotRequest); err != nil {
t.Fatal(err)
}
if err := json.NewEncoder(w).Encode(WebSearchResponse{
Results: []WebSearchResult{{Title: "Ollama", URL: "https://ollama.com", Content: "models"}},
}); err != nil {
t.Fatal(err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
resp, err := client.WebSearchExperimental(t.Context(), &WebSearchRequest{Query: "ollama", MaxResults: 3})
if err != nil {
t.Fatal(err)
}
if gotMethod != http.MethodPost {
t.Fatalf("method = %q, want POST", gotMethod)
}
if gotPath != "/api/experimental/web_search" {
t.Fatalf("path = %q, want /api/experimental/web_search", gotPath)
}
if gotRequest.Query != "ollama" || gotRequest.MaxResults != 3 {
t.Fatalf("request = %#v", gotRequest)
}
if len(resp.Results) != 1 || resp.Results[0].Title != "Ollama" {
t.Fatalf("response = %#v", resp)
}
}
func TestClientWebFetchExperimentalUsesLocalRoute(t *testing.T) {
var gotPath string
var gotMethod string
var gotRequest WebFetchRequest
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotMethod = r.Method
if err := json.NewDecoder(r.Body).Decode(&gotRequest); err != nil {
t.Fatal(err)
}
if err := json.NewEncoder(w).Encode(WebFetchResponse{
Title: "Ollama",
Content: "models",
Links: []string{"https://ollama.com/library"},
}); err != nil {
t.Fatal(err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
resp, err := client.WebFetchExperimental(t.Context(), &WebFetchRequest{URL: "https://ollama.com"})
if err != nil {
t.Fatal(err)
}
if gotMethod != http.MethodPost {
t.Fatalf("method = %q, want POST", gotMethod)
}
if gotPath != "/api/experimental/web_fetch" {
t.Fatalf("path = %q, want /api/experimental/web_fetch", gotPath)
}
if gotRequest.URL != "https://ollama.com" {
t.Fatalf("request = %#v", gotRequest)
}
if resp.Title != "Ollama" || resp.Content != "models" {
t.Fatalf("response = %#v", resp)
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
type failingReader struct {
data []byte
err error
}
func (r *failingReader) Read(p []byte) (int, error) {
if len(r.data) > 0 {
n := copy(p, r.data)
r.data = r.data[n:]
return n, nil
}
return 0, r.err
}