1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"net/http"
 11	"strings"
 12	"time"
 13)
 14
 15type SourcegraphParams struct {
 16	Query         string `json:"query"`
 17	Count         int    `json:"count,omitempty"`
 18	ContextWindow int    `json:"context_window,omitempty"`
 19	Timeout       int    `json:"timeout,omitempty"`
 20}
 21
 22type SourcegraphResponseMetadata struct {
 23	NumberOfMatches int  `json:"number_of_matches"`
 24	Truncated       bool `json:"truncated"`
 25}
 26
 27type sourcegraphTool struct {
 28	client *http.Client
 29}
 30
 31const SourcegraphToolName = "sourcegraph"
 32
 33//go:embed sourcegraph.md
 34var sourcegraphDescription []byte
 35
 36func NewSourcegraphTool() BaseTool {
 37	return &sourcegraphTool{
 38		client: &http.Client{
 39			Timeout: 30 * time.Second,
 40			Transport: &http.Transport{
 41				MaxIdleConns:        100,
 42				MaxIdleConnsPerHost: 10,
 43				IdleConnTimeout:     90 * time.Second,
 44			},
 45		},
 46	}
 47}
 48
 49func (t *sourcegraphTool) Name() string {
 50	return SourcegraphToolName
 51}
 52
 53func (t *sourcegraphTool) Info() ToolInfo {
 54	return ToolInfo{
 55		Name:        SourcegraphToolName,
 56		Description: string(sourcegraphDescription),
 57		Parameters: map[string]any{
 58			"query": map[string]any{
 59				"type":        "string",
 60				"description": "The Sourcegraph search query",
 61			},
 62			"count": map[string]any{
 63				"type":        "number",
 64				"description": "Optional number of results to return (default: 10, max: 20)",
 65			},
 66			"context_window": map[string]any{
 67				"type":        "number",
 68				"description": "The context around the match to return (default: 10 lines)",
 69			},
 70			"timeout": map[string]any{
 71				"type":        "number",
 72				"description": "Optional timeout in seconds (max 120)",
 73			},
 74		},
 75		Required: []string{"query"},
 76	}
 77}
 78
 79func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 80	var params SourcegraphParams
 81	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 82		return NewTextErrorResponse("Failed to parse sourcegraph parameters: " + err.Error()), nil
 83	}
 84
 85	if params.Query == "" {
 86		return NewTextErrorResponse("Query parameter is required"), nil
 87	}
 88
 89	if params.Count <= 0 {
 90		params.Count = 10
 91	} else if params.Count > 20 {
 92		params.Count = 20 // Limit to 20 results
 93	}
 94
 95	if params.ContextWindow <= 0 {
 96		params.ContextWindow = 10 // Default context window
 97	}
 98
 99	// Handle timeout with context
