fix(agent): prevent session corruption due to malformed image data (#2597)

Christian Rocha created

Change summary

internal/agent/agent.go                | 21 ++++-
internal/agent/convert_test.go         | 94 +++++++++++++++++++++++++++
internal/agent/tools/mcp/tools.go      | 49 ++++---------
internal/agent/tools/mcp/tools_test.go | 84 ++++--------------------
internal/agent/tools/view.go           |  4 
internal/message/content.go            | 17 ++++
internal/message/content_test.go       | 97 ++++++++++++++++++++++++++++
internal/stringext/string.go           | 18 +++++
internal/stringext/string_test.go      | 38 ++++++++++
9 files changed, 308 insertions(+), 114 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -1091,13 +1091,22 @@ func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) mes
 		}
 	case fantasy.ToolResultContentTypeMedia:
 		if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
-			content := r.Text
-			if content == "" {
-				content = fmt.Sprintf("Loaded %s content", r.MediaType)
+			if !stringext.IsValidBase64(r.Data) {
+				slog.Warn("Tool returned media with invalid base64 data, discarding image",
+					"tool", result.ToolName,
+					"tool_call_id", result.ToolCallID,
+				)
+				baseResult.Content = "Tool returned image data with invalid encoding"
+				baseResult.IsError = true
+			} else {
+				content := r.Text
+				if content == "" {
+					content = fmt.Sprintf("Loaded %s content", r.MediaType)
+				}
+				baseResult.Content = content
+				baseResult.Data = r.Data
+				baseResult.MIMEType = r.MediaType
 			}
-			baseResult.Content = content
-			baseResult.Data = r.Data
-			baseResult.MIMEType = r.MediaType
 		}
 	}
 

internal/agent/convert_test.go 🔗

@@ -0,0 +1,94 @@
+package agent
+
+import (
+	"encoding/base64"
+	"testing"
+
+	"charm.land/fantasy"
+	"github.com/stretchr/testify/require"
+)
+
+func TestConvertToToolResult_InvalidBase64(t *testing.T) {
+	t.Parallel()
+
+	a := &sessionAgent{}
+	result := fantasy.ToolResultContent{
+		ToolCallID: "call_123",
+		ToolName:   "test_tool",
+		Result: fantasy.ToolResultOutputContentMedia{
+			Data:      "abc\x80def",
+			MediaType: "image/png",
+		},
+	}
+
+	tr := a.convertToToolResult(result)
+	require.True(t, tr.IsError)
+	require.Empty(t, tr.Data)
+	require.Contains(t, tr.Content, "invalid encoding")
+	require.Equal(t, "call_123", tr.ToolCallID)
+	require.Equal(t, "test_tool", tr.Name)
+}
+
+func TestConvertToToolResult_ValidMedia(t *testing.T) {
+	t.Parallel()
+
+	a := &sessionAgent{}
+	validData := base64.StdEncoding.EncodeToString([]byte("test image data"))
+
+	result := fantasy.ToolResultContent{
+		ToolCallID: "call_456",
+		ToolName:   "screenshot",
+		Result: fantasy.ToolResultOutputContentMedia{
+			Data:      validData,
+			MediaType: "image/png",
+			Text:      "Screenshot captured",
+		},
+	}
+
+	tr := a.convertToToolResult(result)
+	require.False(t, tr.IsError)
+	require.Equal(t, validData, tr.Data)
+	require.Equal(t, "image/png", tr.MIMEType)
+	require.Equal(t, "Screenshot captured", tr.Content)
+}
+
+func TestConvertToToolResult_ValidMediaNoText(t *testing.T) {
+	t.Parallel()
+
+	a := &sessionAgent{}
+	validData := base64.StdEncoding.EncodeToString([]byte("test image data"))
+
+	result := fantasy.ToolResultContent{
+		ToolCallID: "call_789",
+		ToolName:   "view",
+		Result: fantasy.ToolResultOutputContentMedia{
+			Data:      validData,
+			MediaType: "image/jpeg",
+		},
+	}
+
+	tr := a.convertToToolResult(result)
+	require.False(t, tr.IsError)
+	require.Equal(t, validData, tr.Data)
+	require.Equal(t, "image/jpeg", tr.MIMEType)
+	require.Equal(t, "Loaded image/jpeg content", tr.Content)
+}
+
+func TestConvertToToolResult_ASCIIButInvalidBase64(t *testing.T) {
+	t.Parallel()
+
+	a := &sessionAgent{}
+	result := fantasy.ToolResultContent{
+		ToolCallID: "call_abc",
+		ToolName:   "mcp_tool",
+		Result: fantasy.ToolResultOutputContentMedia{
+			Data:      "not-valid-base64!!!",
+			MediaType: "image/png",
+		},
+	}
+
+	tr := a.convertToToolResult(result)
+	require.True(t, tr.IsError)
+	require.Empty(t, tr.Data)
+	require.Contains(t, tr.Content, "invalid encoding")
+}

