tools_test.go

  1package mcp
  2
  3import (
  4	"encoding/base64"
  5	"testing"
  6
  7	"github.com/stretchr/testify/require"
  8)
  9
 10func TestEnsureBase64(t *testing.T) {
 11	t.Parallel()
 12
 13	tests := []struct {
 14		name     string
 15		input    []byte
 16		wantData []byte // expected output
 17	}{
 18		{
 19			name:     "already base64 encoded",
 20			input:    []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
 21			wantData: []byte("SGVsbG8gV29ybGQh"),
 22		},
 23		{
 24			name:     "raw binary data (PNG header)",
 25			input:    []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
 26			wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})),
 27		},
 28		{
 29			name:     "raw binary with high bytes",
 30			input:    []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
 31			wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})),
 32		},
 33		{
 34			name:     "empty data",
 35			input:    []byte{},
 36			wantData: []byte{},
 37		},
 38		{
 39			name:     "base64 with padding",
 40			input:    []byte("YQ=="), // "a" in base64
 41			wantData: []byte("YQ=="),
 42		},
 43		{
 44			name:     "base64 without padding",
 45			input:    []byte("YQ"),
 46			wantData: []byte("YQ=="),
 47		},
 48		{
 49			name:     "base64 with whitespace",
 50			input:    []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
 51			wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
 52		},
 53		{
 54			name:     "raw ascii bytes that look like unpadded base64",
 55			input:    []byte("abc"),
 56			wantData: []byte("YWJj"),
 57		},
 58	}
 59
 60	for _, tt := range tests {
 61		t.Run(tt.name, func(t *testing.T) {
 62			t.Parallel()
 63			result := ensureBase64(tt.input)
 64			require.Equal(t, tt.wantData, result)
 65
 66			// Verify the result is valid base64 that can be decoded.
 67			if len(result) > 0 {
 68				_, err := base64.StdEncoding.DecodeString(string(result))
 69				if err != nil {
 70					_, err = base64.RawStdEncoding.DecodeString(string(result))
 71				}
 72				require.NoError(t, err, "result should be valid base64")
 73			}
 74		})
 75	}
 76}
 77
 78func TestIsValidBase64(t *testing.T) {
 79	t.Parallel()
 80
 81	tests := []struct {
 82		name  string
 83		input []byte
 84		want  bool
 85	}{
 86		{
 87			name:  "valid base64",
 88			input: []byte("SGVsbG8gV29ybGQh"),
 89			want:  true,
 90		},
 91		{
 92			name:  "valid base64 with padding",
 93			input: []byte("YQ=="),
 94			want:  true,
 95		},
 96		{
 97			name:  "raw binary with high bytes",
 98			input: []byte{0xFF, 0xD8, 0xFF},
 99			want:  false,
100		},
101		{
102			name:  "empty",
103			input: []byte{},
104			want:  true,
105		},
106		{
107			name:  "valid raw base64 without padding",
108			input: []byte("YQ"),
109			want:  true,
110		},
111		{
112			name:  "valid base64 with whitespace",
113			input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")),
114			want:  true,
115		},
116		{
117			name:  "invalid base64 characters",
118			input: []byte("SGVsbG8!@#$"),
119			want:  false,
120		},
121	}
122
123	for _, tt := range tests {
124		t.Run(tt.name, func(t *testing.T) {
125			t.Parallel()
126			got := isValidBase64(tt.input)
127			require.Equal(t, tt.want, got)
128		})
129	}
130}