1package tools
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "fmt"
9 "io"
10 "net/http"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/proto"
15)
16
17type (
18 SourcegraphParams = proto.SourcegraphParams
19 SourcegraphResponseMetadata = proto.SourcegraphResponseMetadata
20)
21
22type sourcegraphTool struct {
23 client *http.Client
24}
25
26const SourcegraphToolName = proto.SourcegraphToolName
27
28//go:embed sourcegraph.md
29var sourcegraphDescription []byte
30
31func NewSourcegraphTool() BaseTool {
32 return &sourcegraphTool{
33 client: &http.Client{
34 Timeout: 30 * time.Second,
35 Transport: &http.Transport{
36 MaxIdleConns: 100,
37 MaxIdleConnsPerHost: 10,
38 IdleConnTimeout: 90 * time.Second,
39 },
40 },
41 }
42}
43
44func (t *sourcegraphTool) Name() string {
45 return SourcegraphToolName
46}
47
48func (t *sourcegraphTool) Info() ToolInfo {
49 return ToolInfo{
50 Name: SourcegraphToolName,
51 Description: string(sourcegraphDescription),
52 Parameters: map[string]any{
53 "query": map[string]any{
54 "type": "string",
55 "description": "The Sourcegraph search query",
56 },
57 "count": map[string]any{
58 "type": "number",
59 "description": "Optional number of results to return (default: 10, max: 20)",
60 },
61 "context_window": map[string]any{
62 "type": "number",
63 "description": "The context around the match to return (default: 10 lines)",
64 },
65 "timeout": map[string]any{
66 "type": "number",
67 "description": "Optional timeout in seconds (max 120)",
68 },
69 },
70 Required: []string{"query"},
71 }
72}
73
74func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
75 var params SourcegraphParams
76 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
77 return NewTextErrorResponse("Failed to parse sourcegraph parameters: " + err.Error()), nil
78 }
79
80 if params.Query == "" {
81 return NewTextErrorResponse("Query parameter is required"), nil
82 }
83
84 if params.Count <= 0 {
85 params.Count = 10
86 } else if params.Count > 20 {
87 params.Count = 20 // Limit to 20 results
88 }
89
90 if params.ContextWindow <= 0 {
91 params.ContextWindow = 10 // Default context window
92 }
93
94 // Handle timeout with context
95 requestCtx := ctx
96 if params.Timeout > 0 {
97 maxTimeout := 120 // 2 minutes
98 if params.Timeout > maxTimeout {
99 params.Timeout = maxTimeout
100 }
101 var cancel context.CancelFunc
102 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
103 defer cancel()
104 }
105
106 type graphqlRequest struct {
107 Query string `json:"query"`
108 Variables struct {
109 Query string `json:"query"`
110 } `json:"variables"`
111 }
112
113 request := graphqlRequest{
114 Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
115 }
116 request.Variables.Query = params.Query
117
118 graphqlQueryBytes, err := json.Marshal(request)
119 if err != nil {
120 return ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
121 }
122 graphqlQuery := string(graphqlQueryBytes)
123
124 req, err := http.NewRequestWithContext(
125 requestCtx,
126 "POST",
127 "https://sourcegraph.com/.api/graphql",
128 bytes.NewBuffer([]byte(graphqlQuery)),
129 )
130 if err != nil {
131 return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
132 }
133
134 req.Header.Set("Content-Type", "application/json")
135 req.Header.Set("User-Agent", "crush/1.0")
136
137 resp, err := t.client.Do(req)
138 if err != nil {
139 return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
140 }
141 defer resp.Body.Close()
142
143 if resp.StatusCode != http.StatusOK {
144 body, _ := io.ReadAll(resp.Body)
145 if len(body) > 0 {
146 return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
147 }
148
149 return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
150 }
151 body, err := io.ReadAll(resp.Body)
152 if err != nil {
153 return ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
154 }
155
156 var result map[string]any
157 if err = json.Unmarshal(body, &result); err != nil {
158 return ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
159 }
160
161 formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
162 if err != nil {
163 return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
164 }
165
166 return NewTextResponse(formattedResults), nil
167}
168
169func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
170 var buffer strings.Builder
171
172 if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
173 buffer.WriteString("## Sourcegraph API Error\n\n")
174 for _, err := range errors {
175 if errMap, ok := err.(map[string]any); ok {
176 if message, ok := errMap["message"].(string); ok {
177 buffer.WriteString(fmt.Sprintf("- %s\n", message))
178 }
179 }
180 }
181 return buffer.String(), nil
182 }
183
184 data, ok := result["data"].(map[string]any)
185 if !ok {
186 return "", fmt.Errorf("invalid response format: missing data field")
187 }
188
189 search, ok := data["search"].(map[string]any)
190 if !ok {
191 return "", fmt.Errorf("invalid response format: missing search field")
192 }
193
194 searchResults, ok := search["results"].(map[string]any)
195 if !ok {
196 return "", fmt.Errorf("invalid response format: missing results field")
197 }
198
199 matchCount, _ := searchResults["matchCount"].(float64)
200 resultCount, _ := searchResults["resultCount"].(float64)
201 limitHit, _ := searchResults["limitHit"].(bool)
202
203 buffer.WriteString("# Sourcegraph Search Results\n\n")
204 buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
205
206 if limitHit {
207 buffer.WriteString("(Result limit reached, try a more specific query)\n")
208 }
209
210 buffer.WriteString("\n")
211
212 results, ok := searchResults["results"].([]any)
213 if !ok || len(results) == 0 {
214 buffer.WriteString("No results found. Try a different query.\n")
215 return buffer.String(), nil
216 }
217
218 maxResults := 10
219 if len(results) > maxResults {
220 results = results[:maxResults]
221 }
222
223 for i, res := range results {
224 fileMatch, ok := res.(map[string]any)
225 if !ok {
226 continue
227 }
228
229 typeName, _ := fileMatch["__typename"].(string)
230 if typeName != "FileMatch" {
231 continue
232 }
233
234 repo, _ := fileMatch["repository"].(map[string]any)
235 file, _ := fileMatch["file"].(map[string]any)
236 lineMatches, _ := fileMatch["lineMatches"].([]any)
237
238 if repo == nil || file == nil {
239 continue
240 }
241
242 repoName, _ := repo["name"].(string)
243 filePath, _ := file["path"].(string)
244 fileURL, _ := file["url"].(string)
245 fileContent, _ := file["content"].(string)
246
247 buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
248
249 if fileURL != "" {
250 buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
251 }
252
253 if len(lineMatches) > 0 {
254 for _, lm := range lineMatches {
255 lineMatch, ok := lm.(map[string]any)
256 if !ok {
257 continue
258 }
259
260 lineNumber, _ := lineMatch["lineNumber"].(float64)
261 preview, _ := lineMatch["preview"].(string)
262
263 if fileContent != "" {
264 lines := strings.Split(fileContent, "\n")
265
266 buffer.WriteString("```\n")
267
268 startLine := max(1, int(lineNumber)-contextWindow)
269
270 for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
271 if j >= 0 {
272 buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
273 }
274 }
275
276 buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
277
278 endLine := int(lineNumber) + contextWindow
279
280 for j := int(lineNumber); j < endLine && j < len(lines); j++ {
281 if j < len(lines) {
282 buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
283 }
284 }
285
286 buffer.WriteString("```\n\n")
287 } else {
288 buffer.WriteString("```\n")
289 buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
290 buffer.WriteString("```\n\n")
291 }
292 }
293 }
294 }
295
296 return buffer.String(), nil
297}