From efe42e801fbb5a44db1463f2e8ace2957328e504 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Sun, 12 Apr 2026 10:07:41 -0400 Subject: [PATCH] fix(agent): prevent session corruption due to malformed image data (#2597) --- internal/agent/agent.go | 21 ++++-- internal/agent/convert_test.go | 94 +++++++++++++++++++++++++ internal/agent/tools/mcp/tools.go | 49 +++++-------- internal/agent/tools/mcp/tools_test.go | 84 ++++------------------ internal/agent/tools/view.go | 4 +- internal/message/content.go | 17 ++++- internal/message/content_test.go | 97 ++++++++++++++++++++++++++ internal/stringext/string.go | 18 +++++ internal/stringext/string_test.go | 38 ++++++++++ 9 files changed, 308 insertions(+), 114 deletions(-) create mode 100644 internal/agent/convert_test.go create mode 100644 internal/stringext/string_test.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index c8a8cecd22400049e793e8f7c312dc6384cf68b2..ace750512ee94c69b67045620934a5d828dfd2db 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -1091,13 +1091,22 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes } case fantasy.ToolResultContentTypeMedia: if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok { - content := r.Text - if content == "" { - content = fmt.Sprintf("Loaded %s content", r.MediaType) + if !stringext.IsValidBase64(r.Data) { + slog.Warn("Tool returned media with invalid base64 data, discarding image", + "tool", result.ToolName, + "tool_call_id", result.ToolCallID, + ) + baseResult.Content = "Tool returned image data with invalid encoding" + baseResult.IsError = true + } else { + content := r.Text + if content == "" { + content = fmt.Sprintf("Loaded %s content", r.MediaType) + } + baseResult.Content = content + baseResult.Data = r.Data + baseResult.MIMEType = r.MediaType } - baseResult.Content = content - baseResult.Data = r.Data - baseResult.MIMEType = r.MediaType } } diff --git a/internal/agent/convert_test.go b/internal/agent/convert_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4aa382e7cf0015c825c806678908ff97fc8828dd --- /dev/null +++ b/internal/agent/convert_test.go @@ -0,0 +1,94 @@ +package agent + +import ( + "encoding/base64" + "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/require" +) + +func TestConvertToToolResult_InvalidBase64(t *testing.T) { + t.Parallel() + + a := &sessionAgent{} + result := fantasy.ToolResultContent{ + ToolCallID: "call_123", + ToolName: "test_tool", + Result: fantasy.ToolResultOutputContentMedia{ + Data: "abc\x80def", + MediaType: "image/png", + }, + } + + tr := a.convertToToolResult(result) + require.True(t, tr.IsError) + require.Empty(t, tr.Data) + require.Contains(t, tr.Content, "invalid encoding") + require.Equal(t, "call_123", tr.ToolCallID) + require.Equal(t, "test_tool", tr.Name) +} + +func TestConvertToToolResult_ValidMedia(t *testing.T) { + t.Parallel() + + a := &sessionAgent{} + validData := base64.StdEncoding.EncodeToString([]byte("test image data")) + + result := fantasy.ToolResultContent{ + ToolCallID: "call_456", + ToolName: "screenshot", + Result: fantasy.ToolResultOutputContentMedia{ + Data: validData, + MediaType: "image/png", + Text: "Screenshot captured", + }, + } + + tr := a.convertToToolResult(result) + require.False(t, tr.IsError) + require.Equal(t, validData, tr.Data) + require.Equal(t, "image/png", tr.MIMEType) + require.Equal(t, "Screenshot captured", tr.Content) +} + +func TestConvertToToolResult_ValidMediaNoText(t *testing.T) { + t.Parallel() + + a := &sessionAgent{} + validData := base64.StdEncoding.EncodeToString([]byte("test image data")) + + result := fantasy.ToolResultContent{ + ToolCallID: "call_789", + ToolName: "view", + Result: fantasy.ToolResultOutputContentMedia{ + Data: validData, + MediaType: "image/jpeg", + }, + } + + tr := a.convertToToolResult(result) + require.False(t, tr.IsError) + require.Equal(t, validData, tr.Data) + require.Equal(t, "image/jpeg", tr.MIMEType) + require.Equal(t, "Loaded image/jpeg content", tr.Content) +} + +func TestConvertToToolResult_ASCIIButInvalidBase64(t *testing.T) { + t.Parallel() + + a := &sessionAgent{} + result := fantasy.ToolResultContent{ + ToolCallID: "call_abc", + ToolName: "mcp_tool", + Result: fantasy.ToolResultOutputContentMedia{ + Data: "not-valid-base64!!!", + MediaType: "image/png", + }, + } + + tr := a.convertToToolResult(result) + require.True(t, tr.IsError) + require.Empty(t, tr.Data) + require.Contains(t, tr.Content, "invalid encoding") +} diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index 05d6b2b75d8fadff2e9af8385817ac135722f1a8..4da41779abf53cf825c701abab03bd4c84aa8298 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -88,7 +88,7 @@ func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string return ToolResult{ Type: "image", Content: textContent, - Data: ensureBase64(imageData), + Data: ensureRawBytes(imageData), MediaType: imageMimeType, }, nil } @@ -97,7 +97,7 @@ func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string return ToolResult{ Type: "media", Content: textContent, - Data: ensureBase64(audioData), + Data: ensureRawBytes(audioData), MediaType: audioMimeType, }, nil } @@ -167,23 +167,29 @@ func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) return filtered } -// ensureBase64 normalizes valid base64 input and guarantees padded -// base64.StdEncoding output; otherwise it encodes raw binary data. -func ensureBase64(data []byte) []byte { +// ensureRawBytes normalizes MCP media data into raw binary bytes. +// +// The MCP Go SDK's json.Unmarshal normally base64-decodes +// ImageContent.Data into raw bytes automatically. However, some MCP +// transports (notably Docker over stdio) can deliver data in +// unexpected formats. This function handles both cases: +// +// - If data looks like a valid base64 string (ASCII-only, decodable) +// it is decoded and the raw bytes are returned. +// - If data is already raw binary (contains bytes > 127) it is +// returned as-is. +func ensureRawBytes(data []byte) []byte { if len(data) == 0 { return data } normalized := normalizeBase64Input(data) if decoded, ok := decodeBase64(normalized); ok { - encoded := make([]byte, base64.StdEncoding.EncodedLen(len(decoded))) - base64.StdEncoding.Encode(encoded, decoded) - return encoded + return decoded } - encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data))) - base64.StdEncoding.Encode(encoded, data) - return encoded + // Already raw binary — return unchanged. + return data } func normalizeBase64Input(data []byte) []byte { @@ -213,24 +219,3 @@ func decodeBase64(data []byte) ([]byte, bool) { } return nil, false } - -// isValidBase64 checks if the data appears to be valid base64-encoded content. -func isValidBase64(data []byte) bool { - if len(data) == 0 { - return true - } - - // Base64 strings should only contain ASCII characters. - for _, b := range data { - if b > 127 { - return false - } - } - - s := string(data) - if _, err := base64.StdEncoding.DecodeString(s); err == nil { - return true - } - _, err := base64.RawStdEncoding.DecodeString(s) - return err == nil -} diff --git a/internal/agent/tools/mcp/tools_test.go b/internal/agent/tools/mcp/tools_test.go index aae4428ed6b830549540611761c22f070eeda925..935e17be42be5e45a592a5aed909aa4a2bfb3d48 100644 --- a/internal/agent/tools/mcp/tools_test.go +++ b/internal/agent/tools/mcp/tools_test.go @@ -1,34 +1,35 @@ package mcp import ( + "bytes" "encoding/base64" "testing" "github.com/stretchr/testify/require" ) -func TestEnsureBase64(t *testing.T) { +func TestEnsureRawBytes(t *testing.T) { t.Parallel() tests := []struct { name string input []byte - wantData []byte // expected output + wantData []byte }{ { name: "already base64 encoded", input: []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64 - wantData: []byte("SGVsbG8gV29ybGQh"), + wantData: []byte("Hello World!"), }, { name: "raw binary data (PNG header)", input: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, - wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})), + wantData: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, }, { name: "raw binary with high bytes", input: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header - wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})), + wantData: []byte{0xFF, 0xD8, 0xFF, 0xE0}, }, { name: "empty data", @@ -38,88 +39,31 @@ func TestEnsureBase64(t *testing.T) { { name: "base64 with padding", input: []byte("YQ=="), // "a" in base64 - wantData: []byte("YQ=="), + wantData: []byte("a"), }, { name: "base64 without padding", input: []byte("YQ"), - wantData: []byte("YQ=="), + wantData: []byte("a"), }, { name: "base64 with whitespace", input: []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"), - wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="), + wantData: []byte("SGVsbG8gV29ybGQh"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - result := ensureBase64(tt.input) + result := ensureRawBytes(tt.input) require.Equal(t, tt.wantData, result) - // Verify the result is valid base64 that can be decoded. - if len(result) > 0 { - _, err := base64.StdEncoding.DecodeString(string(result)) - if err != nil { - _, err = base64.RawStdEncoding.DecodeString(string(result)) - } - require.NoError(t, err, "result should be valid base64") + if len(result) > 0 && !bytes.Equal(result, tt.input) { + reEncoded := base64.StdEncoding.EncodeToString(result) + _, err := base64.StdEncoding.DecodeString(reEncoded) + require.NoError(t, err, "re-encoded result should be valid base64") } }) } } - -func TestIsValidBase64(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input []byte - want bool - }{ - { - name: "valid base64", - input: []byte("SGVsbG8gV29ybGQh"), - want: true, - }, - { - name: "valid base64 with padding", - input: []byte("YQ=="), - want: true, - }, - { - name: "raw binary with high bytes", - input: []byte{0xFF, 0xD8, 0xFF}, - want: false, - }, - { - name: "empty", - input: []byte{}, - want: true, - }, - { - name: "valid raw base64 without padding", - input: []byte("YQ"), - want: true, - }, - { - name: "valid base64 with whitespace", - input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")), - want: true, - }, - { - name: "invalid base64 characters", - input: []byte("SGVsbG8!@#$"), - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - got := isValidBase64(tt.input) - require.Equal(t, tt.want, got) - }) - } -} diff --git a/internal/agent/tools/view.go b/internal/agent/tools/view.go index 37e3c3a1ead9ff53455f2447272cf0be60c734d2..efacba1423ee0acafb8361537982af47ffbf74a1 100644 --- a/internal/agent/tools/view.go +++ b/internal/agent/tools/view.go @@ -4,7 +4,6 @@ import ( "bufio" "context" _ "embed" - "encoding/base64" "fmt" "io" "io/fs" @@ -189,8 +188,7 @@ func NewViewTool( return fantasy.ToolResponse{}, fmt.Errorf("error reading image file: %w", readErr) } - encoded := base64.StdEncoding.EncodeToString(imageData) - return fantasy.NewImageResponse([]byte(encoded), mimeType), nil + return fantasy.NewImageResponse(imageData, mimeType), nil } // Read the file content diff --git a/internal/message/content.go b/internal/message/content.go index 02f949334b688e4dd40c832d5f68d52523ac9953..b13e69c3cf6af556dde4dc22c7c589a702e39f45 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -13,6 +13,7 @@ import ( "charm.land/fantasy/providers/anthropic" "charm.land/fantasy/providers/google" "charm.land/fantasy/providers/openai" + "github.com/charmbracelet/crush/internal/stringext" ) type MessageRole string @@ -24,6 +25,10 @@ const ( Tool MessageRole = "tool" ) +// mediaLoadFailedPlaceholder is the text substituted for image data that +// cannot be decoded during session replay. +const mediaLoadFailedPlaceholder = "[Image data could not be loaded]" + type FinishReason string const ( @@ -542,9 +547,15 @@ func (m *Message) ToAIMessage() []fantasy.Message { Error: errors.New(result.Content), } } else if result.Data != "" { - content = fantasy.ToolResultOutputContentMedia{ - Data: result.Data, - MediaType: result.MIMEType, + if stringext.IsValidBase64(result.Data) { + content = fantasy.ToolResultOutputContentMedia{ + Data: result.Data, + MediaType: result.MIMEType, + } + } else { + content = fantasy.ToolResultOutputContentText{ + Text: mediaLoadFailedPlaceholder, + } } } else { content = fantasy.ToolResultOutputContentText{ diff --git a/internal/message/content_test.go b/internal/message/content_test.go index 7e9e273c57e4b6cee2df8cd6b74bf455797bce36..04e601012aa83e9512e3de4be8386adf6ab909cc 100644 --- a/internal/message/content_test.go +++ b/internal/message/content_test.go @@ -1,9 +1,13 @@ package message import ( + "encoding/base64" "fmt" "strings" "testing" + + "charm.land/fantasy" + "github.com/stretchr/testify/require" ) func makeTestAttachments(n int, contentSize int) []Attachment { @@ -19,6 +23,99 @@ func makeTestAttachments(n int, contentSize int) []Attachment { return attachments } +func TestToAIMessage_CorruptedMediaData(t *testing.T) { + t.Parallel() + + msg := &Message{ + Role: Tool, + Parts: []ContentPart{ + ToolResult{ + ToolCallID: "call_123", + Name: "screenshot", + Content: "Loaded image/png content", + Data: "abc\x80def", + MIMEType: "image/png", + }, + }, + } + + messages := msg.ToAIMessage() + require.Len(t, messages, 1) + require.Len(t, messages[0].Content, 1) + + part, ok := messages[0].Content[0].(fantasy.ToolResultPart) + require.True(t, ok) + + require.Equal(t, "call_123", part.ToolCallID) + + textContent, ok := part.Output.(fantasy.ToolResultOutputContentText) + require.True(t, ok, "corrupted media should be downgraded to text") + require.Equal(t, mediaLoadFailedPlaceholder, textContent.Text) +} + +func TestToAIMessage_ValidMediaData(t *testing.T) { + t.Parallel() + + validBase64 := base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47}) + + msg := &Message{ + Role: Tool, + Parts: []ContentPart{ + ToolResult{ + ToolCallID: "call_456", + Name: "screenshot", + Content: "Loaded image/png content", + Data: validBase64, + MIMEType: "image/png", + }, + }, + } + + messages := msg.ToAIMessage() + require.Len(t, messages, 1) + require.Len(t, messages[0].Content, 1) + + part, ok := messages[0].Content[0].(fantasy.ToolResultPart) + require.True(t, ok) + + require.Equal(t, "call_456", part.ToolCallID) + + mediaContent, ok := part.Output.(fantasy.ToolResultOutputContentMedia) + require.True(t, ok, "valid media should remain as media") + require.Equal(t, validBase64, mediaContent.Data) + require.Equal(t, "image/png", mediaContent.MediaType) +} + +func TestToAIMessage_ASCIIButInvalidBase64(t *testing.T) { + t.Parallel() + + msg := &Message{ + Role: Tool, + Parts: []ContentPart{ + ToolResult{ + ToolCallID: "call_789", + Name: "screenshot", + Content: "Loaded image/png content", + Data: "not-valid-base64!!!", + MIMEType: "image/png", + }, + }, + } + + messages := msg.ToAIMessage() + require.Len(t, messages, 1) + require.Len(t, messages[0].Content, 1) + + part, ok := messages[0].Content[0].(fantasy.ToolResultPart) + require.True(t, ok) + + require.Equal(t, "call_789", part.ToolCallID) + + textContent, ok := part.Output.(fantasy.ToolResultOutputContentText) + require.True(t, ok, "ASCII but invalid base64 should be downgraded to text") + require.Equal(t, mediaLoadFailedPlaceholder, textContent.Text) +} + func BenchmarkPromptWithTextAttachments(b *testing.B) { cases := []struct { name string diff --git a/internal/stringext/string.go b/internal/stringext/string.go index 8be28ccc2096c3d54b9f3106ed30d584503acdf4..de79ee866755250396dcc8124bfceda8063005ad 100644 --- a/internal/stringext/string.go +++ b/internal/stringext/string.go @@ -1,6 +1,7 @@ package stringext import ( + "encoding/base64" "strings" "golang.org/x/text/cases" @@ -20,3 +21,20 @@ func NormalizeSpace(content string) string { content = strings.TrimSpace(content) return content } + +// IsValidBase64 reports whether s is canonical base64 under standard +// encoding (RFC 4648). It requires that s round-trips through +// decode/encode unchanged — rejecting whitespace, missing padding, +// and other leniencies that DecodeString alone would accept. +func IsValidBase64(s string) bool { + if s == "" { + return false + } + decoded, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return false + } + // Round-trip check rejects whitespace, missing padding, and other + // leniencies that DecodeString silently accepts. + return base64.StdEncoding.EncodeToString(decoded) == s +} diff --git a/internal/stringext/string_test.go b/internal/stringext/string_test.go new file mode 100644 index 0000000000000000000000000000000000000000..55557bc98ea144d7dc9ecaead3c11f6588821056 --- /dev/null +++ b/internal/stringext/string_test.go @@ -0,0 +1,38 @@ +package stringext + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsValidBase64(t *testing.T) { + t.Parallel() + + // Real PNG header encoded in standard base64. + pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} + pngBase64 := base64.StdEncoding.EncodeToString(pngHeader) + + tests := []struct { + name string + input string + expected bool + }{ + {name: "empty string", input: "", expected: false}, + {name: "valid no padding", input: "SGVsbG8gV29ybGQh", expected: true}, + {name: "valid with padding", input: "YQ==", expected: true}, + {name: "non-ASCII bytes", input: "abc\x80def", expected: false}, + {name: "ASCII but not base64", input: "hello world!!!", expected: false}, + {name: "raw encoding no padding", input: "YQ", expected: false}, + {name: "trailing whitespace", input: "YQ==\n", expected: false}, + {name: "valid PNG header base64", input: pngBase64, expected: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.expected, IsValidBase64(tt.input)) + }) + } +}