1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"net/http"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/fantasy/ai"
 15)
 16
 17type SourcegraphParams struct {
 18	Query         string `json:"query" description:"The Sourcegraph search query"`
 19	Count         int    `json:"count,omitempty" description:"Optional number of results to return (default: 10, max: 20)"`
 20	ContextWindow int    `json:"context_window,omitempty" description:"The context around the match to return (default: 10 lines)"`
 21	Timeout       int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 22}
 23
 24type SourcegraphResponseMetadata struct {
 25	NumberOfMatches int  `json:"number_of_matches"`
 26	Truncated       bool `json:"truncated"`
 27}
 28
 29const SourcegraphToolName = "sourcegraph"
 30
 31//go:embed sourcegraph.md
 32var sourcegraphDescription []byte
 33
 34func NewSourcegraphTool(client *http.Client) ai.AgentTool {
 35	if client == nil {
 36		client = &http.Client{
 37			Timeout: 30 * time.Second,
 38			Transport: &http.Transport{
 39				MaxIdleConns:        100,
 40				MaxIdleConnsPerHost: 10,
 41				IdleConnTimeout:     90 * time.Second,
 42			},
 43		}
 44	}
 45	return ai.NewAgentTool(
 46		SourcegraphToolName,
 47		string(sourcegraphDescription),
 48		func(ctx context.Context, params SourcegraphParams, call ai.ToolCall) (ai.ToolResponse, error) {
 49			if params.Query == "" {
 50				return ai.NewTextErrorResponse("Query parameter is required"), nil
 51			}
 52
 53			if params.Count <= 0 {
 54				params.Count = 10
 55			} else if params.Count > 20 {
 56				params.Count = 20 // Limit to 20 results
 57			}
 58
 59			if params.ContextWindow <= 0 {
 60				params.ContextWindow = 10 // Default context window
 61			}
 62
 63			// Handle timeout with context
 64			requestCtx := ctx
 65			if params.Timeout > 0 {
 66				maxTimeout := 120 // 2 minutes
 67				if params.Timeout > maxTimeout {
 68					params.Timeout = maxTimeout
 69				}
 70				var cancel context.CancelFunc
 71				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 72				defer cancel()
 73			}
 74
 75			type graphqlRequest struct {
 76				Query     string `json:"query"`
 77				Variables struct {
 78					Query string `json:"query"`
 79				} `json:"variables"`
 80			}
 81
 82			request := graphqlRequest{
 83				Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
 84			}
 85			request.Variables.Query = params.Query
 86
 87			graphqlQueryBytes, err := json.Marshal(request)
 88			if err != nil {
 89				return ai.ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
 90			}
 91			graphqlQuery := string(graphqlQueryBytes)
 92
 93			req, err := http.NewRequestWithContext(
 94				requestCtx,
 95				"POST",
 96				"https://sourcegraph.com/.api/graphql",
 97				bytes.NewBuffer([]byte(graphqlQuery)),
 98			)
 99			if err != nil {
100				return ai.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
101			}
102
103			req.Header.Set("Content-Type", "application/json")
104			req.Header.Set("User-Agent", "crush/1.0")
105
106			resp, err := client.Do(req)
107			if err != nil {
108				return ai.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
109			}
110			defer resp.Body.Close()
111
112			if resp.StatusCode != http.StatusOK {
113				body, _ := io.ReadAll(resp.Body)
114				if len(body) > 0 {
115					return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
116				}
117
118				return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
119			}
120			body, err := io.ReadAll(resp.Body)
121			if err != nil {
122				return ai.ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
123			}
124
125			var result map[string]any
126			if err = json.Unmarshal(body, &result); err != nil {
127				return ai.ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
128			}
129
130			formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
131			if err != nil {
132				return ai.NewTextErrorResponse("Failed to format results: " + err.Error()), nil
133			}
134
135			return ai.NewTextResponse(formattedResults), nil
136		})
137}
138
139func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
140	var buffer strings.Builder
141
142	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
143		buffer.WriteString("## Sourcegraph API Error\n\n")
144		for _, err := range errors {
145			if errMap, ok := err.(map[string]any); ok {
146				if message, ok := errMap["message"].(string); ok {
147					buffer.WriteString(fmt.Sprintf("- %s\n", message))
148				}
149			}
150		}
151		return buffer.String(), nil
152	}
153
154	data, ok := result["data"].(map[string]any)
155	if !ok {
156		return "", fmt.Errorf("invalid response format: missing data field")
157	}
158
159	search, ok := data["search"].(map[string]any)
160	if !ok {
161		return "", fmt.Errorf("invalid response format: missing search field")
162	}
163
164	searchResults, ok := search["results"].(map[string]any)
165	if !ok {
166		return "", fmt.Errorf("invalid response format: missing results field")
167	}
168
169	matchCount, _ := searchResults["matchCount"].(float64)
170	resultCount, _ := searchResults["resultCount"].(float64)
171	limitHit, _ := searchResults["limitHit"].(bool)
172
173	buffer.WriteString("# Sourcegraph Search Results\n\n")
174	buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
175
176	if limitHit {
177		buffer.WriteString("(Result limit reached, try a more specific query)\n")
178	}
179
180	buffer.WriteString("\n")
181
182	results, ok := searchResults["results"].([]any)
183	if !ok || len(results) == 0 {
184		buffer.WriteString("No results found. Try a different query.\n")
185		return buffer.String(), nil
186	}
187
188	maxResults := 10
189	if len(results) > maxResults {
190		results = results[:maxResults]
191	}
192
193	for i, res := range results {
194		fileMatch, ok := res.(map[string]any)
195		if !ok {
196			continue
197		}
198
199		typeName, _ := fileMatch["__typename"].(string)
200		if typeName != "FileMatch" {
201			continue
202		}
203
204		repo, _ := fileMatch["repository"].(map[string]any)
205		file, _ := fileMatch["file"].(map[string]any)
206		lineMatches, _ := fileMatch["lineMatches"].([]any)
207
208		if repo == nil || file == nil {
209			continue
210		}
211
212		repoName, _ := repo["name"].(string)
213		filePath, _ := file["path"].(string)
214		fileURL, _ := file["url"].(string)
215		fileContent, _ := file["content"].(string)
216
217		buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
218
219		if fileURL != "" {
220			buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
221		}
222
223		if len(lineMatches) > 0 {
224			for _, lm := range lineMatches {
225				lineMatch, ok := lm.(map[string]any)
226				if !ok {
227					continue
228				}
229
230				lineNumber, _ := lineMatch["lineNumber"].(float64)
231				preview, _ := lineMatch["preview"].(string)
232
233				if fileContent != "" {
234					lines := strings.Split(fileContent, "\n")
235
236					buffer.WriteString("```\n")
237
238					startLine := max(1, int(lineNumber)-contextWindow)
239
240					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
241						if j >= 0 {
242							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
243						}
244					}
245
246					buffer.WriteString(fmt.Sprintf("%d|  %s\n", int(lineNumber), preview))
247
248					endLine := int(lineNumber) + contextWindow
249
250					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
251						if j < len(lines) {
252							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
253						}
254					}
255
256					buffer.WriteString("```\n\n")
257				} else {
258					buffer.WriteString("```\n")
259					buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
260					buffer.WriteString("```\n\n")
261				}
262			}
263		}
264	}
265
266	return buffer.String(), nil
267}