sourcegraph.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"html/template"
 10	"io"
 11	"net/http"
 12	"strings"
 13	"time"
 14
 15	"charm.land/fantasy"
 16)
 17
 18type SourcegraphParams struct {
 19	Query         string `json:"query" description:"The Sourcegraph search query"`
 20	Count         int    `json:"count,omitempty" description:"Optional number of results to return (default: 10, max: 20)"`
 21	ContextWindow int    `json:"context_window,omitempty" description:"The context around the match to return (default: 10 lines)"`
 22	Timeout       int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 23}
 24
 25type SourcegraphResponseMetadata struct {
 26	NumberOfMatches int  `json:"number_of_matches"`
 27	Truncated       bool `json:"truncated"`
 28}
 29
 30const SourcegraphToolName = "sourcegraph"
 31
 32//go:embed sourcegraph.md.tpl
 33var sourcegraphDescriptionTmpl []byte
 34
 35var sourcegraphDescriptionTpl = template.Must(
 36	template.New("sourcegraphDescription").
 37		Parse(string(sourcegraphDescriptionTmpl)),
 38)
 39
 40type sourcegraphDescriptionData struct {
 41	MaxResults int
 42}
 43
 44func sourcegraphDescription() string {
 45	return renderTemplate(sourcegraphDescriptionTpl, sourcegraphDescriptionData{
 46		MaxResults: 20,
 47	})
 48}
 49
 50func NewSourcegraphTool(client *http.Client) fantasy.AgentTool {
 51	if client == nil {
 52		transport := http.DefaultTransport.(*http.Transport).Clone()
 53		transport.MaxIdleConns = 100
 54		transport.MaxIdleConnsPerHost = 10
 55		transport.IdleConnTimeout = 90 * time.Second
 56
 57		client = &http.Client{
 58			Timeout:   30 * time.Second,
 59			Transport: transport,
 60		}
 61	}
 62	return fantasy.NewParallelAgentTool(
 63		SourcegraphToolName,
 64		sourcegraphDescription(),
 65		func(ctx context.Context, params SourcegraphParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 66			if params.Query == "" {
 67				return fantasy.NewTextErrorResponse("Query parameter is required"), nil
 68			}
 69
 70			if params.Count <= 0 {
 71				params.Count = 10
 72			} else if params.Count > 20 {
 73				params.Count = 20 // Limit to 20 results
 74			}
 75
 76			if params.ContextWindow <= 0 {
 77				params.ContextWindow = 10 // Default context window
 78			}
 79
 80			// Handle timeout with context
 81			requestCtx := ctx
 82			if params.Timeout > 0 {
 83				maxTimeout := 120 // 2 minutes
 84				if params.Timeout > maxTimeout {
 85					params.Timeout = maxTimeout
 86				}
 87				var cancel context.CancelFunc
 88				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 89				defer cancel()
 90			}
 91
 92			type graphqlRequest struct {
 93				Query     string `json:"query"`
 94				Variables struct {
 95					Query string `json:"query"`
 96				} `json:"variables"`
 97			}
 98
 99			request := graphqlRequest{
100				Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
101			}
102			request.Variables.Query = params.Query
103
104			graphqlQueryBytes, err := json.Marshal(request)
105			if err != nil {
106				return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
107			}
108			graphqlQuery := string(graphqlQueryBytes)
109
110			req, err := http.NewRequestWithContext(
111				requestCtx,
112				"POST",
113				"https://sourcegraph.com/.api/graphql",
114				bytes.NewBuffer([]byte(graphqlQuery)),
115			)
116			if err != nil {
117				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
118			}
119
120			req.Header.Set("Content-Type", "application/json")
121			req.Header.Set("User-Agent", "crush/1.0")
122
123			resp, err := client.Do(req)
124			if err != nil {
125				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
126			}
127			defer resp.Body.Close()
128
129			if resp.StatusCode != http.StatusOK {
130				body, _ := io.ReadAll(resp.Body)
131				if len(body) > 0 {
132					return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
133				}
134
135				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
136			}
137			body, err := io.ReadAll(resp.Body)
138			if err != nil {
139				return fantasy.ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
140			}
141
142			var result map[string]any
143			if err = json.Unmarshal(body, &result); err != nil {
144				return fantasy.ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
145			}
146
147			formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
148			if err != nil {
149				return fantasy.NewTextErrorResponse("Failed to format results: " + err.Error()), nil
150			}
151
152			return fantasy.NewTextResponse(formattedResults), nil
153		})
154}
155
156func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
157	var buffer strings.Builder
158
159	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
160		buffer.WriteString("## Sourcegraph API Error\n\n")
161		for _, err := range errors {
162			if errMap, ok := err.(map[string]any); ok {
163				if message, ok := errMap["message"].(string); ok {
164					fmt.Fprintf(&buffer, "- %s\n", message)
165				}
166			}
167		}
168		return buffer.String(), nil
169	}
170
171	data, ok := result["data"].(map[string]any)
172	if !ok {
173		return "", fmt.Errorf("invalid response format: missing data field")
174	}
175
176	search, ok := data["search"].(map[string]any)
177	if !ok {
178		return "", fmt.Errorf("invalid response format: missing search field")
179	}
180
181	searchResults, ok := search["results"].(map[string]any)
182	if !ok {
183		return "", fmt.Errorf("invalid response format: missing results field")
184	}
185
186	matchCount, _ := searchResults["matchCount"].(float64)
187	resultCount, _ := searchResults["resultCount"].(float64)
188	limitHit, _ := searchResults["limitHit"].(bool)
189
190	buffer.WriteString("# Sourcegraph Search Results\n\n")
191	fmt.Fprintf(&buffer, "Found %d matches across %d results\n", int(matchCount), int(resultCount))
192
193	if limitHit {
194		buffer.WriteString("(Result limit reached, try a more specific query)\n")
195	}
196
197	buffer.WriteString("\n")
198
199	results, ok := searchResults["results"].([]any)
200	if !ok || len(results) == 0 {
201		buffer.WriteString("No results found. Try a different query.\n")
202		return buffer.String(), nil
203	}
204
205	maxResults := 10
206	if len(results) > maxResults {
207		results = results[:maxResults]
208	}
209
210	for i, res := range results {
211		fileMatch, ok := res.(map[string]any)
212		if !ok {
213			continue
214		}
215
216		typeName, _ := fileMatch["__typename"].(string)
217		if typeName != "FileMatch" {
218			continue
219		}
220
221		repo, _ := fileMatch["repository"].(map[string]any)
222		file, _ := fileMatch["file"].(map[string]any)
223		lineMatches, _ := fileMatch["lineMatches"].([]any)
224
225		if repo == nil || file == nil {
226			continue
227		}
228
229		repoName, _ := repo["name"].(string)
230		filePath, _ := file["path"].(string)
231		fileURL, _ := file["url"].(string)
232		fileContent, _ := file["content"].(string)
233
234		fmt.Fprintf(&buffer, "## Result %d: %s/%s\n\n", i+1, repoName, filePath)
235
236		if fileURL != "" {
237			fmt.Fprintf(&buffer, "URL: %s\n\n", fileURL)
238		}
239
240		if len(lineMatches) > 0 {
241			for _, lm := range lineMatches {
242				lineMatch, ok := lm.(map[string]any)
243				if !ok {
244					continue
245				}
246
247				lineNumber, _ := lineMatch["lineNumber"].(float64)
248				preview, _ := lineMatch["preview"].(string)
249
250				if fileContent != "" {
251					lines := strings.Split(fileContent, "\n")
252
253					buffer.WriteString("```\n")
254
255					startLine := max(1, int(lineNumber)-contextWindow)
256
257					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
258						if j >= 0 {
259							fmt.Fprintf(&buffer, "%d| %s\n", j+1, lines[j])
260						}
261					}
262
263					fmt.Fprintf(&buffer, "%d|  %s\n", int(lineNumber), preview)
264
265					endLine := int(lineNumber) + contextWindow
266
267					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
268						if j < len(lines) {
269							fmt.Fprintf(&buffer, "%d| %s\n", j+1, lines[j])
270						}
271					}
272
273					buffer.WriteString("```\n\n")
274				} else {
275					buffer.WriteString("```\n")
276					fmt.Fprintf(&buffer, "%d| %s\n", int(lineNumber), preview)
277					buffer.WriteString("```\n\n")
278				}
279			}
280		}
281	}
282
283	return buffer.String(), nil
284}