sourcegraph.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"net/http"
 11	"strings"
 12	"time"
 13
 14	"charm.land/fantasy"
 15)
 16
 17type SourcegraphParams struct {
 18	Query         string `json:"query" description:"The Sourcegraph search query"`
 19	Count         int    `json:"count,omitempty" description:"Optional number of results to return (default: 10, max: 20)"`
 20	ContextWindow int    `json:"context_window,omitempty" description:"The context around the match to return (default: 10 lines)"`
 21	Timeout       int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 22}
 23
 24type SourcegraphResponseMetadata struct {
 25	NumberOfMatches int  `json:"number_of_matches"`
 26	Truncated       bool `json:"truncated"`
 27}
 28
 29const SourcegraphToolName = "sourcegraph"
 30
 31//go:embed sourcegraph.md
 32var sourcegraphDescription []byte
 33
 34func NewSourcegraphTool(client *http.Client) fantasy.AgentTool {
 35	if client == nil {
 36		transport := http.DefaultTransport.(*http.Transport).Clone()
 37		transport.MaxIdleConns = 100
 38		transport.MaxIdleConnsPerHost = 10
 39		transport.IdleConnTimeout = 90 * time.Second
 40
 41		client = &http.Client{
 42			Timeout:   30 * time.Second,
 43			Transport: transport,
 44		}
 45	}
 46	return fantasy.NewParallelAgentTool(
 47		SourcegraphToolName,
 48		string(sourcegraphDescription),
 49		func(ctx context.Context, params SourcegraphParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 50			if params.Query == "" {
 51				return fantasy.NewTextErrorResponse("Query parameter is required"), nil
 52			}
 53
 54			if params.Count <= 0 {
 55				params.Count = 10
 56			} else if params.Count > 20 {
 57				params.Count = 20 // Limit to 20 results
 58			}
 59
 60			if params.ContextWindow <= 0 {
 61				params.ContextWindow = 10 // Default context window
 62			}
 63
 64			// Handle timeout with context
 65			requestCtx := ctx
 66			if params.Timeout > 0 {
 67				maxTimeout := 120 // 2 minutes
 68				if params.Timeout > maxTimeout {
 69					params.Timeout = maxTimeout
 70				}
 71				var cancel context.CancelFunc
 72				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 73				defer cancel()
 74			}
 75
 76			type graphqlRequest struct {
 77				Query     string `json:"query"`
 78				Variables struct {
 79					Query string `json:"query"`
 80				} `json:"variables"`
 81			}
 82
 83			request := graphqlRequest{
 84				Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
 85			}
 86			request.Variables.Query = params.Query
 87
 88			graphqlQueryBytes, err := json.Marshal(request)
 89			if err != nil {
 90				return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
 91			}
 92			graphqlQuery := string(graphqlQueryBytes)
 93
 94			req, err := http.NewRequestWithContext(
 95				requestCtx,
 96				"POST",
 97				"https://sourcegraph.com/.api/graphql",
 98				bytes.NewBuffer([]byte(graphqlQuery)),
 99			)
100			if err != nil {
101				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
102			}
103
104			req.Header.Set("Content-Type", "application/json")
105			req.Header.Set("User-Agent", "crush/1.0")
106
107			resp, err := client.Do(req)
108			if err != nil {
109				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
110			}
111			defer resp.Body.Close()
112
113			if resp.StatusCode != http.StatusOK {
114				body, _ := io.ReadAll(resp.Body)
115				if len(body) > 0 {
116					return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
117				}
118
119				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
120			}
121			body, err := io.ReadAll(resp.Body)
122			if err != nil {
123				return fantasy.ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
124			}
125
126			var result map[string]any
127			if err = json.Unmarshal(body, &result); err != nil {
128				return fantasy.ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
129			}
130
131			formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
132			if err != nil {
133				return fantasy.NewTextErrorResponse("Failed to format results: " + err.Error()), nil
134			}
135
136			return fantasy.NewTextResponse(formattedResults), nil
137		})
138}
139
140func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
141	var buffer strings.Builder
142
143	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
144		buffer.WriteString("## Sourcegraph API Error\n\n")
145		for _, err := range errors {
146			if errMap, ok := err.(map[string]any); ok {
147				if message, ok := errMap["message"].(string); ok {
148					buffer.WriteString(fmt.Sprintf("- %s\n", message))
149				}
150			}
151		}
152		return buffer.String(), nil
153	}
154
155	data, ok := result["data"].(map[string]any)
156	if !ok {
157		return "", fmt.Errorf("invalid response format: missing data field")
158	}
159
160	search, ok := data["search"].(map[string]any)
161	if !ok {
162		return "", fmt.Errorf("invalid response format: missing search field")
163	}
164
165	searchResults, ok := search["results"].(map[string]any)
166	if !ok {
167		return "", fmt.Errorf("invalid response format: missing results field")
168	}
169
170	matchCount, _ := searchResults["matchCount"].(float64)
171	resultCount, _ := searchResults["resultCount"].(float64)
172	limitHit, _ := searchResults["limitHit"].(bool)
173
174	buffer.WriteString("# Sourcegraph Search Results\n\n")
175	buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
176
177	if limitHit {
178		buffer.WriteString("(Result limit reached, try a more specific query)\n")
179	}
180
181	buffer.WriteString("\n")
182
183	results, ok := searchResults["results"].([]any)
184	if !ok || len(results) == 0 {
185		buffer.WriteString("No results found. Try a different query.\n")
186		return buffer.String(), nil
187	}
188
189	maxResults := 10
190	if len(results) > maxResults {
191		results = results[:maxResults]
192	}
193
194	for i, res := range results {
195		fileMatch, ok := res.(map[string]any)
196		if !ok {
197			continue
198		}
199
200		typeName, _ := fileMatch["__typename"].(string)
201		if typeName != "FileMatch" {
202			continue
203		}
204
205		repo, _ := fileMatch["repository"].(map[string]any)
206		file, _ := fileMatch["file"].(map[string]any)
207		lineMatches, _ := fileMatch["lineMatches"].([]any)
208
209		if repo == nil || file == nil {
210			continue
211		}
212
213		repoName, _ := repo["name"].(string)
214		filePath, _ := file["path"].(string)
215		fileURL, _ := file["url"].(string)
216		fileContent, _ := file["content"].(string)
217
218		buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
219
220		if fileURL != "" {
221			buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
222		}
223
224		if len(lineMatches) > 0 {
225			for _, lm := range lineMatches {
226				lineMatch, ok := lm.(map[string]any)
227				if !ok {
228					continue
229				}
230
231				lineNumber, _ := lineMatch["lineNumber"].(float64)
232				preview, _ := lineMatch["preview"].(string)
233
234				if fileContent != "" {
235					lines := strings.Split(fileContent, "\n")
236
237					buffer.WriteString("```\n")
238
239					startLine := max(1, int(lineNumber)-contextWindow)
240
241					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
242						if j >= 0 {
243							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
244						}
245					}
246
247					buffer.WriteString(fmt.Sprintf("%d|  %s\n", int(lineNumber), preview))
248
249					endLine := int(lineNumber) + contextWindow
250
251					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
252						if j < len(lines) {
253							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
254						}
255					}
256
257					buffer.WriteString("```\n\n")
258				} else {
259					buffer.WriteString("```\n")
260					buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
261					buffer.WriteString("```\n\n")
262				}
263			}
264		}
265	}
266
267	return buffer.String(), nil
268}