internal/agent/tools/mcp/tools.go 🔗

@@ -88,7 +88,7 @@ func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string
 		return ToolResult{
 			Type:      "image",
 			Content:   textContent,
-			Data:      ensureBase64(imageData),
+			Data:      ensureRawBytes(imageData),
 			MediaType: imageMimeType,
 		}, nil
 	}
@@ -97,7 +97,7 @@ func RunTool(ctx context.Context, cfg *config.ConfigStore, name, toolName string
 		return ToolResult{
 			Type:      "media",
 			Content:   textContent,
-			Data:      ensureBase64(audioData),
+			Data:      ensureRawBytes(audioData),
 			MediaType: audioMimeType,
 		}, nil
 	}
@@ -167,23 +167,29 @@ func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool)
 	return filtered
 }
 
-// ensureBase64 normalizes valid base64 input and guarantees padded
-// base64.StdEncoding output; otherwise it encodes raw binary data.
-func ensureBase64(data []byte) []byte {
+// ensureRawBytes normalizes MCP media data into raw binary bytes.
+//
+// The MCP Go SDK's json.Unmarshal normally base64-decodes
+// ImageContent.Data into raw bytes automatically. However, some MCP
+// transports (notably Docker over stdio) can deliver data in
+// unexpected formats. This function handles both cases:
+//
+//   - If data looks like a valid base64 string (ASCII-only, decodable)
+//     it is decoded and the raw bytes are returned.
+//   - If data is already raw binary (contains bytes > 127) it is
+//     returned as-is.
+func ensureRawBytes(data []byte) []byte {
 	if len(data) == 0 {
 		return data
 	}
 
 	normalized := normalizeBase64Input(data)
 	if decoded, ok := decodeBase64(normalized); ok {
-		encoded := make([]byte, base64.StdEncoding.EncodedLen(len(decoded)))
-		base64.StdEncoding.Encode(encoded, decoded)
-		return encoded
+		return decoded
 	}
 
-	encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
-	base64.StdEncoding.Encode(encoded, data)
-	return encoded
+	// Already raw binary — return unchanged.
+	return data
 }
 
 func normalizeBase64Input(data []byte) []byte {
@@ -213,24 +219,3 @@ func decodeBase64(data []byte) ([]byte, bool) {
 	}
 	return nil, false
 }
-
-// isValidBase64 checks if the data appears to be valid base64-encoded content.
-func isValidBase64(data []byte) bool {
-	if len(data) == 0 {
-		return true
-	}
-
-	// Base64 strings should only contain ASCII characters.
-	for _, b := range data {
-		if b > 127 {
-			return false
-		}
-	}
-
-	s := string(data)
-	if _, err := base64.StdEncoding.DecodeString(s); err == nil {
-		return true
-	}
-	_, err := base64.RawStdEncoding.DecodeString(s)
-	return err == nil
-}

internal/agent/tools/mcp/tools_test.go 🔗

@@ -1,34 +1,35 @@
 package mcp
 
 import (
+	"bytes"
 	"encoding/base64"
 	"testing"
 
 	"github.com/stretchr/testify/require"
 )
 
-func TestEnsureBase64(t *testing.T) {
+func TestEnsureRawBytes(t *testing.T) {
 	t.Parallel()
 
 	tests := []struct {
 		name     string
 		input    []byte
-		wantData []byte // expected output
+		wantData []byte
 	}{
 		{
 			name:     "already base64 encoded",
 			input:    []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
-			wantData: []byte("SGVsbG8gV29ybGQh"),
+			wantData: []byte("Hello World!"),
 		},
 		{
 			name:     "raw binary data (PNG header)",
 			input:    []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
-			wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})),
+			wantData: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
 		},
 		{
 			name:     "raw binary with high bytes",
 			input:    []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
-			wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})),
+			wantData: []byte{0xFF, 0xD8, 0xFF, 0xE0},
 		},
 		{
 			name:     "empty data",
@@ -38,88 +39,31 @@ func TestEnsureBase64(t *testing.T) {
 		{
 			name:     "base64 with padding",
 			input:    []byte("YQ=="), // "a" in base64
-			wantData: []byte("YQ=="),
+			wantData: []byte("a"),
 		},
 		{
 			name:     "base64 without padding",
 			input:    []byte("YQ"),
-			wantData: []byte("YQ=="),
+			wantData: []byte("a"),
 		},
 		{
 			name:     "base64 with whitespace",
 			input:    []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
-			wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
+			wantData: []byte("SGVsbG8gV29ybGQh"),
 		},
 	}
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
 			t.Parallel()
-			result := ensureBase64(tt.input)
+			result := ensureRawBytes(tt.input)
 			require.Equal(t, tt.wantData, result)
 
-			// Verify the result is valid base64 that can be decoded.
-			if len(result) > 0 {
-				_, err := base64.StdEncoding.DecodeString(string(result))
-				if err != nil {
-					_, err = base64.RawStdEncoding.DecodeString(string(result))
-				}
-				require.NoError(t, err, "result should be valid base64")
+			if len(result) > 0 && !bytes.Equal(result, tt.input) {
+				reEncoded := base64.StdEncoding.EncodeToString(result)
+				_, err := base64.StdEncoding.DecodeString(reEncoded)
+				require.NoError(t, err, "re-encoded result should be valid base64")
 			}
 		})
 	}
 }
