From df094bdeca6717cb36c127fddda69a041b4e38f1 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Thu, 19 Mar 2026 11:09:54 -0400 Subject: [PATCH] fix(mcp): handle raw/whitespace base64 --- internal/agent/tools/mcp/tools.go | 54 ++++++++++++++++++++++---- internal/agent/tools/mcp/tools_test.go | 23 +++++++++++ 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index ce85e591e55139343e43179bdd33c88b49c274be..05d6b2b75d8fadff2e9af8385817ac135722f1a8 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -167,32 +167,70 @@ func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool) return filtered } -// ensureBase64 checks if data is valid base64 and returns it as-is if so, -// otherwise encodes the raw binary data to base64. +// ensureBase64 normalizes valid base64 input and guarantees padded +// base64.StdEncoding output; otherwise it encodes raw binary data. func ensureBase64(data []byte) []byte { - // Check if the data is already valid base64 by attempting to decode it. - // Valid base64 should only contain ASCII characters (A-Z, a-z, 0-9, +, /, =). - if isValidBase64(data) { + if len(data) == 0 { return data } - // Data is raw binary, encode it to base64. + + normalized := normalizeBase64Input(data) + if decoded, ok := decodeBase64(normalized); ok { + encoded := make([]byte, base64.StdEncoding.EncodedLen(len(decoded))) + base64.StdEncoding.Encode(encoded, decoded) + return encoded + } + encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data))) base64.StdEncoding.Encode(encoded, data) return encoded } +func normalizeBase64Input(data []byte) []byte { + normalized := strings.Join(strings.Fields(string(data)), "") + return []byte(normalized) +} + +func decodeBase64(data []byte) ([]byte, bool) { + if len(data) == 0 { + return data, true + } + + for _, b := range data { + if b > 127 { + return nil, false + } + } + + s := string(data) + decoded, err := base64.StdEncoding.DecodeString(s) + if err == nil { + return decoded, true + } + decoded, err = base64.RawStdEncoding.DecodeString(s) + if err == nil { + return decoded, true + } + return nil, false +} + // isValidBase64 checks if the data appears to be valid base64-encoded content. func isValidBase64(data []byte) bool { if len(data) == 0 { return true } + // Base64 strings should only contain ASCII characters. for _, b := range data { if b > 127 { return false } } - // Try to decode to verify it's valid base64. - _, err := base64.StdEncoding.DecodeString(string(data)) + + s := string(data) + if _, err := base64.StdEncoding.DecodeString(s); err == nil { + return true + } + _, err := base64.RawStdEncoding.DecodeString(s) return err == nil } diff --git a/internal/agent/tools/mcp/tools_test.go b/internal/agent/tools/mcp/tools_test.go index 3795381ebd2a6a54540a6905ef39c97b8ce59575..aae4428ed6b830549540611761c22f070eeda925 100644 --- a/internal/agent/tools/mcp/tools_test.go +++ b/internal/agent/tools/mcp/tools_test.go @@ -40,6 +40,16 @@ func TestEnsureBase64(t *testing.T) { input: []byte("YQ=="), // "a" in base64 wantData: []byte("YQ=="), }, + { + name: "base64 without padding", + input: []byte("YQ"), + wantData: []byte("YQ=="), + }, + { + name: "base64 with whitespace", + input: []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"), + wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="), + }, } for _, tt := range tests { @@ -51,6 +61,9 @@ func TestEnsureBase64(t *testing.T) { // Verify the result is valid base64 that can be decoded. if len(result) > 0 { _, err := base64.StdEncoding.DecodeString(string(result)) + if err != nil { + _, err = base64.RawStdEncoding.DecodeString(string(result)) + } require.NoError(t, err, "result should be valid base64") } }) @@ -85,6 +98,16 @@ func TestIsValidBase64(t *testing.T) { input: []byte{}, want: true, }, + { + name: "valid raw base64 without padding", + input: []byte("YQ"), + want: true, + }, + { + name: "valid base64 with whitespace", + input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")), + want: true, + }, { name: "invalid base64 characters", input: []byte("SGVsbG8!@#$"),