tools_test.go

 1package mcp
 2
 3import (
 4	"bytes"
 5	"encoding/base64"
 6	"testing"
 7
 8	"github.com/stretchr/testify/require"
 9)
10
11func TestEnsureRawBytes(t *testing.T) {
12	t.Parallel()
13
14	tests := []struct {
15		name     string
16		input    []byte
17		wantData []byte
18	}{
19		{
20			name:     "already base64 encoded",
21			input:    []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
22			wantData: []byte("Hello World!"),
23		},
24		{
25			name:     "raw binary data (PNG header)",
26			input:    []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
27			wantData: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
28		},
29		{
30			name:     "raw binary with high bytes",
31			input:    []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
32			wantData: []byte{0xFF, 0xD8, 0xFF, 0xE0},
33		},
34		{
35			name:     "empty data",
36			input:    []byte{},
37			wantData: []byte{},
38		},
39		{
40			name:     "base64 with padding",
41			input:    []byte("YQ=="), // "a" in base64
42			wantData: []byte("a"),
43		},
44		{
45			name:     "base64 without padding",
46			input:    []byte("YQ"),
47			wantData: []byte("a"),
48		},
49		{
50			name:     "base64 with whitespace",
51			input:    []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
52			wantData: []byte("SGVsbG8gV29ybGQh"),
53		},
54	}
55
56	for _, tt := range tests {
57		t.Run(tt.name, func(t *testing.T) {
58			t.Parallel()
59			result := ensureRawBytes(tt.input)
60			require.Equal(t, tt.wantData, result)
61
62			if len(result) > 0 && !bytes.Equal(result, tt.input) {
63				reEncoded := base64.StdEncoding.EncodeToString(result)
64				_, err := base64.StdEncoding.DecodeString(reEncoded)
65				require.NoError(t, err, "re-encoded result should be valid base64")
66			}
67		})
68	}
69}