1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "testing"
7
8 "github.com/kujtimiihoxha/termai/internal/permission"
9 "github.com/stretchr/testify/assert"
10 "github.com/stretchr/testify/require"
11)
12
13func TestSourcegraphTool_Info(t *testing.T) {
14 tool := NewSourcegraphTool()
15 info := tool.Info()
16
17 assert.Equal(t, SourcegraphToolName, info.Name)
18 assert.NotEmpty(t, info.Description)
19 assert.Contains(t, info.Parameters, "query")
20 assert.Contains(t, info.Parameters, "count")
21 assert.Contains(t, info.Parameters, "timeout")
22 assert.Contains(t, info.Required, "query")
23}
24
25func TestSourcegraphTool_Run(t *testing.T) {
26 // Setup a mock permission handler that always allows
27 origPermission := permission.Default
28 defer func() {
29 permission.Default = origPermission
30 }()
31 permission.Default = newMockPermissionService(true)
32
33 t.Run("handles missing query parameter", func(t *testing.T) {
34 tool := NewSourcegraphTool()
35 params := SourcegraphParams{
36 Query: "",
37 }
38
39 paramsJSON, err := json.Marshal(params)
40 require.NoError(t, err)
41
42 call := ToolCall{
43 Name: SourcegraphToolName,
44 Input: string(paramsJSON),
45 }
46
47 response, err := tool.Run(context.Background(), call)
48 require.NoError(t, err)
49 assert.Contains(t, response.Content, "Query parameter is required")
50 })
51
52 t.Run("handles invalid parameters", func(t *testing.T) {
53 tool := NewSourcegraphTool()
54 call := ToolCall{
55 Name: SourcegraphToolName,
56 Input: "invalid json",
57 }
58
59 response, err := tool.Run(context.Background(), call)
60 require.NoError(t, err)
61 assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
62 })
63
64 t.Run("handles permission denied", func(t *testing.T) {
65 permission.Default = newMockPermissionService(false)
66
67 tool := NewSourcegraphTool()
68 params := SourcegraphParams{
69 Query: "test query",
70 }
71
72 paramsJSON, err := json.Marshal(params)
73 require.NoError(t, err)
74
75 call := ToolCall{
76 Name: SourcegraphToolName,
77 Input: string(paramsJSON),
78 }
79
80 response, err := tool.Run(context.Background(), call)
81 require.NoError(t, err)
82 assert.Contains(t, response.Content, "Permission denied")
83 })
84
85 t.Run("normalizes count parameter", func(t *testing.T) {
86 // Test cases for count normalization
87 testCases := []struct {
88 name string
89 inputCount int
90 expectedCount int
91 }{
92 {"negative count", -5, 10}, // Should use default (10)
93 {"zero count", 0, 10}, // Should use default (10)
94 {"valid count", 50, 50}, // Should keep as is
95 {"excessive count", 150, 100}, // Should cap at 100
96 }
97
98 for _, tc := range testCases {
99 t.Run(tc.name, func(t *testing.T) {
100 // Verify count normalization logic directly
101 assert.NotPanics(t, func() {
102 // Apply the same normalization logic as in the tool
103 normalizedCount := tc.inputCount
104 if normalizedCount <= 0 {
105 normalizedCount = 10
106 } else if normalizedCount > 100 {
107 normalizedCount = 100
108 }
109
110 assert.Equal(t, tc.expectedCount, normalizedCount)
111 })
112 })
113 }
114 })
115}