-
-func TestIsValidBase64(t *testing.T) {
-	t.Parallel()
-
-	tests := []struct {
-		name  string
-		input []byte
-		want  bool
-	}{
-		{
-			name:  "valid base64",
-			input: []byte("SGVsbG8gV29ybGQh"),
-			want:  true,
-		},
-		{
-			name:  "valid base64 with padding",
-			input: []byte("YQ=="),
-			want:  true,
-		},
-		{
-			name:  "raw binary with high bytes",
-			input: []byte{0xFF, 0xD8, 0xFF},
-			want:  false,
-		},
-		{
-			name:  "empty",
-			input: []byte{},
-			want:  true,
-		},
-		{
-			name:  "valid raw base64 without padding",
-			input: []byte("YQ"),
-			want:  true,
-		},
-		{
-			name:  "valid base64 with whitespace",
-			input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")),
-			want:  true,
-		},
-		{
-			name:  "invalid base64 characters",
-			input: []byte("SGVsbG8!@#$"),
-			want:  false,
-		},
-	}
-
-	for _, tt := range tests {
-		t.Run(tt.name, func(t *testing.T) {
-			t.Parallel()
-			got := isValidBase64(tt.input)
-			require.Equal(t, tt.want, got)
-		})
-	}
-}

internal/agent/tools/view.go 🔗

