Detailed changes
@@ -1091,13 +1091,22 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes
}
case fantasy.ToolResultContentTypeMedia:
if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
- content := r.Text
- if content == "" {
- content = fmt.Sprintf("Loaded %s content", r.MediaType)
+ if !stringext.IsValidBase64(r.Data) {
+ slog.Warn("Tool returned media with invalid base64 data, discarding image",
+ "tool", result.ToolName,
+ "tool_call_id", result.ToolCallID,
+ )
+ baseResult.Content = "Tool returned image data with invalid encoding"
+ baseResult.IsError = true
+ } else {
+ content := r.Text
+ if content == "" {
+ content = fmt.Sprintf("Loaded %s content", r.MediaType)
+ }
+ baseResult.Content = content
+ baseResult.Data = r.Data
+ baseResult.MIMEType = r.MediaType
}
- baseResult.Content = content
- baseResult.Data = r.Data
- baseResult.MIMEType = r.MediaType
}
}
@@ -0,0 +1,94 @@
+package agent
+
+import (
+ "encoding/base64"
+ "testing"
+
+ "charm.land/fantasy"
+ "github.com/stretchr/testify/require"
+)
+
+func TestConvertToToolResult_InvalidBase64(t *testing.T) {
+ t.Parallel()
+
+ a := &sessionAgent{}
+ result := fantasy.ToolResultContent{
+ ToolCallID: "call_123",
+ ToolName: "test_tool",
+ Result: fantasy.ToolResultOutputContentMedia{
+ Data: "abc\x80def",
+ MediaType: "image/png",
+ },
+ }
+
+ tr := a.convertToToolResult(result)
+ require.True(t, tr.IsError)
+ require.Empty(t, tr.Data)
+ require.Contains(t, tr.Content, "invalid encoding")
+ require.Equal(t, "call_123", tr.ToolCallID)
+ require.Equal(t, "test_tool", tr.Name)
+}
+
+func TestConvertToToolResult_ValidMedia(t *testing.T) {
+ t.Parallel()
+
+ a := &sessionAgent{}
+ validData := base64.StdEncoding.EncodeToString([]byte("test image data"))
+
+ result := fantasy.ToolResultContent{
+ ToolCallID: "call_456",
+ ToolName: "screenshot",
+ Result: fantasy.ToolResultOutputContentMedia{
+ Data: validData,
+ MediaType: "image/png",
+ Text: "Screenshot captured",
+ },
+ }
+
+ tr := a.convertToToolResult(result)
+ require.False(t, tr.IsError)
+ require.Equal(t, validData, tr.Data)
+ require.Equal(t, "image/png", tr.MIMEType)
+ require.Equal(t, "Screenshot captured", tr.Content)
+}
+
+func TestConvertToToolResult_ValidMediaNoText(t *testing.T) {
+ t.Parallel()
+
+ a := &sessionAgent{}
+ validData := base64.StdEncoding.EncodeToString([]byte("test image data"))
+
+ result := fantasy.ToolResultContent{
+ ToolCallID: "call_789",
+ ToolName: "view",
+ Result: fantasy.ToolResultOutputContentMedia{
+ Data: validData,
+ MediaType: "image/jpeg",
+ },
+ }
+
+ tr := a.convertToToolResult(result)
+ require.False(t, tr.IsError)
+ require.Equal(t, validData, tr.Data)
+ require.Equal(t, "image/jpeg", tr.MIMEType)
+ require.Equal(t, "Loaded image/jpeg content", tr.Content)
+}
+
+func TestConvertToToolResult_ASCIIButInvalidBase64(t *testing.T) {
+ t.Parallel()
+
+ a := &sessionAgent{}
+ result := fantasy.ToolResultContent{
+ ToolCallID: "call_abc",
+ ToolName: "mcp_tool",
+ Result: fantasy.ToolResultOutputContentMedia{
+ Data: "not-valid-base64!!!",
+ MediaType: "image/png",
+ },
+ }
+
+ tr := a.convertToToolResult(result)
+ require.True(t, tr.IsError)
+ require.Empty(t, tr.Data)
+ require.Contains(t, tr.Content, "invalid encoding")
+}
@@ -88,7 +88,7 @@ func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string
return ToolResult{
Type: "image",
Content: textContent,
- Data: ensureBase64(imageData),
+ Data: ensureRawBytes(imageData),
MediaType: imageMimeType,
}, nil
}
@@ -97,7 +97,7 @@ func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string
return ToolResult{
Type: "media",
Content: textContent,
- Data: ensureBase64(audioData),
+ Data: ensureRawBytes(audioData),
MediaType: audioMimeType,
}, nil
}
@@ -167,23 +167,29 @@ func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool)
return filtered
}
-// ensureBase64 normalizes valid base64 input and guarantees padded
-// base64.StdEncoding output; otherwise it encodes raw binary data.
-func ensureBase64(data []byte) []byte {
+// ensureRawBytes normalizes MCP media data into raw binary bytes.
+//
+// The MCP Go SDK's json.Unmarshal normally base64-decodes
+// ImageContent.Data into raw bytes automatically. However, some MCP
+// transports (notably Docker over stdio) can deliver data in
+// unexpected formats. This function handles both cases:
+//
+// - If data looks like a valid base64 string (ASCII-only, decodable)
+// it is decoded and the raw bytes are returned.
+// - If data is already raw binary (contains bytes > 127) it is
+// returned as-is.
+func ensureRawBytes(data []byte) []byte {
if len(data) == 0 {
return data
}
normalized := normalizeBase64Input(data)
if decoded, ok := decodeBase64(normalized); ok {
- encoded := make([]byte, base64.StdEncoding.EncodedLen(len(decoded)))
- base64.StdEncoding.Encode(encoded, decoded)
- return encoded
+ return decoded
}
- encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
- base64.StdEncoding.Encode(encoded, data)
- return encoded
+ // Already raw binary — return unchanged.
+ return data
}
func normalizeBase64Input(data []byte) []byte {
@@ -213,24 +219,3 @@ func decodeBase64(data []byte) ([]byte, bool) {
}
return nil, false
}
-
-// isValidBase64 checks if the data appears to be valid base64-encoded content.
-func isValidBase64(data []byte) bool {
- if len(data) == 0 {
- return true
- }
-
- // Base64 strings should only contain ASCII characters.
- for _, b := range data {
- if b > 127 {
- return false
- }
- }
-
- s := string(data)
- if _, err := base64.StdEncoding.DecodeString(s); err == nil {
- return true
- }
- _, err := base64.RawStdEncoding.DecodeString(s)
- return err == nil
-}
@@ -1,34 +1,35 @@
package mcp
import (
+ "bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/require"
)
-func TestEnsureBase64(t *testing.T) {
+func TestEnsureRawBytes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []byte
- wantData []byte // expected output
+ wantData []byte
}{
{
name: "already base64 encoded",
input: []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
- wantData: []byte("SGVsbG8gV29ybGQh"),
+ wantData: []byte("Hello World!"),
},
{
name: "raw binary data (PNG header)",
input: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
- wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})),
+ wantData: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
},
{
name: "raw binary with high bytes",
input: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
- wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})),
+ wantData: []byte{0xFF, 0xD8, 0xFF, 0xE0},
},
{
name: "empty data",
@@ -38,88 +39,31 @@ func TestEnsureBase64(t *testing.T) {
{
name: "base64 with padding",
input: []byte("YQ=="), // "a" in base64
- wantData: []byte("YQ=="),
+ wantData: []byte("a"),
},
{
name: "base64 without padding",
input: []byte("YQ"),
- wantData: []byte("YQ=="),
+ wantData: []byte("a"),
},
{
name: "base64 with whitespace",
input: []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
- wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
+ wantData: []byte("SGVsbG8gV29ybGQh"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
- result := ensureBase64(tt.input)
+ result := ensureRawBytes(tt.input)
require.Equal(t, tt.wantData, result)
- // Verify the result is valid base64 that can be decoded.
- if len(result) > 0 {
- _, err := base64.StdEncoding.DecodeString(string(result))
- if err != nil {
- _, err = base64.RawStdEncoding.DecodeString(string(result))
- }
- require.NoError(t, err, "result should be valid base64")
+ if len(result) > 0 && !bytes.Equal(result, tt.input) {
+ reEncoded := base64.StdEncoding.EncodeToString(result)
+ _, err := base64.StdEncoding.DecodeString(reEncoded)
+ require.NoError(t, err, "re-encoded result should be valid base64")
}
})
}
}
-
-func TestIsValidBase64(t *testing.T) {
- t.Parallel()
-
- tests := []struct {
- name string
- input []byte
- want bool
- }{
- {
- name: "valid base64",
- input: []byte("SGVsbG8gV29ybGQh"),
- want: true,
- },
- {
- name: "valid base64 with padding",
- input: []byte("YQ=="),
- want: true,
- },
- {
- name: "raw binary with high bytes",
- input: []byte{0xFF, 0xD8, 0xFF},
- want: false,
- },
- {
- name: "empty",
- input: []byte{},
- want: true,
- },
- {
- name: "valid raw base64 without padding",
- input: []byte("YQ"),
- want: true,
- },
- {
- name: "valid base64 with whitespace",
- input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")),
- want: true,
- },
- {
- name: "invalid base64 characters",
- input: []byte("SGVsbG8!@#$"),
- want: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- t.Parallel()
- got := isValidBase64(tt.input)
- require.Equal(t, tt.want, got)
- })
- }
-}
@@ -4,7 +4,6 @@ import (
"bufio"
"context"
_ "embed"
- "encoding/base64"
"fmt"
"io"
"io/fs"
@@ -189,8 +188,7 @@ func NewViewTool(
return fantasy.ToolResponse{}, fmt.Errorf("error reading image file: %w", readErr)
}
- encoded := base64.StdEncoding.EncodeToString(imageData)
- return fantasy.NewImageResponse([]byte(encoded), mimeType), nil
+ return fantasy.NewImageResponse(imageData, mimeType), nil
}
// Read the file content
@@ -13,6 +13,7 @@ import (
"charm.land/fantasy/providers/anthropic"
"charm.land/fantasy/providers/google"
"charm.land/fantasy/providers/openai"
+ "github.com/charmbracelet/crush/internal/stringext"
)
type MessageRole string
@@ -24,6 +25,10 @@ const (
Tool MessageRole = "tool"
)
+// mediaLoadFailedPlaceholder is the text substituted for image data that
+// cannot be decoded during session replay.
+const mediaLoadFailedPlaceholder = "[Image data could not be loaded]"
+
type FinishReason string
const (
@@ -542,9 +547,15 @@ func (m *Message) ToAIMessage() []fantasy.Message {
Error: errors.New(result.Content),
}
} else if result.Data != "" {
- content = fantasy.ToolResultOutputContentMedia{
- Data: result.Data,
- MediaType: result.MIMEType,
+ if stringext.IsValidBase64(result.Data) {
+ content = fantasy.ToolResultOutputContentMedia{
+ Data: result.Data,
+ MediaType: result.MIMEType,
+ }
+ } else {
+ content = fantasy.ToolResultOutputContentText{
+ Text: mediaLoadFailedPlaceholder,
+ }
}
} else {
content = fantasy.ToolResultOutputContentText{
@@ -1,9 +1,13 @@
package message
import (
+ "encoding/base64"
"fmt"
"strings"
"testing"
+
+ "charm.land/fantasy"
+ "github.com/stretchr/testify/require"
)
func makeTestAttachments(n int, contentSize int) []Attachment {
@@ -19,6 +23,99 @@ func makeTestAttachments(n int, contentSize int) []Attachment {
return attachments
}
+func TestToAIMessage_CorruptedMediaData(t *testing.T) {
+ t.Parallel()
+
+ msg := &Message{
+ Role: Tool,
+ Parts: []ContentPart{
+ ToolResult{
+ ToolCallID: "call_123",
+ Name: "screenshot",
+ Content: "Loaded image/png content",
+ Data: "abc\x80def",
+ MIMEType: "image/png",
+ },
+ },
+ }
+
+ messages := msg.ToAIMessage()
+ require.Len(t, messages, 1)
+ require.Len(t, messages[0].Content, 1)
+
+ part, ok := messages[0].Content[0].(fantasy.ToolResultPart)
+ require.True(t, ok)
+
+ require.Equal(t, "call_123", part.ToolCallID)
+
+ textContent, ok := part.Output.(fantasy.ToolResultOutputContentText)
+ require.True(t, ok, "corrupted media should be downgraded to text")
+ require.Equal(t, mediaLoadFailedPlaceholder, textContent.Text)
+}
+
+func TestToAIMessage_ValidMediaData(t *testing.T) {
+ t.Parallel()
+
+ validBase64 := base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47})
+
+ msg := &Message{
+ Role: Tool,
+ Parts: []ContentPart{
+ ToolResult{
+ ToolCallID: "call_456",
+ Name: "screenshot",
+ Content: "Loaded image/png content",
+ Data: validBase64,
+ MIMEType: "image/png",
+ },
+ },
+ }
+
+ messages := msg.ToAIMessage()
+ require.Len(t, messages, 1)
+ require.Len(t, messages[0].Content, 1)
+
+ part, ok := messages[0].Content[0].(fantasy.ToolResultPart)
+ require.True(t, ok)
+
+ require.Equal(t, "call_456", part.ToolCallID)
+
+ mediaContent, ok := part.Output.(fantasy.ToolResultOutputContentMedia)
+ require.True(t, ok, "valid media should remain as media")
+ require.Equal(t, validBase64, mediaContent.Data)
+ require.Equal(t, "image/png", mediaContent.MediaType)
+}
+
+func TestToAIMessage_ASCIIButInvalidBase64(t *testing.T) {
+ t.Parallel()
+
+ msg := &Message{
+ Role: Tool,
+ Parts: []ContentPart{
+ ToolResult{
+ ToolCallID: "call_789",
+ Name: "screenshot",
+ Content: "Loaded image/png content",
+ Data: "not-valid-base64!!!",
+ MIMEType: "image/png",
+ },
+ },
+ }
+
+ messages := msg.ToAIMessage()
+ require.Len(t, messages, 1)
+ require.Len(t, messages[0].Content, 1)
+
+ part, ok := messages[0].Content[0].(fantasy.ToolResultPart)
+ require.True(t, ok)
+
+ require.Equal(t, "call_789", part.ToolCallID)
+
+ textContent, ok := part.Output.(fantasy.ToolResultOutputContentText)
+ require.True(t, ok, "ASCII but invalid base64 should be downgraded to text")
+ require.Equal(t, mediaLoadFailedPlaceholder, textContent.Text)
+}
+
func BenchmarkPromptWithTextAttachments(b *testing.B) {
cases := []struct {
name string
@@ -1,6 +1,7 @@
package stringext
import (
+ "encoding/base64"
"strings"
"golang.org/x/text/cases"
@@ -20,3 +21,20 @@ func NormalizeSpace(content string) string {
content = strings.TrimSpace(content)
return content
}
+
+// IsValidBase64 reports whether s is canonical base64 under standard
+// encoding (RFC 4648). It requires that s round-trips through
+// decode/encode unchanged — rejecting whitespace, missing padding,
+// and other leniencies that DecodeString alone would accept.
+func IsValidBase64(s string) bool {
+ if s == "" {
+ return false
+ }
+ decoded, err := base64.StdEncoding.DecodeString(s)
+ if err != nil {
+ return false
+ }
+ // Round-trip check rejects whitespace, missing padding, and other
+ // leniencies that DecodeString silently accepts.
+ return base64.StdEncoding.EncodeToString(decoded) == s
+}
@@ -0,0 +1,38 @@
+package stringext
+
+import (
+ "encoding/base64"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestIsValidBase64(t *testing.T) {
+ t.Parallel()
+
+ // Real PNG header encoded in standard base64.
+ pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
+ pngBase64 := base64.StdEncoding.EncodeToString(pngHeader)
+
+ tests := []struct {
+ name string
+ input string
+ expected bool
+ }{
+ {name: "empty string", input: "", expected: false},
+ {name: "valid no padding", input: "SGVsbG8gV29ybGQh", expected: true},
+ {name: "valid with padding", input: "YQ==", expected: true},
+ {name: "non-ASCII bytes", input: "abc\x80def", expected: false},
+ {name: "ASCII but not base64", input: "hello world!!!", expected: false},
+ {name: "raw encoding no padding", input: "YQ", expected: false},
+ {name: "trailing whitespace", input: "YQ==\n", expected: false},
+ {name: "valid PNG header base64", input: pngBase64, expected: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.expected, IsValidBase64(tt.input))
+ })
+ }
+}