1package mcp
2
3import (
4 "encoding/base64"
5 "testing"
6
7 "github.com/stretchr/testify/require"
8)
9
10func TestEnsureBase64(t *testing.T) {
11 t.Parallel()
12
13 tests := []struct {
14 name string
15 input []byte
16 wantData []byte // expected output
17 }{
18 {
19 name: "already base64 encoded",
20 input: []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
21 wantData: []byte("SGVsbG8gV29ybGQh"),
22 },
23 {
24 name: "raw binary data (PNG header)",
25 input: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
26 wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})),
27 },
28 {
29 name: "raw binary with high bytes",
30 input: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
31 wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})),
32 },
33 {
34 name: "empty data",
35 input: []byte{},
36 wantData: []byte{},
37 },
38 {
39 name: "base64 with padding",
40 input: []byte("YQ=="), // "a" in base64
41 wantData: []byte("YQ=="),
42 },
43 {
44 name: "base64 without padding",
45 input: []byte("YQ"),
46 wantData: []byte("YQ=="),
47 },
48 {
49 name: "base64 with whitespace",
50 input: []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
51 wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
52 },
53 {
54 name: "raw ascii bytes that look like unpadded base64",
55 input: []byte("abc"),
56 wantData: []byte("YWJj"),
57 },
58 }
59
60 for _, tt := range tests {
61 t.Run(tt.name, func(t *testing.T) {
62 t.Parallel()
63 result := ensureBase64(tt.input)
64 require.Equal(t, tt.wantData, result)
65
66 // Verify the result is valid base64 that can be decoded.
67 if len(result) > 0 {
68 _, err := base64.StdEncoding.DecodeString(string(result))
69 if err != nil {
70 _, err = base64.RawStdEncoding.DecodeString(string(result))
71 }
72 require.NoError(t, err, "result should be valid base64")
73 }
74 })
75 }
76}
77
78func TestIsValidBase64(t *testing.T) {
79 t.Parallel()
80
81 tests := []struct {
82 name string
83 input []byte
84 want bool
85 }{
86 {
87 name: "valid base64",
88 input: []byte("SGVsbG8gV29ybGQh"),
89 want: true,
90 },
91 {
92 name: "valid base64 with padding",
93 input: []byte("YQ=="),
94 want: true,
95 },
96 {
97 name: "raw binary with high bytes",
98 input: []byte{0xFF, 0xD8, 0xFF},
99 want: false,
100 },
101 {
102 name: "empty",
103 input: []byte{},
104 want: true,
105 },
106 {
107 name: "valid raw base64 without padding",
108 input: []byte("YQ"),
109 want: true,
110 },
111 {
112 name: "valid base64 with whitespace",
113 input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")),
114 want: true,
115 },
116 {
117 name: "invalid base64 characters",
118 input: []byte("SGVsbG8!@#$"),
119 want: false,
120 },
121 }
122
123 for _, tt := range tests {
124 t.Run(tt.name, func(t *testing.T) {
125 t.Parallel()
126 got := isValidBase64(tt.input)
127 require.Equal(t, tt.want, got)
128 })
129 }
130}