@@ -175,7 +175,7 @@ func ensureBase64(data []byte) []byte {
}
normalized := normalizeBase64Input(data)
- if decoded, ok := decodeBase64(normalized); ok {
+ if decoded, ok := decodeLikelyBase64(normalized); ok {
encoded := make([]byte, base64.StdEncoding.EncodedLen(len(decoded)))
base64.StdEncoding.Encode(encoded, decoded)
return encoded
@@ -191,7 +191,11 @@ func normalizeBase64Input(data []byte) []byte {
return []byte(normalized)
}
-func decodeBase64(data []byte) ([]byte, bool) {
+// decodeLikelyBase64 decodes canonical base64 and only accepts unpadded raw
+// base64 when the input also contains a character that plain lowercase text
+// cannot contain in base64. This avoids corrupting raw ASCII payload bytes like
+// "abc" that RawStdEncoding could otherwise decode.
+func decodeLikelyBase64(data []byte) ([]byte, bool) {
if len(data) == 0 {
return data, true
}
@@ -207,6 +211,14 @@ func decodeBase64(data []byte) ([]byte, bool) {
if err == nil {
return decoded, true
}
+
+ if len(s)%4 == 0 {
+ return nil, false
+ }
+ if !strings.ContainsAny(s, "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+/") {
+ return nil, false
+ }
+
decoded, err = base64.RawStdEncoding.DecodeString(s)
if err == nil {
return decoded, true
@@ -216,21 +228,6 @@ func decodeBase64(data []byte) ([]byte, bool) {
// isValidBase64 checks if the data appears to be valid base64-encoded content.
func isValidBase64(data []byte) bool {
- if len(data) == 0 {
- return true
- }
-
- // Base64 strings should only contain ASCII characters.
- for _, b := range data {
- if b > 127 {
- return false
- }
- }
-
- s := string(data)
- if _, err := base64.StdEncoding.DecodeString(s); err == nil {
- return true
- }
- _, err := base64.RawStdEncoding.DecodeString(s)
- return err == nil
+ _, ok := decodeLikelyBase64(normalizeBase64Input(data))
+ return ok
}
@@ -50,6 +50,11 @@ func TestEnsureBase64(t *testing.T) {
input: []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
},
+ {
+ name: "raw ascii bytes that look like unpadded base64",
+ input: []byte("abc"),
+ wantData: []byte("YWJj"),
+ },
}
for _, tt := range tests {