sourcegraph.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"html/template"
 10	"io"
 11	"net/http"
 12	"strings"
 13	"time"
 14
 15	"charm.land/fantasy"
 16)
 17
 18type SourcegraphParams struct {
 19	Query         string `json:"query" description:"The Sourcegraph search query"`
 20	Count         int    `json:"count,omitempty" description:"Optional number of results to return (default: 10, max: 20)"`
 21	ContextWindow int    `json:"context_window,omitempty" description:"The context around the match to return (default: 10 lines)"`
 22	Timeout       int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 23}
 24
 25type SourcegraphResponseMetadata struct {
 26	NumberOfMatches int  `json:"number_of_matches"`
 27	Truncated       bool `json:"truncated"`
 28}
 29
 30const SourcegraphToolName = "sourcegraph"
 31
 32//go:embed sourcegraph.md.tpl
 33var sourcegraphDescriptionTmpl []byte
 34
 35var sourcegraphDescriptionTpl = template.Must(
 36	template.New("sourcegraphDescription").
 37		Parse(string(sourcegraphDescriptionTmpl)),
 38)
 39
 40type sourcegraphDescriptionData struct {
 41	MaxResults int
 42}
 43
 44func sourcegraphDescription() string {
 45	return renderTemplate(sourcegraphDescriptionTpl, sourcegraphDescriptionData{
 46		MaxResults: 20,
 47	})
 48}
 49
 50func NewSourcegraphTool(client *http.Client) fantasy.AgentTool {
 51	if client == nil {
 52		transport := http.DefaultTransport.(*http.Transport).Clone()
 53		transport.MaxIdleConns = 100
 54		transport.MaxIdleConnsPerHost = 10
 55		transport.IdleConnTimeout = 90 * time.Second
 56
 57		client = &http.Client{
 58			Timeout:   30 * time.Second,
 59			Transport: transport,
 60		}
 61	}
 62	return fantasy.NewParallelAgentTool(
 63		SourcegraphToolName,
 64		sourcegraphDescription(),
 65		func(ctx context.Context, params SourcegraphParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 66			if params.Query == "" {
 67				return fantasy.NewTextErrorResponse("Query parameter is required"), nil
 68			}
 69
 70			if params.Count <= 0 {
 71				params.Count = 10
 72			} else if params.Count > 20 {
 73				params.Count = 20 // Limit to 20 results
 74			}
 75
 76			if params.ContextWindow <= 0 {
 77				params.ContextWindow = 10 // Default context window
 78			}
 79
 80			// Handle timeout with context
 81			requestCtx := ctx
 82			if params.Timeout > 0 {
 83				maxTimeout := 120 // 2 minutes
 84				if params.Timeout > maxTimeout {
 85					params.Timeout = maxTimeout
 86				}
 87				var cancel context.CancelFunc
 88				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 89				defer cancel()
 90			}
 91
 92			type graphqlRequest struct {
 93				Query     string `json:"query"`
 94				Variables struct {
 95					Query string `json:"query"`
 96				} `json:"variables"`
 97			}
 98
 99			request := graphqlRequest{
100				Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
101			}
102			request.Variables.Query = params.Query
103
104			graphqlQueryBytes, err := json.Marshal(request)
105			if err != nil {
106				return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
107			}
108			graphqlQuery := string(graphqlQueryBytes)
109
110			req, err := http.NewRequestWithContext(
111				requestCtx,
112				"POST",
113				"https://sourcegraph.com/.api/graphql",
114				bytes.NewBuffer([]byte(graphqlQuery)),
115			)
116			if err != nil {
117				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
118			}
119
120			req.Header.Set("Content-Type", "application/json")
121			req.Header.Set("User-Agent", "crush/1.0")
122
123			resp, err := client.Do(req)
124			if err != nil {
125				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
126			}
127			defer resp.Body.Close()
128
129			if resp.StatusCode != http.StatusOK {
130				body, _ := io.ReadAll(resp.Body)
131				if len(body) > 0 {
132					return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
133				}
134
135				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
136			}
137			body, err := io.ReadAll(resp.Body)
138			if err != nil {
139				return fantasy.ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
140			}
141
142			var result map[string]any
143			if err = json.Unmarshal(body, &result); err != nil {
144				return fantasy.ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
145			}
146
147			formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
148			if err != nil {
149				return fantasy.NewTextErrorResponse("Failed to format results: " + err.Error()), nil
150			}
151
152			return fantasy.NewTextResponse(formattedResults), nil
153		},
154	)
155}
156
157func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
158	var buffer strings.Builder
159
160	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
161		buffer.WriteString("## Sourcegraph API Error\n\n")
162		for _, err := range errors {
163			if errMap, ok := err.(map[string]any); ok {
164				if message, ok := errMap["message"].(string); ok {
165					fmt.Fprintf(&buffer, "- %s\n", message)
166				}
167			}
168		}
169		return buffer.String(), nil
170	}
171
172	data, ok := result["data"].(map[string]any)
173	if !ok {
174		return "", fmt.Errorf("invalid response format: missing data field")
175	}
176
177	search, ok := data["search"].(map[string]any)
178	if !ok {
179		return "", fmt.Errorf("invalid response format: missing search field")
180	}
181
182	searchResults, ok := search["results"].(map[string]any)
183	if !ok {
184		return "", fmt.Errorf("invalid response format: missing results field")
185	}
186
187	matchCount, _ := searchResults["matchCount"].(float64)
188	resultCount, _ := searchResults["resultCount"].(float64)
189	limitHit, _ := searchResults["limitHit"].(bool)
190
191	buffer.WriteString("# Sourcegraph Search Results\n\n")
192	fmt.Fprintf(&buffer, "Found %d matches across %d results\n", int(matchCount), int(resultCount))
193
194	if limitHit {
195		buffer.WriteString("(Result limit reached, try a more specific query)\n")
196	}
197
198	buffer.WriteString("\n")
199
200	results, ok := searchResults["results"].([]any)
201	if !ok || len(results) == 0 {
202		buffer.WriteString("No results found. Try a different query.\n")
203		return buffer.String(), nil
204	}
205
206	maxResults := 10
207	if len(results) > maxResults {
208		results = results[:maxResults]
209	}
210
211	for i, res := range results {
212		fileMatch, ok := res.(map[string]any)
213		if !ok {
214			continue
215		}
216
217		typeName, _ := fileMatch["__typename"].(string)
218		if typeName != "FileMatch" {
219			continue
220		}
221
222		repo, _ := fileMatch["repository"].(map[string]any)
223		file, _ := fileMatch["file"].(map[string]any)
224		lineMatches, _ := fileMatch["lineMatches"].([]any)
225
226		if repo == nil || file == nil {
227			continue
228		}
229
230		repoName, _ := repo["name"].(string)
231		filePath, _ := file["path"].(string)
232		fileURL, _ := file["url"].(string)
233		fileContent, _ := file["content"].(string)
234
235		fmt.Fprintf(&buffer, "## Result %d: %s/%s\n\n", i+1, repoName, filePath)
236
237		if fileURL != "" {
238			fmt.Fprintf(&buffer, "URL: %s\n\n", fileURL)
239		}
240
241		if len(lineMatches) > 0 {
242			for _, lm := range lineMatches {
243				lineMatch, ok := lm.(map[string]any)
244				if !ok {
245					continue
246				}
247
248				lineNumber, _ := lineMatch["lineNumber"].(float64)
249				preview, _ := lineMatch["preview"].(string)
250
251				if fileContent != "" {
252					lines := strings.Split(fileContent, "\n")
253
254					buffer.WriteString("```\n")
255
256					startLine := max(1, int(lineNumber)-contextWindow)
257
258					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
259						if j >= 0 {
260							fmt.Fprintf(&buffer, "%d| %s\n", j+1, lines[j])
261						}
262					}
263
264					fmt.Fprintf(&buffer, "%d|  %s\n", int(lineNumber), preview)
265
266					endLine := int(lineNumber) + contextWindow
267
268					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
269						if j < len(lines) {
270							fmt.Fprintf(&buffer, "%d| %s\n", j+1, lines[j])
271						}
272					}
273
274					buffer.WriteString("```\n\n")
275				} else {
276					buffer.WriteString("```\n")
277					fmt.Fprintf(&buffer, "%d| %s\n", int(lineNumber), preview)
278					buffer.WriteString("```\n\n")
279				}
280			}
281		}
282	}
283
284	return buffer.String(), nil
285}