mirror of
https://github.com/ollama/ollama.git
synced 2026-05-13 14:27:00 +00:00
anthropic: Preserve Claude local image-path tool results in renderer-owned prompt formatting (#16047)
This commit is contained in:
parent
421faa0263
commit
6bdb73073b
16 changed files with 670 additions and 102 deletions
|
|
@ -78,6 +78,11 @@ type MessagesRequest struct {
|
||||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||||
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
Thinking *ThinkingConfig `json:"thinking,omitempty"`
|
||||||
Metadata *Metadata `json:"metadata,omitempty"`
|
Metadata *Metadata `json:"metadata,omitempty"`
|
||||||
|
OutputConfig *OutputConfig `json:"output_config,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OutputConfig struct {
|
||||||
|
Effort string `json:"effort,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MessageParam represents a message in the request
|
// MessageParam represents a message in the request
|
||||||
|
|
@ -161,7 +166,7 @@ type WebSearchToolResultError struct {
|
||||||
|
|
||||||
// ImageSource represents the source of an image
|
// ImageSource represents the source of an image
|
||||||
type ImageSource struct {
|
type ImageSource struct {
|
||||||
Type string `json:"type"` // "base64" or "url"
|
Type string `json:"type"` // "base64"
|
||||||
MediaType string `json:"media_type,omitempty"`
|
MediaType string `json:"media_type,omitempty"`
|
||||||
Data string `json:"data,omitempty"`
|
Data string `json:"data,omitempty"`
|
||||||
URL string `json:"url,omitempty"`
|
URL string `json:"url,omitempty"`
|
||||||
|
|
@ -373,9 +378,26 @@ func FromMessagesRequest(r MessagesRequest) (*api.ChatRequest, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var think *api.ThinkValue
|
var think *api.ThinkValue
|
||||||
|
normalizedEffort := ""
|
||||||
|
if r.OutputConfig != nil {
|
||||||
|
normalizedEffort = strings.ToLower(strings.TrimSpace(r.OutputConfig.Effort))
|
||||||
|
if normalizedEffort == "xhigh" {
|
||||||
|
normalizedEffort = "high"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
if r.Thinking != nil && r.Thinking.Type == "enabled" {
|
||||||
think = &api.ThinkValue{Value: true}
|
think = &api.ThinkValue{Value: true}
|
||||||
}
|
}
|
||||||
|
if r.Thinking != nil && r.Thinking.Type == "disabled" {
|
||||||
|
think = &api.ThinkValue{Value: false}
|
||||||
|
}
|
||||||
|
if think == nil && r.OutputConfig != nil {
|
||||||
|
switch normalizedEffort {
|
||||||
|
case "high", "medium", "low", "max":
|
||||||
|
think = &api.ThinkValue{Value: normalizedEffort}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
stream := r.Stream
|
stream := r.Stream
|
||||||
convertedRequest := &api.ChatRequest{
|
convertedRequest := &api.ChatRequest{
|
||||||
|
|
@ -425,17 +447,12 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||||
return nil, errors.New("invalid image source")
|
return nil, errors.New("invalid image source")
|
||||||
}
|
}
|
||||||
|
|
||||||
if block.Source.Type == "base64" {
|
decoded, err := resolveImageSource(block.Source)
|
||||||
decoded, err := base64.StdEncoding.DecodeString(block.Source.Data)
|
if err != nil {
|
||||||
if err != nil {
|
logutil.Trace("anthropic: unsupported image source", "role", role, "source_type", block.Source.Type, "error", err)
|
||||||
logutil.Trace("anthropic: invalid base64 image data", "role", role, "error", err)
|
return nil, err
|
||||||
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
|
||||||
}
|
|
||||||
images = append(images, decoded)
|
|
||||||
} else {
|
|
||||||
logutil.Trace("anthropic: unsupported image source type", "role", role, "source_type", block.Source.Type)
|
|
||||||
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", block.Source.Type)
|
|
||||||
}
|
}
|
||||||
|
images = append(images, decoded)
|
||||||
|
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
toolUseBlocks++
|
toolUseBlocks++
|
||||||
|
|
@ -457,26 +474,16 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||||
|
|
||||||
case "tool_result":
|
case "tool_result":
|
||||||
toolResultBlocks++
|
toolResultBlocks++
|
||||||
var resultContent string
|
resultContent, resultImages, err := convertToolResultContent(block.Content)
|
||||||
|
if err != nil {
|
||||||
switch c := block.Content.(type) {
|
logutil.Trace("anthropic: invalid tool_result content", "role", role, "error", err)
|
||||||
case string:
|
return nil, err
|
||||||
resultContent = c
|
|
||||||
case []any:
|
|
||||||
for _, cb := range c {
|
|
||||||
if cbMap, ok := cb.(map[string]any); ok {
|
|
||||||
if cbMap["type"] == "text" {
|
|
||||||
if text, ok := cbMap["text"].(string); ok {
|
|
||||||
resultContent += text
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
toolResults = append(toolResults, api.Message{
|
toolResults = append(toolResults, api.Message{
|
||||||
Role: "tool",
|
Role: "tool",
|
||||||
Content: resultContent,
|
Content: resultContent,
|
||||||
|
Images: resultImages,
|
||||||
ToolCallID: block.ToolUseID,
|
ToolCallID: block.ToolUseID,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -508,6 +515,10 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if role == "user" && len(toolResults) > 0 {
|
||||||
|
messages = append(messages, toolResults...)
|
||||||
|
}
|
||||||
|
|
||||||
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
if textContent.Len() > 0 || len(images) > 0 || len(toolCalls) > 0 || thinking != "" {
|
||||||
m := api.Message{
|
m := api.Message{
|
||||||
Role: role,
|
Role: role,
|
||||||
|
|
@ -519,8 +530,10 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
|
||||||
messages = append(messages, m)
|
messages = append(messages, m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add tool results as separate messages
|
// Add tool results as separate messages.
|
||||||
messages = append(messages, toolResults...)
|
if role != "user" || len(toolResults) == 0 {
|
||||||
|
messages = append(messages, toolResults...)
|
||||||
|
}
|
||||||
logutil.Trace("anthropic: converted block message",
|
logutil.Trace("anthropic: converted block message",
|
||||||
"role", role,
|
"role", role,
|
||||||
"blocks", len(msg.Content),
|
"blocks", len(msg.Content),
|
||||||
|
|
@ -969,6 +982,71 @@ func GenerateMessageID() string {
|
||||||
return generateID("msg")
|
return generateID("msg")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func resolveImageSource(source *ImageSource) (api.ImageData, error) {
|
||||||
|
if source.Type != "base64" {
|
||||||
|
return nil, fmt.Errorf("invalid image source type: %s. Only base64 images are supported.", source.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(source.Data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid base64 image data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolResultContent(content any) (string, []api.ImageData, error) {
|
||||||
|
switch c := content.(type) {
|
||||||
|
case nil:
|
||||||
|
return "", nil, nil
|
||||||
|
case string:
|
||||||
|
return c, nil, nil
|
||||||
|
case []any:
|
||||||
|
var text strings.Builder
|
||||||
|
var images []api.ImageData
|
||||||
|
|
||||||
|
for _, cb := range c {
|
||||||
|
cbMap, ok := cb.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch cbMap["type"] {
|
||||||
|
case "text":
|
||||||
|
if t, ok := cbMap["text"].(string); ok {
|
||||||
|
text.WriteString(t)
|
||||||
|
}
|
||||||
|
case "image":
|
||||||
|
rawSource, ok := cbMap["source"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return "", nil, errors.New("invalid tool_result image source")
|
||||||
|
}
|
||||||
|
|
||||||
|
var source ImageSource
|
||||||
|
if rawType, ok := rawSource["type"].(string); ok {
|
||||||
|
source.Type = rawType
|
||||||
|
}
|
||||||
|
if rawMediaType, ok := rawSource["media_type"].(string); ok {
|
||||||
|
source.MediaType = rawMediaType
|
||||||
|
}
|
||||||
|
if rawData, ok := rawSource["data"].(string); ok {
|
||||||
|
source.Data = rawData
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := resolveImageSource(&source)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
images = append(images, img)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return text.String(), images, nil
|
||||||
|
default:
|
||||||
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ptr returns a pointer to the given string value
|
// ptr returns a pointer to the given string value
|
||||||
func ptr(s string) *string {
|
func ptr(s string) *string {
|
||||||
return &s
|
return &s
|
||||||
|
|
|
||||||
|
|
@ -271,6 +271,241 @@ func TestFromMessagesRequest_WithToolResult(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithToolResultImage(t *testing.T) {
|
||||||
|
imgData, _ := base64.StdEncoding.DecodeString(testImage)
|
||||||
|
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []ContentBlock{
|
||||||
|
{
|
||||||
|
Type: "tool_result",
|
||||||
|
ToolUseID: "call_img",
|
||||||
|
Content: []any{
|
||||||
|
map[string]any{"type": "text", "text": "Attached image"},
|
||||||
|
map[string]any{
|
||||||
|
"type": "image",
|
||||||
|
"source": map[string]any{
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/png",
|
||||||
|
"data": testImage,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 1 {
|
||||||
|
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := result.Messages[0]
|
||||||
|
if msg.Role != "tool" {
|
||||||
|
t.Errorf("expected role 'tool', got %q", msg.Role)
|
||||||
|
}
|
||||||
|
if msg.ToolCallID != "call_img" {
|
||||||
|
t.Errorf("expected tool_call_id 'call_img', got %q", msg.ToolCallID)
|
||||||
|
}
|
||||||
|
if msg.Content != "Attached image" {
|
||||||
|
t.Errorf("unexpected content: %q", msg.Content)
|
||||||
|
}
|
||||||
|
if len(msg.Images) != 1 {
|
||||||
|
t.Fatalf("expected 1 image, got %d", len(msg.Images))
|
||||||
|
}
|
||||||
|
if string(msg.Images[0]) != string(imgData) {
|
||||||
|
t.Error("image data mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithToolResultFollowedByUserText(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "test-model",
|
||||||
|
MaxTokens: 1024,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: []ContentBlock{
|
||||||
|
{
|
||||||
|
Type: "tool_use",
|
||||||
|
ID: "call_read",
|
||||||
|
Name: "Read",
|
||||||
|
Input: makeArgs("file_path", "/Users/hoyyeva/Desktop/aaa.png"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []ContentBlock{
|
||||||
|
{
|
||||||
|
Type: "tool_result",
|
||||||
|
ToolUseID: "call_read",
|
||||||
|
Content: "Read image (311.5KB)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: "text",
|
||||||
|
Text: ptr("Please describe it."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 3 {
|
||||||
|
t.Fatalf("expected 3 messages, got %d", len(result.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Messages[1].Role != "tool" {
|
||||||
|
t.Fatalf("expected second message to be tool, got %q", result.Messages[1].Role)
|
||||||
|
}
|
||||||
|
if result.Messages[1].ToolCallID != "call_read" {
|
||||||
|
t.Fatalf("expected tool_call_id 'call_read', got %q", result.Messages[1].ToolCallID)
|
||||||
|
}
|
||||||
|
if result.Messages[2].Role != "user" {
|
||||||
|
t.Fatalf("expected third message to be user, got %q", result.Messages[2].Role)
|
||||||
|
}
|
||||||
|
if result.Messages[2].Content != "Please describe it." {
|
||||||
|
t.Fatalf("unexpected user content: %q", result.Messages[2].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithOutputConfigEffort(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "gemma4",
|
||||||
|
MaxTokens: 32000,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: textContent("Describe the image."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
OutputConfig: &OutputConfig{
|
||||||
|
Effort: "high",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Think == nil {
|
||||||
|
t.Fatal("expected think to be set from output_config.effort")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := result.Think.String(); got != "high" {
|
||||||
|
t.Fatalf("expected think level 'high', got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_WithOutputConfigEffortXHighMapsToHigh(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "gemma4",
|
||||||
|
MaxTokens: 32000,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: textContent("Describe the image."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
OutputConfig: &OutputConfig{
|
||||||
|
Effort: "xhigh",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Think == nil {
|
||||||
|
t.Fatal("expected think to be set from output_config.effort")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := result.Think.String(); got != "high" {
|
||||||
|
t.Fatalf("expected think level 'high' for xhigh effort, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_ThinkingDisabledOverridesOutputConfigEffort(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "gemma4",
|
||||||
|
MaxTokens: 32000,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: textContent("Describe the image."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Thinking: &ThinkingConfig{
|
||||||
|
Type: "disabled",
|
||||||
|
},
|
||||||
|
OutputConfig: &OutputConfig{
|
||||||
|
Effort: "high",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Think == nil {
|
||||||
|
t.Fatal("expected think to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := result.Think.Value; got != false {
|
||||||
|
t.Fatalf("expected think=false when thinking is disabled, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromMessagesRequest_ThinkingAdaptiveUsesOutputConfigEffort(t *testing.T) {
|
||||||
|
req := MessagesRequest{
|
||||||
|
Model: "gemma4",
|
||||||
|
MaxTokens: 32000,
|
||||||
|
Messages: []MessageParam{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: textContent("Describe the image."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Thinking: &ThinkingConfig{
|
||||||
|
Type: "adaptive",
|
||||||
|
},
|
||||||
|
OutputConfig: &OutputConfig{
|
||||||
|
Effort: "high",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := FromMessagesRequest(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Think == nil {
|
||||||
|
t.Fatal("expected think to be set from output_config.effort")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := result.Think.String(); got != "high" {
|
||||||
|
t.Fatalf("expected think level 'high' for adaptive thinking, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
func TestFromMessagesRequest_WithTools(t *testing.T) {
|
||||||
req := MessagesRequest{
|
req := MessagesRequest{
|
||||||
Model: "test-model",
|
Model: "test-model",
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,8 @@ func (r *Gemma4Renderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||||
toolResponsesEmitted := false
|
toolResponsesEmitted := false
|
||||||
if len(message.ToolCalls) > 0 {
|
if len(message.ToolCalls) > 0 {
|
||||||
for k := i + 1; k < len(loopMessages) && loopMessages[k].Role == "tool"; k++ {
|
for k := i + 1; k < len(loopMessages) && loopMessages[k].Role == "tool"; k++ {
|
||||||
sb.WriteString(r.formatToolResponseBlock(r.toolResponseName(loopMessages[k], message.ToolCalls), loopMessages[k].Content))
|
response := r.renderToolResponseContent(loopMessages[k], &imageOffset)
|
||||||
|
sb.WriteString(r.formatToolResponseBlock(r.toolResponseName(loopMessages[k], message.ToolCalls), response))
|
||||||
toolResponsesEmitted = true
|
toolResponsesEmitted = true
|
||||||
prevMessageType = "tool_response"
|
prevMessageType = "tool_response"
|
||||||
}
|
}
|
||||||
|
|
@ -160,19 +161,22 @@ func stripThinking(text string) string {
|
||||||
// When trim is true, leading/trailing whitespace is stripped (matching the Jinja2
|
// When trim is true, leading/trailing whitespace is stripped (matching the Jinja2
|
||||||
// template's | trim filter applied to non-model content).
|
// template's | trim filter applied to non-model content).
|
||||||
func (r *Gemma4Renderer) renderContent(sb *strings.Builder, msg api.Message, imageOffset *int, trim bool) {
|
func (r *Gemma4Renderer) renderContent(sb *strings.Builder, msg api.Message, imageOffset *int, trim bool) {
|
||||||
if len(msg.Images) > 0 && r.useImgTags {
|
|
||||||
for range msg.Images {
|
|
||||||
sb.WriteString(fmt.Sprintf("[img-%d]", *imageOffset))
|
|
||||||
*imageOffset++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
content := msg.Content
|
content := msg.Content
|
||||||
if trim {
|
if trim {
|
||||||
content = strings.TrimSpace(content)
|
content = strings.TrimSpace(content)
|
||||||
}
|
}
|
||||||
|
if len(msg.Images) > 0 && r.useImgTags {
|
||||||
|
content, *imageOffset = renderContentWithImageTags(content, len(msg.Images), *imageOffset)
|
||||||
|
}
|
||||||
sb.WriteString(content)
|
sb.WriteString(content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Gemma4Renderer) renderToolResponseContent(msg api.Message, imageOffset *int) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
r.renderContent(&sb, msg, imageOffset, false)
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Gemma4Renderer) previousNonToolRole(messages []api.Message, idx int) string {
|
func (r *Gemma4Renderer) previousNonToolRole(messages []api.Message, idx int) string {
|
||||||
for i := idx - 1; i >= 0; i-- {
|
for i := idx - 1; i >= 0; i-- {
|
||||||
if messages[i].Role != "tool" {
|
if messages[i].Role != "tool" {
|
||||||
|
|
|
||||||
|
|
@ -13,15 +13,11 @@ type GlmOcrRenderer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
|
func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) {
|
||||||
var sb strings.Builder
|
if r.useImgTags {
|
||||||
for range message.Images {
|
return renderContentWithImageTags(message.Content, len(message.Images), imageOffset)
|
||||||
if r.useImgTags {
|
|
||||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
|
||||||
imageOffset++
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
sb.WriteString(message.Content)
|
|
||||||
return sb.String(), imageOffset
|
return message.Content, imageOffset
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||||
|
|
@ -85,8 +81,10 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV
|
||||||
if i == 0 || messages[i-1].Role != "tool" {
|
if i == 0 || messages[i-1].Role != "tool" {
|
||||||
sb.WriteString("<|observation|>")
|
sb.WriteString("<|observation|>")
|
||||||
}
|
}
|
||||||
|
content, nextOffset := r.renderContent(message, imageOffset)
|
||||||
|
imageOffset = nextOffset
|
||||||
sb.WriteString("\n<tool_response>\n")
|
sb.WriteString("\n<tool_response>\n")
|
||||||
sb.WriteString(message.Content)
|
sb.WriteString(content)
|
||||||
sb.WriteString("\n</tool_response>\n")
|
sb.WriteString("\n</tool_response>\n")
|
||||||
case "system":
|
case "system":
|
||||||
sb.WriteString("<|system|>\n")
|
sb.WriteString("<|system|>\n")
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ func TestGlmOcrRenderer_Images(t *testing.T) {
|
||||||
Images: []api.ImageData{api.ImageData("img1")},
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expected: "[gMASK]<sop><|user|>\n[img-0]Describe this image.<|assistant|>\n",
|
expected: "[gMASK]<sop><|user|>\n[img-0] Describe this image.<|assistant|>\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "use_img_tags_multiple_images",
|
name: "use_img_tags_multiple_images",
|
||||||
|
|
@ -37,7 +37,7 @@ func TestGlmOcrRenderer_Images(t *testing.T) {
|
||||||
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expected: "[gMASK]<sop><|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n",
|
expected: "[gMASK]<sop><|user|>\n[img-0][img-1] Describe these images.<|assistant|>\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multi_turn_increments_image_offset",
|
name: "multi_turn_increments_image_offset",
|
||||||
|
|
@ -58,7 +58,7 @@ func TestGlmOcrRenderer_Images(t *testing.T) {
|
||||||
Images: []api.ImageData{api.ImageData("img2")},
|
Images: []api.ImageData{api.ImageData("img2")},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expected: "[gMASK]<sop><|user|>\n[img-0]First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n",
|
expected: "[gMASK]<sop><|user|>\n[img-0] First image<|assistant|>\n<think></think>\nProcessed.\n<|user|>\n[img-1] Second image<|assistant|>\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "default_no_img_tags",
|
name: "default_no_img_tags",
|
||||||
|
|
|
||||||
39
model/renderers/image_tags.go
Normal file
39
model/renderers/image_tags.go
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
"unicode/utf8"
|
||||||
|
)
|
||||||
|
|
||||||
|
// renderContentWithImageTags preserves the legacy server-side placeholder
|
||||||
|
// semantics for explicit [img] tokens: replace placeholders in order, and
|
||||||
|
// only prepend tags for any remaining images without placeholders.
|
||||||
|
func renderContentWithImageTags(content string, imageCount int, imageOffset int) (string, int) {
|
||||||
|
if imageCount == 0 {
|
||||||
|
return content, imageOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(content, "[img-") {
|
||||||
|
return content, imageOffset + imageCount
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefix strings.Builder
|
||||||
|
for i := range imageCount {
|
||||||
|
imgTag := fmt.Sprintf("[img-%d]", imageOffset+i)
|
||||||
|
if strings.Contains(content, "[img]") {
|
||||||
|
content = strings.Replace(content, "[img]", imgTag, 1)
|
||||||
|
} else {
|
||||||
|
prefix.WriteString(imgTag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.Len() > 0 && content != "" {
|
||||||
|
if r, _ := utf8.DecodeRuneInString(content); r != utf8.RuneError && !unicode.IsSpace(r) {
|
||||||
|
prefix.WriteByte(' ')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefix.String() + content, imageOffset + imageCount
|
||||||
|
}
|
||||||
67
model/renderers/image_tags_test.go
Normal file
67
model/renderers/image_tags_test.go
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
package renderers
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestRenderContentWithImageTags(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
content string
|
||||||
|
imageCount int
|
||||||
|
imageOffset int
|
||||||
|
want string
|
||||||
|
wantOffset int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "prefixes when there are no placeholders",
|
||||||
|
content: "describe this image",
|
||||||
|
imageCount: 2,
|
||||||
|
imageOffset: 0,
|
||||||
|
want: "[img-0][img-1] describe this image",
|
||||||
|
wantOffset: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "replaces explicit placeholders in order",
|
||||||
|
content: "compare [img] and [img]",
|
||||||
|
imageCount: 2,
|
||||||
|
imageOffset: 3,
|
||||||
|
want: "compare [img-3] and [img-4]",
|
||||||
|
wantOffset: 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "prefixes extra images after placeholders are exhausted",
|
||||||
|
content: "compare [img]",
|
||||||
|
imageCount: 2,
|
||||||
|
imageOffset: 0,
|
||||||
|
want: "[img-1] compare [img-0]",
|
||||||
|
wantOffset: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "leaves leftover placeholders when there are fewer images",
|
||||||
|
content: "compare [img] and [img]",
|
||||||
|
imageCount: 1,
|
||||||
|
imageOffset: 0,
|
||||||
|
want: "compare [img-0] and [img]",
|
||||||
|
wantOffset: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preserves already-numbered placeholders",
|
||||||
|
content: "compare [img-0] and [img-1]",
|
||||||
|
imageCount: 2,
|
||||||
|
imageOffset: 0,
|
||||||
|
want: "compare [img-0] and [img-1]",
|
||||||
|
wantOffset: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, gotOffset := renderContentWithImageTags(tt.content, tt.imageCount, tt.imageOffset)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("content = %q, want %q", got, tt.want)
|
||||||
|
}
|
||||||
|
if gotOffset != tt.wantOffset {
|
||||||
|
t.Fatalf("offset = %d, want %d", gotOffset, tt.wantOffset)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -3,7 +3,6 @@ package renderers
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
|
@ -199,19 +198,18 @@ func (r *LFM2Renderer) renderMessageContent(message api.Message, imageOffset int
|
||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
if r.useImgTags {
|
if r.useImgTags {
|
||||||
for i := range message.Images {
|
content, _ = renderContentWithImageTags(content, len(message.Images), imageOffset)
|
||||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
|
return content
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
placeholder := lfm2ImagePlaceholder(false)
|
var sb strings.Builder
|
||||||
if strings.Contains(content, placeholder) {
|
placeholder := lfm2ImagePlaceholder(false)
|
||||||
return content
|
if strings.Contains(content, placeholder) {
|
||||||
}
|
return content
|
||||||
for range message.Images {
|
}
|
||||||
sb.WriteString(placeholder)
|
for range message.Images {
|
||||||
}
|
sb.WriteString(placeholder)
|
||||||
}
|
}
|
||||||
sb.WriteString(content)
|
sb.WriteString(content)
|
||||||
return sb.String()
|
return sb.String()
|
||||||
|
|
|
||||||
|
|
@ -236,7 +236,7 @@ func TestLFM2Renderer_Images(t *testing.T) {
|
||||||
Content: "Describe this image.",
|
Content: "Describe this image.",
|
||||||
Images: []api.ImageData{api.ImageData("img1")},
|
Images: []api.ImageData{api.ImageData("img1")},
|
||||||
},
|
},
|
||||||
expected: "<|startoftext|><|im_start|>user\n[img-0]Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
expected: "<|startoftext|><|im_start|>user\n[img-0] Describe this image.<|im_end|>\n<|im_start|>assistant\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "existing_template_image_placeholder_not_duplicated",
|
name: "existing_template_image_placeholder_not_duplicated",
|
||||||
|
|
|
||||||
|
|
@ -79,12 +79,14 @@ func (r *Nemotron3NanoRenderer) Render(messages []api.Message, tools []api.Tool,
|
||||||
// Check if previous message was also a tool message
|
// Check if previous message was also a tool message
|
||||||
prevWasTool := i > 0 && loopMessages[i-1].Role == "tool"
|
prevWasTool := i > 0 && loopMessages[i-1].Role == "tool"
|
||||||
nextIsTool := i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool"
|
nextIsTool := i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool"
|
||||||
|
content := r.renderMessageContent(message, imageOffset)
|
||||||
|
imageOffset += len(message.Images)
|
||||||
|
|
||||||
if !prevWasTool {
|
if !prevWasTool {
|
||||||
sb.WriteString("<|im_start|>user\n")
|
sb.WriteString("<|im_start|>user\n")
|
||||||
}
|
}
|
||||||
sb.WriteString("<tool_response>\n")
|
sb.WriteString("<tool_response>\n")
|
||||||
sb.WriteString(message.Content)
|
sb.WriteString(content)
|
||||||
sb.WriteString("\n</tool_response>\n")
|
sb.WriteString("\n</tool_response>\n")
|
||||||
|
|
||||||
if !nextIsTool {
|
if !nextIsTool {
|
||||||
|
|
@ -237,23 +239,8 @@ func (r *Nemotron3NanoRenderer) renderMessageContent(message api.Message, imageO
|
||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(content, "[img-") {
|
content, _ = renderContentWithImageTags(content, len(message.Images), imageOffset)
|
||||||
return content
|
return content
|
||||||
}
|
|
||||||
|
|
||||||
if strings.Contains(content, "[img]") {
|
|
||||||
for i := range message.Images {
|
|
||||||
content = strings.Replace(content, "[img]", fmt.Sprintf("[img-%d]", imageOffset+i), 1)
|
|
||||||
}
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
for i := range message.Images {
|
|
||||||
sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset+i))
|
|
||||||
}
|
|
||||||
sb.WriteString(content)
|
|
||||||
return sb.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func nemotron3NanoRenderContent(content any) string {
|
func nemotron3NanoRenderContent(content any) string {
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ func TestNemotron3NanoRenderer_Images(t *testing.T) {
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img1")}},
|
{Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img1")}},
|
||||||
},
|
},
|
||||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe this image.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0] Describe this image.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "generic image placeholder is rewritten",
|
name: "generic image placeholder is rewritten",
|
||||||
|
|
@ -35,7 +35,7 @@ func TestNemotron3NanoRenderer_Images(t *testing.T) {
|
||||||
{Role: "assistant", Content: "It shows something."},
|
{Role: "assistant", Content: "It shows something."},
|
||||||
{Role: "user", Content: "Compare these.", Images: []api.ImageData{api.ImageData("img2"), api.ImageData("img3")}},
|
{Role: "user", Content: "Compare these.", Images: []api.ImageData{api.ImageData("img2"), api.ImageData("img3")}},
|
||||||
},
|
},
|
||||||
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0]Describe the first image.<|im_end|>\n<|im_start|>assistant\n<think></think>It shows something.<|im_end|>\n<|im_start|>user\n[img-1][img-2]Compare these.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
expected: "\n\n\n<|im_start|>system\n<|im_end|>\n\n<|im_start|>user\n[img-0] Describe the first image.<|im_end|>\n<|im_start|>assistant\n<think></think>It shows something.<|im_end|>\n<|im_start|>user\n[img-1][img-2] Compare these.<|im_end|>\n\n<|im_start|>assistant\n<think>\n",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package renderers
|
package renderers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
|
@ -45,15 +44,14 @@ type Qwen35Renderer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Qwen35Renderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
func (r *Qwen35Renderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||||
|
if r.useImgTags {
|
||||||
|
return renderContentWithImageTags(content.Content, len(content.Images), imageOffset)
|
||||||
|
}
|
||||||
|
|
||||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||||
var subSb strings.Builder
|
var subSb strings.Builder
|
||||||
for range content.Images {
|
for range content.Images {
|
||||||
if r.useImgTags {
|
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||||
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
|
||||||
imageOffset++
|
|
||||||
} else {
|
|
||||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// TODO: support videos
|
// TODO: support videos
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
package renderers
|
package renderers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
|
@ -15,18 +14,17 @@ type Qwen3VLRenderer struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Qwen3VLRenderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
func (r *Qwen3VLRenderer) renderContent(content api.Message, imageOffset int) (string, int) {
|
||||||
|
if r.useImgTags {
|
||||||
|
return renderContentWithImageTags(content.Content, len(content.Images), imageOffset)
|
||||||
|
}
|
||||||
|
|
||||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||||
var subSb strings.Builder
|
var subSb strings.Builder
|
||||||
for range content.Images {
|
for range content.Images {
|
||||||
// TODO: (jmorganca): how to render this is different for different
|
// TODO: (jmorganca): how to render this is different for different
|
||||||
// model backends, and so we should eventually parameterize this or
|
// model backends, and so we should eventually parameterize this or
|
||||||
// only output a placeholder such as [img]
|
// only output a placeholder such as [img]
|
||||||
if r.useImgTags {
|
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||||
subSb.WriteString(fmt.Sprintf("[img-%d]", imageOffset))
|
|
||||||
imageOffset++
|
|
||||||
} else {
|
|
||||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// TODO: support videos
|
// TODO: support videos
|
||||||
|
|
||||||
|
|
@ -126,7 +124,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, think
|
||||||
if i == 0 || messages[i-1].Role != "tool" {
|
if i == 0 || messages[i-1].Role != "tool" {
|
||||||
sb.WriteString("<|im_start|>user")
|
sb.WriteString("<|im_start|>user")
|
||||||
}
|
}
|
||||||
sb.WriteString("\n<tool_response>\n" + message.Content + "\n</tool_response>")
|
sb.WriteString("\n<tool_response>\n" + content + "\n</tool_response>")
|
||||||
if i == len(messages)-1 || messages[i+1].Role != "tool" {
|
if i == len(messages)-1 || messages[i+1].Role != "tool" {
|
||||||
sb.WriteString("<|im_end|>\n")
|
sb.WriteString("<|im_end|>\n")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ Let me analyze this image.`,
|
||||||
},
|
},
|
||||||
useImgTags: true,
|
useImgTags: true,
|
||||||
expected: `<|im_start|>user
|
expected: `<|im_start|>user
|
||||||
[img-0]Describe this image.<|im_end|>
|
[img-0] Describe this image.<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Let me analyze this image.`,
|
Let me analyze this image.`,
|
||||||
},
|
},
|
||||||
|
|
@ -123,7 +123,7 @@ Let me analyze this image.`,
|
||||||
},
|
},
|
||||||
useImgTags: true,
|
useImgTags: true,
|
||||||
expected: `<|im_start|>user
|
expected: `<|im_start|>user
|
||||||
[img-0][img-1]Describe these images.<|im_end|>
|
[img-0][img-1] Describe these images.<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
Let me analyze this image.`,
|
Let me analyze this image.`,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
|
||||||
}
|
}
|
||||||
|
|
||||||
for cnt, msg := range msgs[currMsgIdx:] {
|
renderMsgs := slices.Clone(msgs)
|
||||||
|
|
||||||
|
for cnt, msg := range renderMsgs[currMsgIdx:] {
|
||||||
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
|
||||||
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
||||||
}
|
}
|
||||||
|
|
@ -101,11 +103,16 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||||
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msgs[currMsgIdx+cnt].Content = prefix + prompt
|
|
||||||
|
if m.Config.Renderer != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
renderMsgs[currMsgIdx+cnt].Content = prefix + prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncate any messages that do not fit into the context window
|
// truncate any messages that do not fit into the context window
|
||||||
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
p, err := renderPrompt(m, append(system, renderMsgs[currMsgIdx:]...), tools, think)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -401,11 +401,170 @@ func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
|
||||||
t.Fatalf("len(images) = %d, want %d", got, want)
|
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") {
|
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1] extract text") {
|
||||||
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatPromptRendererAddsToolImageTags(t *testing.T) {
|
||||||
|
msgs := []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "look at this file",
|
||||||
|
Images: []api.ImageData{[]byte("img-1")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_read",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "Read",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "attached image",
|
||||||
|
Images: []api.ImageData{[]byte("img-2")},
|
||||||
|
ToolCallID: "call_read",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
renderer string
|
||||||
|
wantUserTag string
|
||||||
|
wantToolContent string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "gemma4",
|
||||||
|
renderer: "gemma4",
|
||||||
|
wantUserTag: "<|turn>user\n[img-0] look at this file<turn|>\n",
|
||||||
|
wantToolContent: "[img-1] attached image",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3-vl",
|
||||||
|
renderer: "qwen3-vl-instruct",
|
||||||
|
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||||
|
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3.5",
|
||||||
|
renderer: "qwen3.5",
|
||||||
|
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||||
|
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "glm-ocr",
|
||||||
|
renderer: "glm-ocr",
|
||||||
|
wantUserTag: "<|user|>\n[img-0] look at this file",
|
||||||
|
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nemotron-3-nano",
|
||||||
|
renderer: "nemotron-3-nano",
|
||||||
|
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
|
||||||
|
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := Model{
|
||||||
|
Config: model.ConfigV2{Renderer: tt.renderer},
|
||||||
|
ProjectorPaths: []string{"vision"},
|
||||||
|
}
|
||||||
|
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||||
|
think := false
|
||||||
|
|
||||||
|
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(images), 2; got != want {
|
||||||
|
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(prompt, tt.wantUserTag) {
|
||||||
|
t.Fatalf("prompt missing user image tag, got: %q", prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(prompt, tt.wantToolContent) {
|
||||||
|
t.Fatalf("prompt missing tool image tag, got: %q", prompt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChatPromptRendererPreservesExplicitImagePlaceholders(t *testing.T) {
|
||||||
|
msgs := []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "compare [img] and [img]",
|
||||||
|
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
renderer string
|
||||||
|
wantSnippet string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "gemma4",
|
||||||
|
renderer: "gemma4",
|
||||||
|
wantSnippet: "<|turn>user\ncompare [img-0] and [img-1]<turn|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3-vl",
|
||||||
|
renderer: "qwen3-vl-instruct",
|
||||||
|
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "qwen3.5",
|
||||||
|
renderer: "qwen3.5",
|
||||||
|
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "glm-ocr",
|
||||||
|
renderer: "glm-ocr",
|
||||||
|
wantSnippet: "<|user|>\ncompare [img-0] and [img-1]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nemotron-3-nano",
|
||||||
|
renderer: "nemotron-3-nano",
|
||||||
|
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
m := Model{
|
||||||
|
Config: model.ConfigV2{Renderer: tt.renderer},
|
||||||
|
ProjectorPaths: []string{"vision"},
|
||||||
|
}
|
||||||
|
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
|
||||||
|
think := false
|
||||||
|
|
||||||
|
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(images), 2; got != want {
|
||||||
|
t.Fatalf("len(images) = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(prompt, tt.wantSnippet) {
|
||||||
|
t.Fatalf("prompt missing replaced placeholders, got: %q", prompt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
|
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
|
||||||
msgs := []api.Message{{Role: "user", Content: "Hello"}}
|
msgs := []api.Message{{Role: "user", Content: "Hello"}}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue