sourcegraph.go

  1package tools
  2
  3import (
  4	"bytes"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"net/http"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/fantasy/ai"
 15)
 16
 17type SourcegraphParams struct {
 18	Query         string `json:"query" description:"The Sourcegraph search query"`
 19	Count         int    `json:"count,omitempty" description:"Optional number of results to return (default: 10, max: 20)"`
 20	ContextWindow int    `json:"context_window,omitempty" description:"The context around the match to return (default: 10 lines)"`
 21	Timeout       int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 22}
 23
 24type SourcegraphResponseMetadata struct {
 25	NumberOfMatches int  `json:"number_of_matches"`
 26	Truncated       bool `json:"truncated"`
 27}
 28
 29const SourcegraphToolName = "sourcegraph"
 30
 31//go:embed sourcegraph.md
 32var sourcegraphDescription []byte
 33
 34func NewSourcegraphTool() ai.AgentTool {
 35	client := &http.Client{
 36		Timeout: 30 * time.Second,
 37		Transport: &http.Transport{
 38			MaxIdleConns:        100,
 39			MaxIdleConnsPerHost: 10,
 40			IdleConnTimeout:     90 * time.Second,
 41		},
 42	}
 43	return ai.NewAgentTool(
 44		SourcegraphToolName,
 45		string(sourcegraphDescription),
 46		func(ctx context.Context, params SourcegraphParams, call ai.ToolCall) (ai.ToolResponse, error) {
 47			if params.Query == "" {
 48				return ai.NewTextErrorResponse("Query parameter is required"), nil
 49			}
 50
 51			if params.Count <= 0 {
 52				params.Count = 10
 53			} else if params.Count > 20 {
 54				params.Count = 20 // Limit to 20 results
 55			}
 56
 57			if params.ContextWindow <= 0 {
 58				params.ContextWindow = 10 // Default context window
 59			}
 60
 61			// Handle timeout with context
 62			requestCtx := ctx
 63			if params.Timeout > 0 {
 64				maxTimeout := 120 // 2 minutes
 65				if params.Timeout > maxTimeout {
 66					params.Timeout = maxTimeout
 67				}
 68				var cancel context.CancelFunc
 69				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 70				defer cancel()
 71			}
 72
 73			type graphqlRequest struct {
 74				Query     string `json:"query"`
 75				Variables struct {
 76					Query string `json:"query"`
 77				} `json:"variables"`
 78			}
 79
 80			request := graphqlRequest{
 81				Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
 82			}
 83			request.Variables.Query = params.Query
 84
 85			graphqlQueryBytes, err := json.Marshal(request)
 86			if err != nil {
 87				return ai.ToolResponse{}, fmt.Errorf("failed to marshal GraphQL request: %w", err)
 88			}
 89			graphqlQuery := string(graphqlQueryBytes)
 90
 91			req, err := http.NewRequestWithContext(
 92				requestCtx,
 93				"POST",
 94				"https://sourcegraph.com/.api/graphql",
 95				bytes.NewBuffer([]byte(graphqlQuery)),
 96			)
 97			if err != nil {
 98				return ai.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 99			}
100
101			req.Header.Set("Content-Type", "application/json")
102			req.Header.Set("User-Agent", "crush/1.0")
103
104			resp, err := client.Do(req)
105			if err != nil {
106				return ai.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
107			}
108			defer resp.Body.Close()
109
110			if resp.StatusCode != http.StatusOK {
111				body, _ := io.ReadAll(resp.Body)
112				if len(body) > 0 {
113					return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
114				}
115
116				return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
117			}
118			body, err := io.ReadAll(resp.Body)
119			if err != nil {
120				return ai.ToolResponse{}, fmt.Errorf("failed to read response body: %w", err)
121			}
122
123			var result map[string]any
124			if err = json.Unmarshal(body, &result); err != nil {
125				return ai.ToolResponse{}, fmt.Errorf("failed to unmarshal response: %w", err)
126			}
127
128			formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
129			if err != nil {
130				return ai.NewTextErrorResponse("Failed to format results: " + err.Error()), nil
131			}
132
133			return ai.NewTextResponse(formattedResults), nil
134		})
135}
136
137func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
138	var buffer strings.Builder
139
140	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
141		buffer.WriteString("## Sourcegraph API Error\n\n")
142		for _, err := range errors {
143			if errMap, ok := err.(map[string]any); ok {
144				if message, ok := errMap["message"].(string); ok {
145					buffer.WriteString(fmt.Sprintf("- %s\n", message))
146				}
147			}
148		}
149		return buffer.String(), nil
150	}
151
152	data, ok := result["data"].(map[string]any)
153	if !ok {
154		return "", fmt.Errorf("invalid response format: missing data field")
155	}
156
157	search, ok := data["search"].(map[string]any)
158	if !ok {
159		return "", fmt.Errorf("invalid response format: missing search field")
160	}
161
162	searchResults, ok := search["results"].(map[string]any)
163	if !ok {
164		return "", fmt.Errorf("invalid response format: missing results field")
165	}
166
167	matchCount, _ := searchResults["matchCount"].(float64)
168	resultCount, _ := searchResults["resultCount"].(float64)
169	limitHit, _ := searchResults["limitHit"].(bool)
170
171	buffer.WriteString("# Sourcegraph Search Results\n\n")
172	buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
173
174	if limitHit {
175		buffer.WriteString("(Result limit reached, try a more specific query)\n")
176	}
177
178	buffer.WriteString("\n")
179
180	results, ok := searchResults["results"].([]any)
181	if !ok || len(results) == 0 {
182		buffer.WriteString("No results found. Try a different query.\n")
183		return buffer.String(), nil
184	}
185
186	maxResults := 10
187	if len(results) > maxResults {
188		results = results[:maxResults]
189	}
190
191	for i, res := range results {
192		fileMatch, ok := res.(map[string]any)
193		if !ok {
194			continue
195		}
196
197		typeName, _ := fileMatch["__typename"].(string)
198		if typeName != "FileMatch" {
199			continue
200		}
201
202		repo, _ := fileMatch["repository"].(map[string]any)
203		file, _ := fileMatch["file"].(map[string]any)
204		lineMatches, _ := fileMatch["lineMatches"].([]any)
205
206		if repo == nil || file == nil {
207			continue
208		}
209
210		repoName, _ := repo["name"].(string)
211		filePath, _ := file["path"].(string)
212		fileURL, _ := file["url"].(string)
213		fileContent, _ := file["content"].(string)
214
215		buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
216
217		if fileURL != "" {
218			buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
219		}
220
221		if len(lineMatches) > 0 {
222			for _, lm := range lineMatches {
223				lineMatch, ok := lm.(map[string]any)
224				if !ok {
225					continue
226				}
227
228				lineNumber, _ := lineMatch["lineNumber"].(float64)
229				preview, _ := lineMatch["preview"].(string)
230
231				if fileContent != "" {
232					lines := strings.Split(fileContent, "\n")
233
234					buffer.WriteString("```\n")
235
236					startLine := max(1, int(lineNumber)-contextWindow)
237
238					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
239						if j >= 0 {
240							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
241						}
242					}
243
244					buffer.WriteString(fmt.Sprintf("%d|  %s\n", int(lineNumber), preview))
245
246					endLine := int(lineNumber) + contextWindow
247
248					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
249						if j < len(lines) {
250							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
251						}
252					}
253
254					buffer.WriteString("```\n\n")
255				} else {
256					buffer.WriteString("```\n")
257					buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
258					buffer.WriteString("```\n\n")
259				}
260			}
261		}
262	}
263
264	return buffer.String(), nil
265}