sourcegraph_test.go

 1package tools
 2
 3import (
 4	"context"
 5	"encoding/json"
 6	"testing"
 7
 8	"github.com/stretchr/testify/assert"
 9	"github.com/stretchr/testify/require"
10)
11
12func TestSourcegraphTool_Info(t *testing.T) {
13	tool := NewSourcegraphTool()
14	info := tool.Info()
15
16	assert.Equal(t, SourcegraphToolName, info.Name)
17	assert.NotEmpty(t, info.Description)
18	assert.Contains(t, info.Parameters, "query")
19	assert.Contains(t, info.Parameters, "count")
20	assert.Contains(t, info.Parameters, "timeout")
21	assert.Contains(t, info.Required, "query")
22}
23
24func TestSourcegraphTool_Run(t *testing.T) {
25	t.Run("handles missing query parameter", func(t *testing.T) {
26		tool := NewSourcegraphTool()
27		params := SourcegraphParams{
28			Query: "",
29		}
30
31		paramsJSON, err := json.Marshal(params)
32		require.NoError(t, err)
33
34		call := ToolCall{
35			Name:  SourcegraphToolName,
36			Input: string(paramsJSON),
37		}
38
39		response, err := tool.Run(context.Background(), call)
40		require.NoError(t, err)
41		assert.Contains(t, response.Content, "Query parameter is required")
42	})
43
44	t.Run("handles invalid parameters", func(t *testing.T) {
45		tool := NewSourcegraphTool()
46		call := ToolCall{
47			Name:  SourcegraphToolName,
48			Input: "invalid json",
49		}
50
51		response, err := tool.Run(context.Background(), call)
52		require.NoError(t, err)
53		assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
54	})
55
56	t.Run("normalizes count parameter", func(t *testing.T) {
57		// Test cases for count normalization
58		testCases := []struct {
59			name          string
60			inputCount    int
61			expectedCount int
62		}{
63			{"negative count", -5, 10},    // Should use default (10)
64			{"zero count", 0, 10},         // Should use default (10)
65			{"valid count", 50, 50},       // Should keep as is
66			{"excessive count", 150, 100}, // Should cap at 100
67		}
68
69		for _, tc := range testCases {
70			t.Run(tc.name, func(t *testing.T) {
71				// Verify count normalization logic directly
72				assert.NotPanics(t, func() {
73					// Apply the same normalization logic as in the tool
74					normalizedCount := tc.inputCount
75					if normalizedCount <= 0 {
76						normalizedCount = 10
77					} else if normalizedCount > 100 {
78						normalizedCount = 100
79					}
80
81					assert.Equal(t, tc.expectedCount, normalizedCount)
82				})
83			})
84		}
85	})
86}