tools_test.go

  1package mcp
  2
  3import (
  4	"bytes"
  5	"encoding/base64"
  6	"testing"
  7
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12func TestEnsureRawBytes(t *testing.T) {
 13	t.Parallel()
 14
 15	tests := []struct {
 16		name     string
 17		input    []byte
 18		wantData []byte
 19	}{
 20		{
 21			name:     "already base64 encoded",
 22			input:    []byte("SGVsbG8gV29ybGQh"), // "Hello World!" in base64
 23			wantData: []byte("Hello World!"),
 24		},
 25		{
 26			name:     "raw binary data (PNG header)",
 27			input:    []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
 28			wantData: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A},
 29		},
 30		{
 31			name:     "raw binary with high bytes",
 32			input:    []byte{0xFF, 0xD8, 0xFF, 0xE0}, // JPEG header
 33			wantData: []byte{0xFF, 0xD8, 0xFF, 0xE0},
 34		},
 35		{
 36			name:     "empty data",
 37			input:    []byte{},
 38			wantData: []byte{},
 39		},
 40		{
 41			name:     "base64 with padding",
 42			input:    []byte("YQ=="), // "a" in base64
 43			wantData: []byte("a"),
 44		},
 45		{
 46			name:     "base64 without padding",
 47			input:    []byte("YQ"),
 48			wantData: []byte("a"),
 49		},
 50		{
 51			name:     "base64 with whitespace",
 52			input:    []byte("U0dWc2JHOGdWMjl5YkdRaA==\n"),
 53			wantData: []byte("SGVsbG8gV29ybGQh"),
 54		},
 55	}
 56
 57	for _, tt := range tests {
 58		t.Run(tt.name, func(t *testing.T) {
 59			t.Parallel()
 60			result := ensureRawBytes(tt.input)
 61			require.Equal(t, tt.wantData, result)
 62
 63			if len(result) > 0 && !bytes.Equal(result, tt.input) {
 64				reEncoded := base64.StdEncoding.EncodeToString(result)
 65				_, err := base64.StdEncoding.DecodeString(reEncoded)
 66				require.NoError(t, err, "re-encoded result should be valid base64")
 67			}
 68		})
 69	}
 70}
 71
 72func TestFilterTools(t *testing.T) {
 73	t.Parallel()
 74
 75	tools := []*Tool{
 76		{Name: "tool_a"},
 77		{Name: "tool_b"},
 78		{Name: "tool_c"},
 79	}
 80
 81	t.Run("no filters returns all tools", func(t *testing.T) {
 82		t.Parallel()
 83		result := filterTools(config.MCPConfig{}, tools)
 84		require.Len(t, result, 3)
 85	})
 86
 87	t.Run("disabled tools filters deny list", func(t *testing.T) {
 88		t.Parallel()
 89		result := filterTools(config.MCPConfig{DisabledTools: []string{"tool_a"}}, tools)
 90		require.Len(t, result, 2)
 91		require.Equal(t, "tool_b", result[0].Name)
 92		require.Equal(t, "tool_c", result[1].Name)
 93	})
 94
 95	t.Run("enabled tools acts as allow list", func(t *testing.T) {
 96		t.Parallel()
 97		result := filterTools(config.MCPConfig{EnabledTools: []string{"tool_b"}}, tools)
 98		require.Len(t, result, 1)
 99		require.Equal(t, "tool_b", result[0].Name)
100	})
101
102	t.Run("enabled and disabled both apply", func(t *testing.T) {
103		t.Parallel()
104		result := filterTools(config.MCPConfig{
105			EnabledTools:  []string{"tool_a", "tool_b"},
106			DisabledTools: []string{"tool_b"},
107		}, tools)
108		require.Len(t, result, 1)
109		require.Equal(t, "tool_a", result[0].Name)
110	})
111
112	t.Run("enabled with non-existent tool returns empty", func(t *testing.T) {
113		t.Parallel()
114		result := filterTools(config.MCPConfig{EnabledTools: []string{"non_existent"}}, tools)
115		require.Len(t, result, 0)
116	})
117}