1package tools
2
3import (
4 "bytes"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "fmt"
9 "io"
10 "net/http"
11 "strings"
12 "time"
13
14 "charm.land/fantasy"
15)
16
17type SourcegraphParams struct {
18 Query string `json:"query" description:"The Sourcegraph search query"`
19 Count int `json:"count,omitempty" description:"Optional number of results to return (default: 10, max: 20)"`
20 ContextWindow int `json:"context_window,omitempty" description:"The context around the match to return (default: 10 lines)"`
21 Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
22}
23
24type SourcegraphResponseMetadata struct {
25 NumberOfMatches int `json:"number_of_matches"`
26 Truncated bool `json:"truncated"`
27}
28
29const SourcegraphToolName = "sourcegraph"
30
31//go:embed sourcegraph.md
32var sourcegraphDescription []byte
33
34func NewSourcegraphTool(client *http.Client) fantasy.AgentTool {
35 if client == nil {
36 transport := http.DefaultTransport.(*http.Transport).Clone()
37 transport.MaxIdleConns = 100
38 transport.MaxIdleConnsPerHost = 10
39 transport.IdleConnTimeout = 90 * time.Second
40
41 client = &http.Client{
42 Timeout: 30 * time.Second,
43 Transport: transport,
44 }
45 }
46 return fantasy.NewParallelAgentTool(
47 SourcegraphToolName,
48 string(sourcegraphDescription),
49 func(ctx context.Context, params SourcegraphParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
50 if params.Query == "" {
51 return fantasy.NewTextErrorResponse("Query parameter is required"), nil
52 }
53
54 if params.Count <= 0 {
55 params.Count = 10
56 } else if params.Count > 20 {
57 params.Count = 20 // Limit to 20 results
58 }
59
60 if params.ContextWindow <= 0 {
61 params.ContextWindow = 10 // Default context window
62 }
63
64 // Handle timeout with context
65 requestCtx := ctx
66 if params.Timeout > 0 {
67 maxTimeout := 120 // 2 minutes
68 if params.Timeout > maxTimeout {
69 params.Timeout = maxTimeout
70 }
71 var cancel context.CancelFunc
72 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
73 defer cancel()
74 }
75
76 type graphqlRequest struct {
77 Query string `json:"query"`
78 Variables struct {
79 Query string `json:"query"`
80 } `json:"variables"`
81 }
82
83 request := graphqlRequest{
84 Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
85 }
86 request.Variables.Query = params.Query
87
88 graphqlQueryBytes, err := json.Marshal(request)
89 if err != nil {
90 return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
91 }
92 graphqlQuery := string(graphqlQueryBytes)
93
94 req, err := http.NewRequestWithContext(
95 requestCtx,
96 "POST",
97 "https://sourcegraph.com/.api/graphql",
98 bytes.NewBuffer([]byte(graphqlQuery)),
99 )
100 if err != nil {
101 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
102 }
103
104 req.Header.Set("Content-Type", "application/json")
105 req.Header.Set("User-Agent", "crush/1.0")
106
107 resp, err := client.Do(req)
108 if err != nil {
109 return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
110 }
111 defer resp.Body.Close()
112
113 if resp.StatusCode != http.StatusOK {
114 body, _ := io.ReadAll(resp.Body)
115 if len(body) > 0 {
116 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
117 }
118
119 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
120 }
121 body, err := io.ReadAll(resp.Body)
122 if err != nil {
123 return fantasy.ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
124 }
125
126 var result map[string]any
127 if err = json.Unmarshal(body, &result); err != nil {
128 return fantasy.ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
129 }
130
131 formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
132 if err != nil {
133 return fantasy.NewTextErrorResponse("Failed to format results: " + err.Error()), nil
134 }
135
136 return fantasy.NewTextResponse(formattedResults), nil
137 })
138}
139
140func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
141 var buffer strings.Builder
142
143 if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
144 buffer.WriteString("## Sourcegraph API Error\n\n")
145 for _, err := range errors {
146 if errMap, ok := err.(map[string]any); ok {
147 if message, ok := errMap["message"].(string); ok {
148 buffer.WriteString(fmt.Sprintf("- %s\n", message))
149 }
150 }
151 }
152 return buffer.String(), nil
153 }
154
155 data, ok := result["data"].(map[string]any)
156 if !ok {
157 return "", fmt.Errorf("invalid response format: missing data field")
158 }
159
160 search, ok := data["search"].(map[string]any)
161 if !ok {
162 return "", fmt.Errorf("invalid response format: missing search field")
163 }
164
165 searchResults, ok := search["results"].(map[string]any)
166 if !ok {
167 return "", fmt.Errorf("invalid response format: missing results field")
168 }
169
170 matchCount, _ := searchResults["matchCount"].(float64)
171 resultCount, _ := searchResults["resultCount"].(float64)
172 limitHit, _ := searchResults["limitHit"].(bool)
173
174 buffer.WriteString("# Sourcegraph Search Results\n\n")
175 buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
176
177 if limitHit {
178 buffer.WriteString("(Result limit reached, try a more specific query)\n")
179 }
180
181 buffer.WriteString("\n")
182
183 results, ok := searchResults["results"].([]any)
184 if !ok || len(results) == 0 {
185 buffer.WriteString("No results found. Try a different query.\n")
186 return buffer.String(), nil
187 }
188
189 maxResults := 10
190 if len(results) > maxResults {
191 results = results[:maxResults]
192 }
193
194 for i, res := range results {
195 fileMatch, ok := res.(map[string]any)
196 if !ok {
197 continue
198 }
199
200 typeName, _ := fileMatch["__typename"].(string)
201 if typeName != "FileMatch" {
202 continue
203 }
204
205 repo, _ := fileMatch["repository"].(map[string]any)
206 file, _ := fileMatch["file"].(map[string]any)
207 lineMatches, _ := fileMatch["lineMatches"].([]any)
208
209 if repo == nil || file == nil {
210 continue
211 }
212
213 repoName, _ := repo["name"].(string)
214 filePath, _ := file["path"].(string)
215 fileURL, _ := file["url"].(string)
216 fileContent, _ := file["content"].(string)
217
218 buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
219
220 if fileURL != "" {
221 buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
222 }
223
224 if len(lineMatches) > 0 {
225 for _, lm := range lineMatches {
226 lineMatch, ok := lm.(map[string]any)
227 if !ok {
228 continue
229 }
230
231 lineNumber, _ := lineMatch["lineNumber"].(float64)
232 preview, _ := lineMatch["preview"].(string)
233
234 if fileContent != "" {
235 lines := strings.Split(fileContent, "\n")
236
237 buffer.WriteString("```\n")
238
239 startLine := max(1, int(lineNumber)-contextWindow)
240
241 for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
242 if j >= 0 {
243 buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
244 }
245 }
246
247 buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
248
249 endLine := int(lineNumber) + contextWindow
250
251 for j := int(lineNumber); j < endLine && j < len(lines); j++ {
252 if j < len(lines) {
253 buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
254 }
255 }
256
257 buffer.WriteString("```\n\n")
258 } else {
259 buffer.WriteString("```\n")
260 buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
261 buffer.WriteString("```\n\n")
262 }
263 }
264 }
265 }
266
267 return buffer.String(), nil
268}