1package mcp
2
3import (
4 "encoding/base64"
5 "testing"
6
7 "github.com/stretchr/testify/require"
8)
9
10func TestEnsureBase64(t *testing.T) {
11 t.Parallel()
12
13 tests := []struct {
14 name string
15 input []byte
16 wantData []byte // expected output
17 }{
18 {
19 name: "already base64 encoded",
20 input: []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
21 wantData: []byte("SGVsbG8gV29ybGQh"),
22 },
23 {
24 name: "raw binary data (PNG header)",
25 input: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
26 wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})),
27 },
28 {
29 name: "raw binary with high bytes",
30 input: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
31 wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})),
32 },
33 {
34 name: "empty data",
35 input: []byte{},
36 wantData: []byte{},
37 },
38 {
39 name: "base64 with padding",
40 input: []byte("YQ=="), // "a" in base64
41 wantData: []byte("YQ=="),
42 },
43 {
44 name: "base64 without padding",
45 input: []byte("YQ"),
46 wantData: []byte("YQ=="),
47 },
48 {
49 name: "base64 with whitespace",
50 input: []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
51 wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
52 },
53 }
54
55 for _, tt := range tests {
56 t.Run(tt.name, func(t *testing.T) {
57 t.Parallel()
58 result := ensureBase64(tt.input)
59 require.Equal(t, tt.wantData, result)
60
61 // Verify the result is valid base64 that can be decoded.
62 if len(result) > 0 {
63 _, err := base64.StdEncoding.DecodeString(string(result))
64 if err != nil {
65 _, err = base64.RawStdEncoding.DecodeString(string(result))
66 }
67 require.NoError(t, err, "result should be valid base64")
68 }
69 })
70 }
71}
72
73func TestIsValidBase64(t *testing.T) {
74 t.Parallel()
75
76 tests := []struct {
77 name string
78 input []byte
79 want bool
80 }{
81 {
82 name: "valid base64",
83 input: []byte("SGVsbG8gV29ybGQh"),
84 want: true,
85 },
86 {
87 name: "valid base64 with padding",
88 input: []byte("YQ=="),
89 want: true,
90 },
91 {
92 name: "raw binary with high bytes",
93 input: []byte{0xFF, 0xD8, 0xFF},
94 want: false,
95 },
96 {
97 name: "empty",
98 input: []byte{},
99 want: true,
100 },
101 {
102 name: "valid raw base64 without padding",
103 input: []byte("YQ"),
104 want: true,
105 },
106 {
107 name: "valid base64 with whitespace",
108 input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")),
109 want: true,
110 },
111 {
112 name: "invalid base64 characters",
113 input: []byte("SGVsbG8!@#$"),
114 want: false,
115 },
116 }
117
118 for _, tt := range tests {
119 t.Run(tt.name, func(t *testing.T) {
120 t.Parallel()
121 got := isValidBase64(tt.input)
122 require.Equal(t, tt.want, got)
123 })
124 }
125}