diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index 05d6b2b75d8fadff2e9af8385817ac135722f1a8..cd30e49e0f92f6fc9c629b9e78dbf7c9540704e9 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -175,7 +175,7 @@ func ensureBase64(data []byte) []byte { } normalized := normalizeBase64Input(data) - if decoded, ok := decodeBase64(normalized); ok { + if decoded, ok := decodeLikelyBase64(normalized); ok { encoded := make([]byte, base64.StdEncoding.EncodedLen(len(decoded))) base64.StdEncoding.Encode(encoded, decoded) return encoded @@ -191,7 +191,11 @@ func normalizeBase64Input(data []byte) []byte { return []byte(normalized) } -func decodeBase64(data []byte) ([]byte, bool) { +// decodeLikelyBase64 decodes canonical base64 and only accepts unpadded raw +// base64 when the input also contains a character that plain lowercase text +// cannot contain in base64. This avoids corrupting raw ASCII payload bytes like +// "abc" that RawStdEncoding could otherwise decode. +func decodeLikelyBase64(data []byte) ([]byte, bool) { if len(data) == 0 { return data, true } @@ -207,6 +211,14 @@ func decodeBase64(data []byte) ([]byte, bool) { if err == nil { return decoded, true } + + if len(s)%4 == 0 { + return nil, false + } + if !strings.ContainsAny(s, "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+/") { + return nil, false + } + decoded, err = base64.RawStdEncoding.DecodeString(s) if err == nil { return decoded, true @@ -216,21 +228,6 @@ func decodeBase64(data []byte) ([]byte, bool) { // isValidBase64 checks if the data appears to be valid base64-encoded content. func isValidBase64(data []byte) bool { - if len(data) == 0 { - return true - } - - // Base64 strings should only contain ASCII characters. - for _, b := range data { - if b > 127 { - return false - } - } - - s := string(data) - if _, err := base64.StdEncoding.DecodeString(s); err == nil { - return true - } - _, err := base64.RawStdEncoding.DecodeString(s) - return err == nil + _, ok := decodeLikelyBase64(normalizeBase64Input(data)) + return ok } diff --git a/internal/agent/tools/mcp/tools_test.go b/internal/agent/tools/mcp/tools_test.go index aae4428ed6b830549540611761c22f070eeda925..753875a16dce2b4cf838de924ad4ad2e91c3d9a4 100644 --- a/internal/agent/tools/mcp/tools_test.go +++ b/internal/agent/tools/mcp/tools_test.go @@ -50,6 +50,11 @@ func TestEnsureBase64(t *testing.T) { input: []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"), wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="), }, + { + name: "raw ascii bytes that look like unpadded base64", + input: []byte("abc"), + wantData: []byte("YWJj"), + }, } for _, tt := range tests {