sourcegraph_test.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"testing"
  7
  8	"github.com/kujtimiihoxha/termai/internal/permission"
  9	"github.com/stretchr/testify/assert"
 10	"github.com/stretchr/testify/require"
 11)
 12
 13func TestSourcegraphTool_Info(t *testing.T) {
 14	tool := NewSourcegraphTool()
 15	info := tool.Info()
 16
 17	assert.Equal(t, SourcegraphToolName, info.Name)
 18	assert.NotEmpty(t, info.Description)
 19	assert.Contains(t, info.Parameters, "query")
 20	assert.Contains(t, info.Parameters, "count")
 21	assert.Contains(t, info.Parameters, "timeout")
 22	assert.Contains(t, info.Required, "query")
 23}
 24
 25func TestSourcegraphTool_Run(t *testing.T) {
 26	// Setup a mock permission handler that always allows
 27	origPermission := permission.Default
 28	defer func() {
 29		permission.Default = origPermission
 30	}()
 31	permission.Default = newMockPermissionService(true)
 32
 33	t.Run("handles missing query parameter", func(t *testing.T) {
 34		tool := NewSourcegraphTool()
 35		params := SourcegraphParams{
 36			Query: "",
 37		}
 38
 39		paramsJSON, err := json.Marshal(params)
 40		require.NoError(t, err)
 41
 42		call := ToolCall{
 43			Name:  SourcegraphToolName,
 44			Input: string(paramsJSON),
 45		}
 46
 47		response, err := tool.Run(context.Background(), call)
 48		require.NoError(t, err)
 49		assert.Contains(t, response.Content, "Query parameter is required")
 50	})
 51
 52	t.Run("handles invalid parameters", func(t *testing.T) {
 53		tool := NewSourcegraphTool()
 54		call := ToolCall{
 55			Name:  SourcegraphToolName,
 56			Input: "invalid json",
 57		}
 58
 59		response, err := tool.Run(context.Background(), call)
 60		require.NoError(t, err)
 61		assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
 62	})
 63
 64	t.Run("handles permission denied", func(t *testing.T) {
 65		permission.Default = newMockPermissionService(false)
 66
 67		tool := NewSourcegraphTool()
 68		params := SourcegraphParams{
 69			Query: "test query",
 70		}
 71
 72		paramsJSON, err := json.Marshal(params)
 73		require.NoError(t, err)
 74
 75		call := ToolCall{
 76			Name:  SourcegraphToolName,
 77			Input: string(paramsJSON),
 78		}
 79
 80		response, err := tool.Run(context.Background(), call)
 81		require.NoError(t, err)
 82		assert.Contains(t, response.Content, "Permission denied")
 83	})
 84
 85	t.Run("normalizes count parameter", func(t *testing.T) {
 86		// Test cases for count normalization
 87		testCases := []struct {
 88			name          string
 89			inputCount    int
 90			expectedCount int
 91		}{
 92			{"negative count", -5, 10},    // Should use default (10)
 93			{"zero count", 0, 10},         // Should use default (10)
 94			{"valid count", 50, 50},       // Should keep as is
 95			{"excessive count", 150, 100}, // Should cap at 100
 96		}
 97
 98		for _, tc := range testCases {
 99			t.Run(tc.name, func(t *testing.T) {
100				// Verify count normalization logic directly
101				assert.NotPanics(t, func() {
102					// Apply the same normalization logic as in the tool
103					normalizedCount := tc.inputCount
104					if normalizedCount <= 0 {
105						normalizedCount = 10
106					} else if normalizedCount > 100 {
107						normalizedCount = 100
108					}
109
110					assert.Equal(t, tc.expectedCount, normalizedCount)
111				})
112			})
113		}
114	})
115}