tools_test.go

  1package mcp
  2
  3import (
  4	"encoding/base64"
  5	"testing"
  6
  7	"github.com/stretchr/testify/require"
  8)
  9
 10func TestEnsureBase64(t *testing.T) {
 11	t.Parallel()
 12
 13	tests := []struct {
 14		name     string
 15		input    []byte
 16		wantData []byte // expected output
 17	}{
 18		{
 19			name:     "already base64 encoded",
 20			input:    []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
 21			wantData: []byte("SGVsbG8gV29ybGQh"),
 22		},
 23		{
 24			name:     "raw binary data (PNG header)",
 25			input:    []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
 26			wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A})),
 27		},
 28		{
 29			name:     "raw binary with high bytes",
 30			input:    []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
 31			wantData: []byte(base64.StdEncoding.EncodeToString([]byte{0xFF, 0xD8, 0xFF, 0xE0})),
 32		},
 33		{
 34			name:     "empty data",
 35			input:    []byte{},
 36			wantData: []byte{},
 37		},
 38		{
 39			name:     "base64 with padding",
 40			input:    []byte("YQ=="), // "a" in base64
 41			wantData: []byte("YQ=="),
 42		},
 43	}
 44
 45	for _, tt := range tests {
 46		t.Run(tt.name, func(t *testing.T) {
 47			t.Parallel()
 48			result := ensureBase64(tt.input)
 49			require.Equal(t, tt.wantData, result)
 50
 51			// Verify the result is valid base64 that can be decoded.
 52			if len(result) > 0 {
 53				_, err := base64.StdEncoding.DecodeString(string(result))
 54				require.NoError(t, err, "result should be valid base64")
 55			}
 56		})
 57	}
 58}
 59
 60func TestIsValidBase64(t *testing.T) {
 61	t.Parallel()
 62
 63	tests := []struct {
 64		name  string
 65		input []byte
 66		want  bool
 67	}{
 68		{
 69			name:  "valid base64",
 70			input: []byte("SGVsbG8gV29ybGQh"),
 71			want:  true,
 72		},
 73		{
 74			name:  "valid base64 with padding",
 75			input: []byte("YQ=="),
 76			want:  true,
 77		},
 78		{
 79			name:  "raw binary with high bytes",
 80			input: []byte{0xFF, 0xD8, 0xFF},
 81			want:  false,
 82		},
 83		{
 84			name:  "empty",
 85			input: []byte{},
 86			want:  true,
 87		},
 88		{
 89			name:  "invalid base64 characters",
 90			input: []byte("SGVsbG8!@#$"),
 91			want:  false,
 92		},
 93	}
 94
 95	for _, tt := range tests {
 96		t.Run(tt.name, func(t *testing.T) {
 97			t.Parallel()
 98			got := isValidBase64(tt.input)
 99			require.Equal(t, tt.want, got)
100		})
101	}
102}