1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"net/http"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/proto"
 15)
 16
 17type (
 18	SourcegraphParams           = proto.SourcegraphParams
 19	SourcegraphResponseMetadata = proto.SourcegraphResponseMetadata
 20)
 21
 22type sourcegraphTool struct {
 23	client *http.Client
 24}
 25
 26const SourcegraphToolName = proto.SourcegraphToolName
 27
 28//go:embed sourcegraph.md
 29var sourcegraphDescription []byte
 30
 31func NewSourcegraphTool() BaseTool {
 32	return &sourcegraphTool{
 33		client: &http.Client{
 34			Timeout: 30 * time.Second,
 35			Transport: &http.Transport{
 36				MaxIdleConns:        100,
 37				MaxIdleConnsPerHost: 10,
 38				IdleConnTimeout:     90 * time.Second,
 39			},
 40		},
 41	}
 42}
 43
 44func (t *sourcegraphTool) Name() string {
 45	return SourcegraphToolName
 46}
 47
 48func (t *sourcegraphTool) Info() ToolInfo {
 49	return ToolInfo{
 50		Name:        SourcegraphToolName,
 51		Description: string(sourcegraphDescription),
 52		Parameters: map[string]any{
 53			"query": map[string]any{
 54				"type":        "string",
 55				"description": "The Sourcegraph search query",
 56			},
 57			"count": map[string]any{
 58				"type":        "number",
 59				"description": "Optional number of results to return (default: 10, max: 20)",
 60			},
 61			"context_window": map[string]any{
 62				"type":        "number",
 63				"description": "The context around the match to return (default: 10 lines)",
 64			},
 65			"timeout": map[string]any{
 66				"type":        "number",
 67				"description": "Optional timeout in seconds (max 120)",
 68			},
 69		},
 70		Required: []string{"query"},
 71	}
 72}
 73
 74func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 75	var params SourcegraphParams
 76	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 77		return NewTextErrorResponse("Failed to parse sourcegraph parameters: " + err.Error()), nil
 78	}
 79
 80	if params.Query == "" {
 81		return NewTextErrorResponse("Query parameter is required"), nil
 82	}
 83
 84	if params.Count <= 0 {
 85		params.Count = 10
 86	} else if params.Count > 20 {
 87		params.Count = 20 // Limit to 20 results
 88	}
 89
 90	if params.ContextWindow <= 0 {
 91		params.ContextWindow = 10 // Default context window
 92	}
 93
 94	// Handle timeout with context
 95	requestCtx := ctx
 96	if params.Timeout > 0 {
 97		maxTimeout := 120 // 2 minutes
 98		if params.Timeout > maxTimeout {
 99			params.Timeout = maxTimeout
100		}
101		var cancel context.CancelFunc
102		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
103		defer cancel()
104	}
105
106	type graphqlRequest struct {
107		Query     string `json:"query"`
108		Variables struct {
109			Query string `json:"query"`
110		} `json:"variables"`
111	}
112
113	request := graphqlRequest{
114		Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
115	}
116	request.Variables.Query = params.Query
117
118	graphqlQueryBytes, err := json.Marshal(request)
119	if err != nil {
120		return ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
121	}
122	graphqlQuery := string(graphqlQueryBytes)
123
124	req, err := http.NewRequestWithContext(
125		requestCtx,
126		"POST",
127		"https://sourcegraph.com/.api/graphql",
128		bytes.NewBuffer([]byte(graphqlQuery)),
129	)
130	if err != nil {
131		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
132	}
133
134	req.Header.Set("Content-Type", "application/json")
135	req.Header.Set("User-Agent", "crush/1.0")
136
137	resp, err := t.client.Do(req)
138	if err != nil {
139		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
140	}
141	defer resp.Body.Close()
142
143	if resp.StatusCode != http.StatusOK {
144		body, _ := io.ReadAll(resp.Body)
145		if len(body) > 0 {
146			return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
147		}
148
149		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
150	}
151	body, err := io.ReadAll(resp.Body)
152	if err != nil {
153		return ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
154	}
155
156	var result map[string]any
157	if err = json.Unmarshal(body, &result); err != nil {
158		return ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
159	}
160
161	formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
162	if err != nil {
163		return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
164	}
165
166	return NewTextResponse(formattedResults), nil
167}
168
169func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
170	var buffer strings.Builder
171
172	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
173		buffer.WriteString("## Sourcegraph API Error\n\n")
174		for _, err := range errors {
175			if errMap, ok := err.(map[string]any); ok {
176				if message, ok := errMap["message"].(string); ok {
177					buffer.WriteString(fmt.Sprintf("- %s\n", message))
178				}
179			}
180		}
181		return buffer.String(), nil
182	}
183
184	data, ok := result["data"].(map[string]any)
185	if !ok {
186		return "", fmt.Errorf("invalid response format: missing data field")
187	}
188
189	search, ok := data["search"].(map[string]any)
190	if !ok {
191		return "", fmt.Errorf("invalid response format: missing search field")
192	}
193
194	searchResults, ok := search["results"].(map[string]any)
195	if !ok {
196		return "", fmt.Errorf("invalid response format: missing results field")
197	}
198
199	matchCount, _ := searchResults["matchCount"].(float64)
200	resultCount, _ := searchResults["resultCount"].(float64)
201	limitHit, _ := searchResults["limitHit"].(bool)
202
203	buffer.WriteString("# Sourcegraph Search Results\n\n")
204	buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
205
206	if limitHit {
207		buffer.WriteString("(Result limit reached, try a more specific query)\n")
208	}
209
210	buffer.WriteString("\n")
211
212	results, ok := searchResults["results"].([]any)
213	if !ok || len(results) == 0 {
214		buffer.WriteString("No results found. Try a different query.\n")
215		return buffer.String(), nil
216	}
217
218	maxResults := 10
219	if len(results) > maxResults {
220		results = results[:maxResults]
221	}
222
223	for i, res := range results {
224		fileMatch, ok := res.(map[string]any)
225		if !ok {
226			continue
227		}
228
229		typeName, _ := fileMatch["__typename"].(string)
230		if typeName != "FileMatch" {
231			continue
232		}
233
234		repo, _ := fileMatch["repository"].(map[string]any)
235		file, _ := fileMatch["file"].(map[string]any)
236		lineMatches, _ := fileMatch["lineMatches"].([]any)
237
238		if repo == nil || file == nil {
239			continue
240		}
241
242		repoName, _ := repo["name"].(string)
243		filePath, _ := file["path"].(string)
244		fileURL, _ := file["url"].(string)
245		fileContent, _ := file["content"].(string)
246
247		buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
248
249		if fileURL != "" {
250			buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
251		}
252
253		if len(lineMatches) > 0 {
254			for _, lm := range lineMatches {
255				lineMatch, ok := lm.(map[string]any)
256				if !ok {
257					continue
258				}
259
260				lineNumber, _ := lineMatch["lineNumber"].(float64)
261				preview, _ := lineMatch["preview"].(string)
262
263				if fileContent != "" {
264					lines := strings.Split(fileContent, "\n")
265
266					buffer.WriteString("```\n")
267
268					startLine := max(1, int(lineNumber)-contextWindow)
269
270					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
271						if j >= 0 {
272							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
273						}
274					}
275
276					buffer.WriteString(fmt.Sprintf("%d|  %s\n", int(lineNumber), preview))
277
278					endLine := int(lineNumber) + contextWindow
279
280					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
281						if j < len(lines) {
282							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
283						}
284					}
285
286					buffer.WriteString("```\n\n")
287				} else {
288					buffer.WriteString("```\n")
289					buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
290					buffer.WriteString("```\n\n")
291				}
292			}
293		}
294	}
295
296	return buffer.String(), nil
297}