@@ -4,7 +4,6 @@ import (
 	"bufio"
 	"context"
 	_ "embed"
-	"encoding/base64"
 	"fmt"
 	"io"
 	"io/fs"
@@ -189,8 +188,7 @@ func NewViewTool(
 					return fantasy.ToolResponse{}, fmt.Errorf("error reading image file: %w", readErr)
 				}
 
-				encoded := base64.StdEncoding.EncodeToString(imageData)
-				return fantasy.NewImageResponse([]byte(encoded), mimeType), nil
+				return fantasy.NewImageResponse(imageData, mimeType), nil
 			}
 
 			// Read the file content

internal/message/content.go 🔗

@@ -13,6 +13,7 @@ import (
 	"charm.land/fantasy/providers/anthropic"
 	"charm.land/fantasy/providers/google"
 	"charm.land/fantasy/providers/openai"
+	"github.com/charmbracelet/crush/internal/stringext"
 )
 
 type MessageRole string
@@ -24,6 +25,10 @@ const (
 	Tool      MessageRole = "tool"
 )
 
+// mediaLoadFailedPlaceholder is the text substituted for image data that
+// cannot be decoded during session replay.
+const mediaLoadFailedPlaceholder = "[Image data could not be loaded]"
+
 type FinishReason string
 
 const (
@@ -542,9 +547,15 @@ func (m *Message) ToAIMessage() []fantasy.Message {
 					Error: errors.New(result.Content),
 				}
 			} else if result.Data != "" {
-				content = fantasy.ToolResultOutputContentMedia{
-					Data:      result.Data,
-					MediaType: result.MIMEType,
+				if stringext.IsValidBase64(result.Data) {
+					content = fantasy.ToolResultOutputContentMedia{
+						Data:      result.Data,
+						MediaType: result.MIMEType,
+					}
+				} else {
+					content = fantasy.ToolResultOutputContentText{
+						Text: mediaLoadFailedPlaceholder,
+					}
 				}
 			} else {
 				content = fantasy.ToolResultOutputContentText{

internal/message/content_test.go 🔗

@@ -1,9 +1,13 @@
 package message
 
 import (
+	"encoding/base64"
 	"fmt"
 	"strings"
 	"testing"
+
+	"charm.land/fantasy"
+	"github.com/stretchr/testify/require"
 )
 
 func makeTestAttachments(n int, contentSize int) []Attachment {
@@ -19,6 +23,99 @@ func makeTestAttachments(n int, contentSize int) []Attachment {
 	return attachments
 }
 
+func TestToAIMessage_CorruptedMediaData(t *testing.T) {
+	t.Parallel()
+
+	msg := &Message{
+		Role: Tool,
+		Parts: []ContentPart{
+			ToolResult{
+				ToolCallID: "call_123",
+				Name:       "screenshot",
+				Content:    "Loaded image/png content",
+				Data:       "abc\x80def",
+				MIMEType:   "image/png",
+			},
+		},
+	}
+
+	messages := msg.ToAIMessage()
+	require.Len(t, messages, 1)
+	require.Len(t, messages[0].Content, 1)
+
+	part, ok := messages[0].Content[0].(fantasy.ToolResultPart)
+	require.True(t, ok)
+
+	require.Equal(t, "call_123", part.ToolCallID)
+
+	textContent, ok := part.Output.(fantasy.ToolResultOutputContentText)
+	require.True(t, ok, "corrupted media should be downgraded to text")
+	require.Equal(t, mediaLoadFailedPlaceholder, textContent.Text)
+}
+
+func TestToAIMessage_ValidMediaData(t *testing.T) {
+	t.Parallel()
+
+	validBase64 := base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47})
+
+	msg := &Message{
+		Role: Tool,
+		Parts: []ContentPart{
+			ToolResult{
+				ToolCallID: "call_456",
+				Name:       "screenshot",
+				Content:    "Loaded image/png content",
+				Data:       validBase64,
+				MIMEType:   "image/png",
+			},
+		},
+	}
+
+	messages := msg.ToAIMessage()
+	require.Len(t, messages, 1)
+	require.Len(t, messages[0].Content, 1)
+
+	part, ok := messages[0].Content[0].(fantasy.ToolResultPart)
+	require.True(t, ok)
+
+	require.Equal(t, "call_456", part.ToolCallID)
+
+	mediaContent, ok := part.Output.(fantasy.ToolResultOutputContentMedia)
+	require.True(t, ok, "valid media should remain as media")
+	require.Equal(t, validBase64, mediaContent.Data)
+	require.Equal(t, "image/png", mediaContent.MediaType)
+}
+
+func TestToAIMessage_ASCIIButInvalidBase64(t *testing.T) {
+	t.Parallel()
+
+	msg := &Message{
+		Role: Tool,
+		Parts: []ContentPart{
+			ToolResult{
+				ToolCallID: "call_789",
+				Name:       "screenshot",
+				Content:    "Loaded image/png content",
+				Data:       "not-valid-base64!!!",
+				MIMEType:   "image/png",
+			},
+		},
+	}
+
+	messages := msg.ToAIMessage()
+	require.Len(t, messages, 1)
+	require.Len(t, messages[0].Content, 1)
+
+	part, ok := messages[0].Content[0].(fantasy.ToolResultPart)
+	require.True(t, ok)
+
+	require.Equal(t, "call_789", part.ToolCallID)
+
+	textContent, ok := part.Output.(fantasy.ToolResultOutputContentText)
+	require.True(t, ok, "ASCII but invalid base64 should be downgraded to text")
+	require.Equal(t, mediaLoadFailedPlaceholder, textContent.Text)
+}
+
 func BenchmarkPromptWithTextAttachments(b *testing.B) {
 	cases := []struct {
 		name        string

internal/stringext/string.go 🔗

@@ -1,6 +1,7 @@
 package stringext
 
 import (
+	"encoding/base64"
 	"strings"
 
 	"golang.org/x/text/cases"
@@ -20,3 +21,20 @@ func NormalizeSpace(content string) string {
 	content = strings.TrimSpace(content)
 	return content
 }
+
+// IsValidBase64 reports whether s is canonical base64 under standard
+// encoding (RFC 4648). It requires that s round-trips through
+// decode/encode unchanged — rejecting whitespace, missing padding,
+// and other leniencies that DecodeString alone would accept.
+func IsValidBase64(s string) bool {
+	if s == "" {
+		return false
+	}
+	decoded, err := base64.StdEncoding.DecodeString(s)
+	if err != nil {
+		return false
+	}
+	// Round-trip check rejects whitespace, missing padding, and other
+	// leniencies that DecodeString silently accepts.
+	return base64.StdEncoding.EncodeToString(decoded) == s
+}

internal/stringext/string_test.go 🔗

@@ -0,0 +1,38 @@
+package stringext
+
+import (
+	"encoding/base64"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestIsValidBase64(t *testing.T) {
+	t.Parallel()
+
+	// Real PNG header encoded in standard base64.
+	pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
+	pngBase64 := base64.StdEncoding.EncodeToString(pngHeader)
+
+	tests := []struct {
+		name     string
+		input    string
+		expected bool
+	}{
+		{name: "empty string", input: "", expected: false},
+		{name: "valid no padding", input: "SGVsbG8gV29ybGQh", expected: true},
+		{name: "valid with padding", input: "YQ==", expected: true},
+		{name: "non-ASCII bytes", input: "abc\x80def", expected: false},
+		{name: "ASCII but not base64", input: "hello world!!!", expected: false},
+		{name: "raw encoding no padding", input: "YQ", expected: false},
+		{name: "trailing whitespace", input: "YQ==\n", expected: false},
+		{name: "valid PNG header base64", input: pngBase64, expected: true},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+			require.Equal(t, tt.expected, IsValidBase64(tt.input))
+		})
+	}
+}