fix(mcp): avoid corrupting raw ASCII payloads during base64 normalization

Christian Rocha created

Change summary

internal/agent/tools/mcp/tools.go      | 35 ++++++++++++---------------
internal/agent/tools/mcp/tools_test.go |  5 ++++
2 files changed, 21 insertions(+), 19 deletions(-)

Detailed changes

internal/agent/tools/mcp/tools.go 🔗

@@ -175,7 +175,7 @@ func ensureBase64(data []byte) []byte {
 	}
 
 	normalized := normalizeBase64Input(data)
-	if decoded, ok := decodeBase64(normalized); ok {
+	if decoded, ok := decodeLikelyBase64(normalized); ok {
 		encoded := make([]byte, base64.StdEncoding.EncodedLen(len(decoded)))
 		base64.StdEncoding.Encode(encoded, decoded)
 		return encoded
@@ -191,7 +191,11 @@ func normalizeBase64Input(data []byte) []byte {
 	return []byte(normalized)
 }
 
-func decodeBase64(data []byte) ([]byte, bool) {
+// decodeLikelyBase64 decodes canonical base64 and only accepts unpadded raw
+// base64 when the input also contains a character that plain lowercase text
+// cannot contain in base64. This avoids corrupting raw ASCII payload bytes like
+// "abc" that RawStdEncoding could otherwise decode.
+func decodeLikelyBase64(data []byte) ([]byte, bool) {
 	if len(data) == 0 {
 		return data, true
 	}
@@ -207,6 +211,14 @@ func decodeBase64(data []byte) ([]byte, bool) {
 	if err == nil {
 		return decoded, true
 	}
+
+	if len(s)%4 == 0 {
+		return nil, false
+	}
+	if !strings.ContainsAny(s, "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+/") {
+		return nil, false
+	}
+
 	decoded, err = base64.RawStdEncoding.DecodeString(s)
 	if err == nil {
 		return decoded, true
@@ -216,21 +228,6 @@ func decodeBase64(data []byte) ([]byte, bool) {
 
 // isValidBase64 checks if the data appears to be valid base64-encoded content.
 func isValidBase64(data []byte) bool {
-	if len(data) == 0 {
-		return true
-	}
-
-	// Base64 strings should only contain ASCII characters.
-	for _, b := range data {
-		if b > 127 {
-			return false
-		}
-	}
-
-	s := string(data)
-	if _, err := base64.StdEncoding.DecodeString(s); err == nil {
-		return true
-	}
-	_, err := base64.RawStdEncoding.DecodeString(s)
-	return err == nil
+	_, ok := decodeLikelyBase64(normalizeBase64Input(data))
+	return ok
 }

internal/agent/tools/mcp/tools_test.go 🔗

@@ -50,6 +50,11 @@ func TestEnsureBase64(t *testing.T) {
 			input:    []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
 			wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
 		},
+		{
+			name:     "raw ascii bytes that look like unpadded base64",
+			input:    []byte("abc"),
+			wantData: []byte("YWJj"),
+		},
 	}
 
 	for _, tt := range tests {