tools_test.go

  1package mcp
  2
  3import (
  4	"encoding/base64"
  5	"testing"
  6
  7	"github.com/stretchr/testify/require"
  8)
  9
 10func TestEnsureBase64(t *testing.T) {
 11	t.Parallel()
 12
 13	tests := []struct {
 14		name     string
 15		input    []byte
 16		wantData []byte // expected output
 17	}{
 18		{
 19			name:     "already base64 encoded",
 20			input:    []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
 21			wantData: []byte("SGVsbG8gV29ybGQh"),
 22		},
 23		{
 24			name:     "raw binary data (PNG header)",
 25			input:    []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
 26			wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})),
 27		},
 28		{
 29			name:     "raw binary with high bytes",
 30			input:    []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
 31			wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})),
 32		},
 33		{
 34			name:     "empty data",
 35			input:    []byte{},
 36			wantData: []byte{},
 37		},
 38		{
 39			name:     "base64 with padding",
 40			input:    []byte("YQ=="), // "a" in base64
 41			wantData: []byte("YQ=="),
 42		},
 43		{
 44			name:     "base64 without padding",
 45			input:    []byte("YQ"),
 46			wantData: []byte("YQ=="),
 47		},
 48		{
 49			name:     "base64 with whitespace",
 50			input:    []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
 51			wantData: []byte("U0dWc2JHOGdWMjl5YkdRaA=="),
 52		},
 53	}
 54
 55	for _, tt := range tests {
 56		t.Run(tt.name, func(t *testing.T) {
 57			t.Parallel()
 58			result := ensureBase64(tt.input)
 59			require.Equal(t, tt.wantData, result)
 60
 61			// Verify the result is valid base64 that can be decoded.
 62			if len(result) > 0 {
 63				_, err := base64.StdEncoding.DecodeString(string(result))
 64				if err != nil {
 65					_, err = base64.RawStdEncoding.DecodeString(string(result))
 66				}
 67				require.NoError(t, err, "result should be valid base64")
 68			}
 69		})
 70	}
 71}
 72
 73func TestIsValidBase64(t *testing.T) {
 74	t.Parallel()
 75
 76	tests := []struct {
 77		name  string
 78		input []byte
 79		want  bool
 80	}{
 81		{
 82			name:  "valid base64",
 83			input: []byte("SGVsbG8gV29ybGQh"),
 84			want:  true,
 85		},
 86		{
 87			name:  "valid base64 with padding",
 88			input: []byte("YQ=="),
 89			want:  true,
 90		},
 91		{
 92			name:  "raw binary with high bytes",
 93			input: []byte{0xFF, 0xD8, 0xFF},
 94			want:  false,
 95		},
 96		{
 97			name:  "empty",
 98			input: []byte{},
 99			want:  true,
100		},
101		{
102			name:  "valid raw base64 without padding",
103			input: []byte("YQ"),
104			want:  true,
105		},
106		{
107			name:  "valid base64 with whitespace",
108			input: normalizeBase64Input([]byte("U0dWc2JHOGdWMjl5YkdRaA==\n")),
109			want:  true,
110		},
111		{
112			name:  "invalid base64 characters",
113			input: []byte("SGVsbG8!@#$"),
114			want:  false,
115		},
116	}
117
118	for _, tt := range tests {
119		t.Run(tt.name, func(t *testing.T) {
120			t.Parallel()
121			got := isValidBase64(tt.input)
122			require.Equal(t, tt.want, got)
123		})
124	}
125}