100	requestCtx := ctx
101	if params.Timeout > 0 {
102		maxTimeout := 120 // 2 minutes
103		if params.Timeout > maxTimeout {
104			params.Timeout = maxTimeout
105		}
106		var cancel context.CancelFunc
107		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
108		defer cancel()
109	}
110
111	type graphqlRequest struct {
112		Query     string `json:"query"`
113		Variables struct {
114			Query string `json:"query"`
115		} `json:"variables"`
116	}
117
118	request := graphqlRequest{
119		Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
120	}
121	request.Variables.Query = params.Query
122
123	graphqlQueryBytes, err := json.Marshal(request)
124	if err != nil {
125		return ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
126	}
127	graphqlQuery := string(graphqlQueryBytes)
128
129	req, err := http.NewRequestWithContext(
130		requestCtx,
131		"POST",
132		"https://sourcegraph.com/.api/graphql",
133		bytes.NewBuffer([]byte(graphqlQuery)),
134	)
135	if err != nil {
136		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
137	}
138
139	req.Header.Set("Content-Type", "application/json")
140	req.Header.Set("User-Agent", "crush/1.0")
141
142	resp, err := t.client.Do(req)
143	if err != nil {
144		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
145	}
146	defer resp.Body.Close()
147
148	if resp.StatusCode != http.StatusOK {
149		body, _ := io.ReadAll(resp.Body)
150		if len(body) > 0 {
151			return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
152		}
153
154		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
155	}
156	body, err := io.ReadAll(resp.Body)
157	if err != nil {
158		return ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
159	}
160
161	var result map[string]any
162	if err = json.Unmarshal(body, &result); err != nil {
163		return ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
164	}
165
166	formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
167	if err != nil {
168		return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
169	}
170
171	return NewTextResponse(formattedResults), nil
172}
173
174func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
175	var buffer strings.Builder
176
177	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
178		buffer.WriteString("## Sourcegraph API Error\n\n")
179		for _, err := range errors {
180			if errMap, ok := err.(map[string]any); ok {
181				if message, ok := errMap["message"].(string); ok {
182					buffer.WriteString(fmt.Sprintf("- %s\n", message))
183				}
184			}
185		}
186		return buffer.String(), nil
187	}
188
189	data, ok := result["data"].(map[string]any)
190	if !ok {
191		return "", fmt.Errorf("invalid response format: missing data field")
192	}
193
194	search, ok := data["search"].(map[string]any)
195	if !ok {
196		return "", fmt.Errorf("invalid response format: missing search field")
197	}
198
199	searchResults, ok := search["results"].(map[string]any)
200	if !ok {
201		return "", fmt.Errorf("invalid response format: missing results field")
202	}
203
204	matchCount, _ := searchResults["matchCount"].(float64)
205	resultCount, _ := searchResults["resultCount"].(float64)
206	limitHit, _ := searchResults["limitHit"].(bool)
207
208	buffer.WriteString("# Sourcegraph Search Results\n\n")
209	buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
210
211	if limitHit {
212		buffer.WriteString("(Result limit reached, try a more specific query)\n")
213	}
214
215	buffer.WriteString("\n")
216
217	results, ok := searchResults["results"].([]any)
218	if !ok || len(results) == 0 {
219		buffer.WriteString("No results found. Try a different query.\n")
220		return buffer.String(), nil
221	}
222
223	maxResults := 10
224	if len(results) > maxResults {
225		results = results[:maxResults]
226	}
227
228	for i, res := range results {
229		fileMatch, ok := res.(map[string]any)
230		if !ok {
231			continue
232		}
233
234		typeName, _ := fileMatch["__typename"].(string)
235		if typeName != "FileMatch" {
236			continue
237		}
238
239		repo, _ := fileMatch["repository"].(map[string]any)
240		file, _ := fileMatch["file"].(map[string]any)
241		lineMatches, _ := fileMatch["lineMatches"].([]any)
242
243		if repo == nil || file == nil {
244			continue
245		}
246
247		repoName, _ := repo["name"].(string)
248		filePath, _ := file["path"].(string)
249		fileURL, _ := file["url"].(string)
250		fileContent, _ := file["content"].(string)
251
252		buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
253
254		if fileURL != "" {
255			buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
256		}
257
258		if len(lineMatches) > 0 {
259			for _, lm := range lineMatches {
260				lineMatch, ok := lm.(map[string]any)
261				if !ok {
262					continue
263				}
264
265				lineNumber, _ := lineMatch["lineNumber"].(float64)
266				preview, _ := lineMatch["preview"].(string)
267
268				if fileContent != "" {
269					lines := strings.Split(fileContent, "\n")
270
271					buffer.WriteString("```\n")
272
273					startLine := max(1, int(lineNumber)-contextWindow)
274
275					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
276						if j >= 0 {
277							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
278						}
279					}
280
281					buffer.WriteString(fmt.Sprintf("%d|  %s\n", int(lineNumber), preview))
282
283					endLine := int(lineNumber) + contextWindow
284
285					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
286						if j < len(lines) {
287							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
288						}
289					}
290
291					buffer.WriteString("```\n\n")
292				} else {
293					buffer.WriteString("```\n")
294					buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
295					buffer.WriteString("```\n\n")
296				}
297			}
298		}
299	}
300
301	return buffer.String(), nil
302}