1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "testing"
7
8 "github.com/stretchr/testify/assert"
9 "github.com/stretchr/testify/require"
10)
11
12func TestSourcegraphTool_Info(t *testing.T) {
13 tool := NewSourcegraphTool()
14 info := tool.Info()
15
16 assert.Equal(t, SourcegraphToolName, info.Name)
17 assert.NotEmpty(t, info.Description)
18 assert.Contains(t, info.Parameters, "query")
19 assert.Contains(t, info.Parameters, "count")
20 assert.Contains(t, info.Parameters, "timeout")
21 assert.Contains(t, info.Required, "query")
22}
23
24func TestSourcegraphTool_Run(t *testing.T) {
25 t.Run("handles missing query parameter", func(t *testing.T) {
26 tool := NewSourcegraphTool()
27 params := SourcegraphParams{
28 Query: "",
29 }
30
31 paramsJSON, err := json.Marshal(params)
32 require.NoError(t, err)
33
34 call := ToolCall{
35 Name: SourcegraphToolName,
36 Input: string(paramsJSON),
37 }
38
39 response, err := tool.Run(context.Background(), call)
40 require.NoError(t, err)
41 assert.Contains(t, response.Content, "Query parameter is required")
42 })
43
44 t.Run("handles invalid parameters", func(t *testing.T) {
45 tool := NewSourcegraphTool()
46 call := ToolCall{
47 Name: SourcegraphToolName,
48 Input: "invalid json",
49 }
50
51 response, err := tool.Run(context.Background(), call)
52 require.NoError(t, err)
53 assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
54 })
55
56 t.Run("normalizes count parameter", func(t *testing.T) {
57 // Test cases for count normalization
58 testCases := []struct {
59 name string
60 inputCount int
61 expectedCount int
62 }{
63 {"negative count", -5, 10}, // Should use default (10)
64 {"zero count", 0, 10}, // Should use default (10)
65 {"valid count", 50, 50}, // Should keep as is
66 {"excessive count", 150, 100}, // Should cap at 100
67 }
68
69 for _, tc := range testCases {
70 t.Run(tc.name, func(t *testing.T) {
71 // Verify count normalization logic directly
72 assert.NotPanics(t, func() {
73 // Apply the same normalization logic as in the tool
74 normalizedCount := tc.inputCount
75 if normalizedCount <= 0 {
76 normalizedCount = 10
77 } else if normalizedCount > 100 {
78 normalizedCount = 100
79 }
80
81 assert.Equal(t, tc.expectedCount, normalizedCount)
82 })
83 })
84 }
85 })
86}