sourcegraph.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/ai"
 14)
 15
 16type SourcegraphParams struct {
 17	Query         string `json:"query" description:"The Sourcegraph search query"`
 18	Count         int    `json:"count,omitempty" description:"Optional number of results to return (default: 10, max: 20)"`
 19	ContextWindow int    `json:"context_window,omitempty" description:"The context around the match to return (default: 10 lines)"`
 20	Timeout       int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 21}
 22
 23type SourcegraphResponseMetadata struct {
 24	NumberOfMatches int  `json:"number_of_matches"`
 25	Truncated       bool `json:"truncated"`
 26}
 27
 28const (
 29	SourcegraphToolName = "sourcegraph"
 30)
 31
 32func NewSourcegraphTool() ai.AgentTool {
 33	client := &http.Client{
 34		Timeout: 30 * time.Second,
 35		Transport: &http.Transport{
 36			MaxIdleConns:        100,
 37			MaxIdleConnsPerHost: 10,
 38			IdleConnTimeout:     90 * time.Second,
 39		},
 40	}
 41	return ai.NewTypedToolFunc(
 42		SourcegraphToolName,
 43		`Search code across public repositories using Sourcegraph's GraphQL API.
 44
 45WHEN TO USE THIS TOOL:
 46- Use when you need to find code examples or implementations across public repositories
 47- Helpful for researching how others have solved similar problems
 48- Useful for discovering patterns and best practices in open source code
 49
 50HOW TO USE:
 51- Provide a search query using Sourcegraph's query syntax
 52- Optionally specify the number of results to return (default: 10)
 53- Optionally set a timeout for the request
 54
 55QUERY SYNTAX:
 56- Basic search: "fmt.Println" searches for exact matches
 57- File filters: "file:.go fmt.Println" limits to Go files
 58- Repository filters: "repo:^github\.com/golang/go$ fmt.Println" limits to specific repos
 59- Language filters: "lang:go fmt.Println" limits to Go code
 60- Boolean operators: "fmt.Println AND log.Fatal" for combined terms
 61- Regular expressions: "fmt\.(Print|Printf|Println)" for pattern matching
 62- Quoted strings: "\"exact phrase\"" for exact phrase matching
 63- Exclude filters: "-file:test" or "-repo:forks" to exclude matches
 64
 65ADVANCED FILTERS:
 66- Repository filters:
 67  * "repo:name" - Match repositories with name containing "name"
 68  * "repo:^github\.com/org/repo$" - Exact repository match
 69  * "repo:org/repo@branch" - Search specific branch
 70  * "repo:org/repo rev:branch" - Alternative branch syntax
 71  * "-repo:name" - Exclude repositories
 72  * "fork:yes" or "fork:only" - Include or only show forks
 73  * "archived:yes" or "archived:only" - Include or only show archived repos
 74  * "visibility:public" or "visibility:private" - Filter by visibility
 75
 76- File filters:
 77  * "file:\.js$" - Files with .js extension
 78  * "file:internal/" - Files in internal directory
 79  * "-file:test" - Exclude test files
 80  * "file:has.content(Copyright)" - Files containing "Copyright"
 81  * "file:has.contributor([email protected])" - Files with specific contributor
 82
 83- Content filters:
 84  * "content:\"exact string\"" - Search for exact string
 85  * "-content:\"unwanted\"" - Exclude files with unwanted content
 86  * "case:yes" - Case-sensitive search
 87
 88- Type filters:
 89  * "type:symbol" - Search for symbols (functions, classes, etc.)
 90  * "type:file" - Search file content only
 91  * "type:path" - Search filenames only
 92  * "type:diff" - Search code changes
 93  * "type:commit" - Search commit messages
 94
 95- Commit/diff search:
 96  * "after:\"1 month ago\"" - Commits after date
 97  * "before:\"2023-01-01\"" - Commits before date
 98  * "author:name" - Commits by author
 99  * "message:\"fix bug\"" - Commits with message
100
101- Result selection:
102  * "select:repo" - Show only repository names
103  * "select:file" - Show only file paths
104  * "select:content" - Show only matching content
105  * "select:symbol" - Show only matching symbols
106
107- Result control:
108  * "count:100" - Return up to 100 results
109  * "count:all" - Return all results
110  * "timeout:30s" - Set search timeout
111
112EXAMPLES:
113- "file:.go context.WithTimeout" - Find Go code using context.WithTimeout
114- "lang:typescript useState type:symbol" - Find TypeScript React useState hooks
115- "repo:^github\.com/kubernetes/kubernetes$ pod list type:file" - Find Kubernetes files related to pod listing
116- "repo:sourcegraph/sourcegraph$ after:\"3 months ago\" type:diff database" - Recent changes to database code
117- "file:Dockerfile (alpine OR ubuntu) -content:alpine:latest" - Dockerfiles with specific base images
118- "repo:has.path(\.py) file:requirements.txt tensorflow" - Python projects using TensorFlow
119
120BOOLEAN OPERATORS:
121- "term1 AND term2" - Results containing both terms
122- "term1 OR term2" - Results containing either term
123- "term1 NOT term2" - Results with term1 but not term2
124- "term1 and (term2 or term3)" - Grouping with parentheses
125
126LIMITATIONS:
127- Only searches public repositories
128- Rate limits may apply
129- Complex queries may take longer to execute
130- Maximum of 20 results per query
131
132TIPS:
133- Use specific file extensions to narrow results
134- Add repo: filters for more targeted searches
135- Use type:symbol to find function/method definitions
136- Use type:file to find relevant files`,
137		func(ctx context.Context, params SourcegraphParams, call ai.ToolCall) (ai.ToolResponse, error) {
138			if params.Query == "" {
139				return ai.NewTextErrorResponse("Query parameter is required"), nil
140			}
141
142			if params.Count <= 0 {
143				params.Count = 10
144			} else if params.Count > 20 {
145				params.Count = 20 // Limit to 20 results
146			}
147
148			if params.ContextWindow <= 0 {
149				params.ContextWindow = 10 // Default context window
150			}
151
152			// Handle timeout with context
153			requestCtx := ctx
154			if params.Timeout > 0 {
155				maxTimeout := 120 // 2 minutes
156				if params.Timeout > maxTimeout {
157					params.Timeout = maxTimeout
158				}
159				var cancel context.CancelFunc
160				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
161				defer cancel()
162			}
163
164			type graphqlRequest struct {
165				Query     string `json:"query"`
166				Variables struct {
167					Query string `json:"query"`
168				} `json:"variables"`
169			}
170
171			request := graphqlRequest{
172				Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
173			}
174			request.Variables.Query = params.Query
175
176			graphqlQueryBytes, err := json.Marshal(request)
177			if err != nil {
178				return ai.ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
179			}
180			graphqlQuery := string(graphqlQueryBytes)
181
182			req, err := http.NewRequestWithContext(
183				requestCtx,
184				"POST",
185				"https://sourcegraph.com/.api/graphql",
186				bytes.NewBuffer([]byte(graphqlQuery)),
187			)
188			if err != nil {
189				return ai.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
190			}
191
192			req.Header.Set("Content-Type", "application/json")
193			req.Header.Set("User-Agent", "crush/1.0")
194
195			resp, err := client.Do(req)
196			if err != nil {
197				return ai.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
198			}
199			defer resp.Body.Close()
200
201			if resp.StatusCode != http.StatusOK {
202				body, _ := io.ReadAll(resp.Body)
203				if len(body) > 0 {
204					return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
205				}
206
207				return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
208			}
209			body, err := io.ReadAll(resp.Body)
210			if err != nil {
211				return ai.ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
212			}
213
214			var result map[string]any
215			if err = json.Unmarshal(body, &result); err != nil {
216				return ai.ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
217			}
218
219			formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
220			if err != nil {
221				return ai.NewTextErrorResponse("Failed to format results: " + err.Error()), nil
222			}
223
224			return ai.NewTextResponse(formattedResults), nil
225		})
226}
227
228func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
229	var buffer strings.Builder
230
231	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
232		buffer.WriteString("## Sourcegraph API Error\n\n")
233		for _, err := range errors {
234			if errMap, ok := err.(map[string]any); ok {
235				if message, ok := errMap["message"].(string); ok {
236					buffer.WriteString(fmt.Sprintf("- %s\n", message))
237				}
238			}
239		}
240		return buffer.String(), nil
241	}
242
243	data, ok := result["data"].(map[string]any)
244	if !ok {
245		return "", fmt.Errorf("invalid response format: missing data field")
246	}
247
248	search, ok := data["search"].(map[string]any)
249	if !ok {
250		return "", fmt.Errorf("invalid response format: missing search field")
251	}
252
253	searchResults, ok := search["results"].(map[string]any)
254	if !ok {
255		return "", fmt.Errorf("invalid response format: missing results field")
256	}
257
258	matchCount, _ := searchResults["matchCount"].(float64)
259	resultCount, _ := searchResults["resultCount"].(float64)
260	limitHit, _ := searchResults["limitHit"].(bool)
261
262	buffer.WriteString("# Sourcegraph Search Results\n\n")
263	buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
264
265	if limitHit {
266		buffer.WriteString("(Result limit reached, try a more specific query)\n")
267	}
268
269	buffer.WriteString("\n")
270
271	results, ok := searchResults["results"].([]any)
272	if !ok || len(results) == 0 {
273		buffer.WriteString("No results found. Try a different query.\n")
274		return buffer.String(), nil
275	}
276
277	maxResults := 10
278	if len(results) > maxResults {
279		results = results[:maxResults]
280	}
281
282	for i, res := range results {
283		fileMatch, ok := res.(map[string]any)
284		if !ok {
285			continue
286		}
287
288		typeName, _ := fileMatch["__typename"].(string)
289		if typeName != "FileMatch" {
290			continue
291		}
292
293		repo, _ := fileMatch["repository"].(map[string]any)
294		file, _ := fileMatch["file"].(map[string]any)
295		lineMatches, _ := fileMatch["lineMatches"].([]any)
296
297		if repo == nil || file == nil {
298			continue
299		}
300
301		repoName, _ := repo["name"].(string)
302		filePath, _ := file["path"].(string)
303		fileURL, _ := file["url"].(string)
304		fileContent, _ := file["content"].(string)
305
306		buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
307
308		if fileURL != "" {
309			buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
310		}
311
312		if len(lineMatches) > 0 {
313			for _, lm := range lineMatches {
314				lineMatch, ok := lm.(map[string]any)
315				if !ok {
316					continue
317				}
318
319				lineNumber, _ := lineMatch["lineNumber"].(float64)
320				preview, _ := lineMatch["preview"].(string)
321
322				if fileContent != "" {
323					lines := strings.Split(fileContent, "\n")
324
325					buffer.WriteString("```\n")
326
327					startLine := max(1, int(lineNumber)-contextWindow)
328
329					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
330						if j >= 0 {
331							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
332						}
333					}
334
335					buffer.WriteString(fmt.Sprintf("%d|  %s\n", int(lineNumber), preview))
336
337					endLine := int(lineNumber) + contextWindow
338
339					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
340						if j < len(lines) {
341							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
342						}
343					}
344
345					buffer.WriteString("```\n\n")
346				} else {
347					buffer.WriteString("```\n")
348					buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
349					buffer.WriteString("```\n\n")
350				}
351			}
352		}
353	}
354
355	return buffer.String(), nil
356}