1package mcp
2
3import (
4 "encoding/base64"
5 "testing"
6
7 "github.com/stretchr/testify/require"
8)
9
10func TestEnsureBase64(t *testing.T) {
11 t.Parallel()
12
13 tests := []struct {
14 name string
15 input []byte
16 wantData []byte // expected output
17 }{
18 {
19 name: "already base64 encoded",
20 input: []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
21 wantData: []byte("SGVsbG8gV29ybGQh"),
22 },
23 {
24 name: "raw binary data (PNG header)",
25 input: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
26 wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})),
27 },
28 {
29 name: "raw binary with high bytes",
30 input: []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
31 wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})),
32 },
33 {
34 name: "empty data",
35 input: []byte{},
36 wantData: []byte{},
37 },
38 {
39 name: "base64 with padding",
40 input: []byte("YQ=="), // "a" in base64
41 wantData: []byte("YQ=="),
42 },
43 }
44
45 for _, tt := range tests {
46 t.Run(tt.name, func(t *testing.T) {
47 t.Parallel()
48 result := ensureBase64(tt.input)
49 require.Equal(t, tt.wantData, result)
50
51 // Verify the result is valid base64 that can be decoded.
52 if len(result) > 0 {
53 _, err := base64.StdEncoding.DecodeString(string(result))
54 require.NoError(t, err, "result should be valid base64")
55 }
56 })
57 }
58}
59
60func TestIsValidBase64(t *testing.T) {
61 t.Parallel()
62
63 tests := []struct {
64 name string
65 input []byte
66 want bool
67 }{
68 {
69 name: "valid base64",
70 input: []byte("SGVsbG8gV29ybGQh"),
71 want: true,
72 },
73 {
74 name: "valid base64 with padding",
75 input: []byte("YQ=="),
76 want: true,
77 },
78 {
79 name: "raw binary with high bytes",
80 input: []byte{0xFF, 0xD8, 0xFF},
81 want: false,
82 },
83 {
84 name: "empty",
85 input: []byte{},
86 want: true,
87 },
88 {
89 name: "invalid base64 characters",
90 input: []byte("SGVsbG8!@#$"),
91 want: false,
92 },
93 }
94
95 for _, tt := range tests {
96 t.Run(tt.name, func(t *testing.T) {
97 t.Parallel()
98 got := isValidBase64(tt.input)
99 require.Equal(t, tt.want, got)
100 })
101 }
102}