@@ -167,32 +167,70 @@ func filterDisabledTools(cfg *config.ConfigStore, mcpName string, tools []*Tool)
return filtered
}
-// ensureBase64 checks if data is valid base64 and returns it as-is if so,
-// otherwise encodes the raw binary data to base64.
+// ensureBase64 normalizes valid base64 input and guarantees padded
+// base64.StdEncoding output; otherwise it encodes raw binary data.
func ensureBase64(data []byte) []byte {
- // Check if the data is already valid base64 by attempting to decode it.
- // Valid base64 should only contain ASCII characters (A-Z, a-z, 0-9, +, /, =).
- if isValidBase64(data) {
+ if len(data) == 0 {
return data
}
- // Data is raw binary, encode it to base64.
+
+ normalized := normalizeBase64Input(data)
+ if decoded, ok := decodeBase64(normalized); ok {
+ encoded := make([]byte, base64.StdEncoding.EncodedLen(len(decoded)))
+ base64.StdEncoding.Encode(encoded, decoded)
+ return encoded
+ }
+
encoded := make([]byte, base64.StdEncoding.EncodedLen(len(data)))
base64.StdEncoding.Encode(encoded, data)
return encoded
}
+func normalizeBase64Input(data []byte) []byte {
+ normalized := strings.Join(strings.Fields(string(data)), "")
+ return []byte(normalized)
+}
+
+func decodeBase64(data []byte) ([]byte, bool) {
+ if len(data) == 0 {
+ return data, true
+ }
+
+ for _, b := range data {
+ if b > 127 {
+ return nil, false
+ }
+ }
+
+ s := string(data)
+ decoded, err := base64.StdEncoding.DecodeString(s)
+ if err == nil {
+ return decoded, true
+ }
+ decoded, err = base64.RawStdEncoding.DecodeString(s)
+ if err == nil {
+ return decoded, true
+ }
+ return nil, false
+}
+
// isValidBase64 checks if the data appears to be valid base64-encoded content.
func isValidBase64(data []byte) bool {
if len(data) == 0 {
return true
}
+
// Base64 strings should only contain ASCII characters.
for _, b := range data {
if b > 127 {
return false
}
}
- // Try to decode to verify it's valid base64.
- _, err := base64.StdEncoding.DecodeString(string(data))
+
+ s := string(data)
+ if _, err := base64.StdEncoding.DecodeString(s); err == nil {
+ return true
+ }
+ _, err := base64.RawStdEncoding.DecodeString(s)
return err == nil
}
@@ -40,6 +40,16 @@ func TestEnsureBase64(t *testing.T) {
input: []byte("YQ=="), // "a" in base64
wantData: []byte("YQ=="),
},
+ {
+ name: "base64 without padding",
+ input: []byte("YQ"),
+ wantData: []byte("YQ=="),
+ },
+ {
+ name: "base64 with whitespace",
+ input: []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
+ wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
+ },
}
for _, tt := range tests {
@@ -51,6 +61,9 @@ func TestEnsureBase64(t *testing.T) {
// Verify the result is valid base64 that can be decoded.
if len(result) > 0 {
_, err := base64.StdEncoding.DecodeString(string(result))
+ if err != nil {
+ _, err = base64.RawStdEncoding.DecodeString(string(result))
+ }
require.NoError(t, err, "result should be valid base64")
}
})
@@ -85,6 +98,16 @@ func TestIsValidBase64(t *testing.T) {
input: []byte{},
want: true,
},
+ {
+ name: "valid raw base64 without padding",
+ input: []byte("YQ"),
+ want: true,
+ },
+ {
+ name: "valid base64 with whitespace",
+ input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")),
+ want: true,
+ },
{
name: "invalid base64 characters",
input: []byte("SGVsbG8!@#$"),