add sourcegraph tool

Kujtim Hoxha created

Change summary

internal/llm/tools/sourcegraph.go      | 401 ++++++++++++++++++++++++++++
internal/llm/tools/sourcegraph_test.go | 115 ++++++++
2 files changed, 516 insertions(+)

Detailed changes

internal/llm/tools/sourcegraph.go 🔗

@@ -0,0 +1,401 @@
+package tools
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"strings"
+	"time"
+)
+
+const (
+	SourcegraphToolName        = "sourcegraph"
+	sourcegraphToolDescription = `Search code across public repositories using Sourcegraph's GraphQL API.
+
+WHEN TO USE THIS TOOL:
+- Use when you need to find code examples or implementations across public repositories
+- Helpful for researching how others have solved similar problems
+- Useful for discovering patterns and best practices in open source code
+
+HOW TO USE:
+- Provide a search query using Sourcegraph's query syntax
+- Optionally specify the number of results to return (default: 10)
+- Optionally set a timeout for the request
+
+QUERY SYNTAX:
+- Basic search: "fmt.Println" searches for exact matches
+- File filters: "file:.go fmt.Println" limits to Go files
+- Repository filters: "repo:^github\.com/golang/go$ fmt.Println" limits to specific repos
+- Language filters: "lang:go fmt.Println" limits to Go code
+- Boolean operators: "fmt.Println AND log.Fatal" for combined terms
+- Regular expressions: "fmt\.(Print|Printf|Println)" for pattern matching
+- Quoted strings: "\"exact phrase\"" for exact phrase matching
+- Exclude filters: "-file:test" or "-repo:forks" to exclude matches
+
+ADVANCED FILTERS:
+- Repository filters:
+  * "repo:name" - Match repositories with name containing "name"
+  * "repo:^github\.com/org/repo$" - Exact repository match
+  * "repo:org/repo@branch" - Search specific branch
+  * "repo:org/repo rev:branch" - Alternative branch syntax
+  * "-repo:name" - Exclude repositories
+  * "fork:yes" or "fork:only" - Include or only show forks
+  * "archived:yes" or "archived:only" - Include or only show archived repos
+  * "visibility:public" or "visibility:private" - Filter by visibility
+
+- File filters:
+  * "file:\.js$" - Files with .js extension
+  * "file:internal/" - Files in internal directory
+  * "-file:test" - Exclude test files
+  * "file:has.content(Copyright)" - Files containing "Copyright"
+  * "file:has.contributor([email protected])" - Files with specific contributor
+
+- Content filters:
+  * "content:\"exact string\"" - Search for exact string
+  * "-content:\"unwanted\"" - Exclude files with unwanted content
+  * "case:yes" - Case-sensitive search
+
+- Type filters:
+  * "type:symbol" - Search for symbols (functions, classes, etc.)
+  * "type:file" - Search file content only
+  * "type:path" - Search filenames only
+  * "type:diff" - Search code changes
+  * "type:commit" - Search commit messages
+
+- Commit/diff search:
+  * "after:\"1 month ago\"" - Commits after date
+  * "before:\"2023-01-01\"" - Commits before date
+  * "author:name" - Commits by author
+  * "message:\"fix bug\"" - Commits with message
+
+- Result selection:
+  * "select:repo" - Show only repository names
+  * "select:file" - Show only file paths
+  * "select:content" - Show only matching content
+  * "select:symbol" - Show only matching symbols
+
+- Result control:
+  * "count:100" - Return up to 100 results
+  * "count:all" - Return all results
+  * "timeout:30s" - Set search timeout
+
+EXAMPLES:
+- "file:.go context.WithTimeout" - Find Go code using context.WithTimeout
+- "lang:typescript useState type:symbol" - Find TypeScript React useState hooks
+- "repo:^github\.com/kubernetes/kubernetes$ pod list type:file" - Find Kubernetes files related to pod listing
+- "repo:sourcegraph/sourcegraph$ after:\"3 months ago\" type:diff database" - Recent changes to database code
+- "file:Dockerfile (alpine OR ubuntu) -content:alpine:latest" - Dockerfiles with specific base images
+- "repo:has.path(\.py) file:requirements.txt tensorflow" - Python projects using TensorFlow
+
+BOOLEAN OPERATORS:
+- "term1 AND term2" - Results containing both terms
+- "term1 OR term2" - Results containing either term
+- "term1 NOT term2" - Results with term1 but not term2
+- "term1 and (term2 or term3)" - Grouping with parentheses
+
+LIMITATIONS:
+- Only searches public repositories
+- Rate limits may apply
+- Complex queries may take longer to execute
+- Maximum of 20 results per query
+
+TIPS:
+- Use specific file extensions to narrow results
+- Add repo: filters for more targeted searches
+- Use type:symbol to find function/method definitions
+- Use type:file to find relevant files
+- For more details on query syntax, visit: https://docs.sourcegraph.com/code_search/queries`
+)
+
+type SourcegraphParams struct {
+	Query   string `json:"query"`
+	Count   int    `json:"count,omitempty"`
+	Timeout int    `json:"timeout,omitempty"`
+}
+
+type SourcegraphPermissionsParams struct {
+	Query   string `json:"query"`
+	Count   int    `json:"count,omitempty"`
+	Timeout int    `json:"timeout,omitempty"`
+}
+
+type sourcegraphTool struct {
+	client *http.Client
+}
+
+func NewSourcegraphTool() BaseTool {
+	return &sourcegraphTool{
+		client: &http.Client{
+			Timeout: 30 * time.Second,
+		},
+	}
+}
+
+func (t *sourcegraphTool) Info() ToolInfo {
+	return ToolInfo{
+		Name:        SourcegraphToolName,
+		Description: sourcegraphToolDescription,
+		Parameters: map[string]any{
+			"query": map[string]any{
+				"type":        "string",
+				"description": "The Sourcegraph search query",
+			},
+			"count": map[string]any{
+				"type":        "number",
+				"description": "Optional number of results to return (default: 10, max: 20)",
+			},
+			"timeout": map[string]any{
+				"type":        "number",
+				"description": "Optional timeout in seconds (max 120)",
+			},
+		},
+		Required: []string{"query"},
+	}
+}
+
+func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
+	var params SourcegraphParams
+	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
+		return NewTextErrorResponse("Failed to parse sourcegraph parameters: " + err.Error()), nil
+	}
+
+	if params.Query == "" {
+		return NewTextErrorResponse("Query parameter is required"), nil
+	}
+
+	// Set default count if not specified
+	if params.Count <= 0 {
+		params.Count = 10
+	} else if params.Count > 20 {
+		params.Count = 20 // Limit to 20 results
+	}
+
+	client := t.client
+	if params.Timeout > 0 {
+		maxTimeout := 120 // 2 minutes
+		if params.Timeout > maxTimeout {
+			params.Timeout = maxTimeout
+		}
+		client = &http.Client{
+			Timeout: time.Duration(params.Timeout) * time.Second,
+		}
+	}
+
+	// GraphQL query for Sourcegraph search
+	// Create a properly escaped JSON structure
+	type graphqlRequest struct {
+		Query     string `json:"query"`
+		Variables struct {
+			Query string `json:"query"`
+		} `json:"variables"`
+	}
+
+	request := graphqlRequest{
+		Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: standard ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
+	}
+	request.Variables.Query = params.Query
+
+	// Marshal to JSON to ensure proper escaping
+	graphqlQueryBytes, err := json.Marshal(request)
+	if err != nil {
+		return NewTextErrorResponse("Failed to create GraphQL request: " + err.Error()), nil
+	}
+	graphqlQuery := string(graphqlQueryBytes)
+
+	// Create request to Sourcegraph API
+	req, err := http.NewRequestWithContext(
+		ctx,
+		"POST",
+		"https://sourcegraph.com/.api/graphql",
+		bytes.NewBuffer([]byte(graphqlQuery)),
+	)
+	if err != nil {
+		return NewTextErrorResponse("Failed to create request: " + err.Error()), nil
+	}
+
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("User-Agent", "termai/1.0")
+
+	resp, err := client.Do(req)
+	if err != nil {
+		return NewTextErrorResponse("Failed to execute request: " + err.Error()), nil
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		// log the error response
+		body, _ := io.ReadAll(resp.Body)
+		if len(body) > 0 {
+			return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
+		}
+
+		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
+	}
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
+	}
+
+	// Parse the GraphQL response
+	var result map[string]any
+	if err = json.Unmarshal(body, &result); err != nil {
+		return NewTextErrorResponse("Failed to parse response: " + err.Error()), nil
+	}
+
+	// Format the results in a readable way
+	formattedResults, err := formatSourcegraphResults(result)
+	if err != nil {
+		return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
+	}
+
+	return NewTextResponse(formattedResults), nil
+}
+
+func formatSourcegraphResults(result map[string]any) (string, error) {
+	var buffer strings.Builder
+
+	// Check for errors in the GraphQL response
+	if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
+		buffer.WriteString("## Sourcegraph API Error\n\n")
+		for _, err := range errors {
+			if errMap, ok := err.(map[string]any); ok {
+				if message, ok := errMap["message"].(string); ok {
+					buffer.WriteString(fmt.Sprintf("- %s\n", message))
+				}
+			}
+		}
+		return buffer.String(), nil
+	}
+
+	// Extract data from the response
+	data, ok := result["data"].(map[string]any)
+	if !ok {
+		return "", fmt.Errorf("invalid response format: missing data field")
+	}
+
+	search, ok := data["search"].(map[string]any)
+	if !ok {
+		return "", fmt.Errorf("invalid response format: missing search field")
+	}
+
+	searchResults, ok := search["results"].(map[string]any)
+	if !ok {
+		return "", fmt.Errorf("invalid response format: missing results field")
+	}
+
+	// Write search metadata
+	matchCount, _ := searchResults["matchCount"].(float64)
+	resultCount, _ := searchResults["resultCount"].(float64)
+	limitHit, _ := searchResults["limitHit"].(bool)
+
+	buffer.WriteString("# Sourcegraph Search Results\n\n")
+	buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
+
+	if limitHit {
+		buffer.WriteString("(Result limit reached, try a more specific query)\n")
+	}
+
+	buffer.WriteString("\n")
+
+	// Process results
+	results, ok := searchResults["results"].([]any)
+	if !ok || len(results) == 0 {
+		buffer.WriteString("No results found. Try a different query.\n")
+		return buffer.String(), nil
+	}
+
+	// Limit to 10 results
+	maxResults := 10
+	if len(results) > maxResults {
+		results = results[:maxResults]
+	}
+
+	// Process each result
+	for i, res := range results {
+		fileMatch, ok := res.(map[string]any)
+		if !ok {
+			continue
+		}
+
+		// Skip non-FileMatch results
+		typeName, _ := fileMatch["__typename"].(string)
+		if typeName != "FileMatch" {
+			continue
+		}
+
+		// Extract repository and file information
+		repo, _ := fileMatch["repository"].(map[string]any)
+		file, _ := fileMatch["file"].(map[string]any)
+		lineMatches, _ := fileMatch["lineMatches"].([]any)
+
+		if repo == nil || file == nil {
+			continue
+		}
+
+		repoName, _ := repo["name"].(string)
+		filePath, _ := file["path"].(string)
+		fileURL, _ := file["url"].(string)
+		fileContent, _ := file["content"].(string)
+
+		buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
+
+		if fileURL != "" {
+			buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
+		}
+
+		// Show line matches with context
+		if len(lineMatches) > 0 {
+			for _, lm := range lineMatches {
+				lineMatch, ok := lm.(map[string]any)
+				if !ok {
+					continue
+				}
+
+				lineNumber, _ := lineMatch["lineNumber"].(float64)
+				preview, _ := lineMatch["preview"].(string)
+
+				// Extract context from file content if available
+				if fileContent != "" {
+					lines := strings.Split(fileContent, "\n")
+
+					buffer.WriteString("```\n")
+
+					// Display context before the match (up to 10 lines)
+					contextBefore := 10
+					startLine := max(1, int(lineNumber)-contextBefore)
+
+					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
+						if j >= 0 {
+							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
+						}
+					}
+
+					// Display the matching line (highlighted)
+					buffer.WriteString(fmt.Sprintf("%d|  %s\n", int(lineNumber), preview))
+
+					// Display context after the match (up to 10 lines)
+					contextAfter := 10
+					endLine := int(lineNumber) + contextAfter
+
+					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
+						if j < len(lines) {
+							buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
+						}
+					}
+
+					buffer.WriteString("```\n\n")
+				} else {
+					// If file content is not available, just show the preview
+					buffer.WriteString("```\n")
+					buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
+					buffer.WriteString("```\n\n")
+				}
+			}
+		}
+	}
+
+	return buffer.String(), nil
+}

internal/llm/tools/sourcegraph_test.go 🔗

@@ -0,0 +1,115 @@
+package tools
+
+import (
+	"context"
+	"encoding/json"
+	"testing"
+
+	"github.com/kujtimiihoxha/termai/internal/permission"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func TestSourcegraphTool_Info(t *testing.T) {
+	tool := NewSourcegraphTool()
+	info := tool.Info()
+
+	assert.Equal(t, SourcegraphToolName, info.Name)
+	assert.NotEmpty(t, info.Description)
+	assert.Contains(t, info.Parameters, "query")
+	assert.Contains(t, info.Parameters, "count")
+	assert.Contains(t, info.Parameters, "timeout")
+	assert.Contains(t, info.Required, "query")
+}
+
+func TestSourcegraphTool_Run(t *testing.T) {
+	// Setup a mock permission handler that always allows
+	origPermission := permission.Default
+	defer func() {
+		permission.Default = origPermission
+	}()
+	permission.Default = newMockPermissionService(true)
+
+	t.Run("handles missing query parameter", func(t *testing.T) {
+		tool := NewSourcegraphTool()
+		params := SourcegraphParams{
+			Query: "",
+		}
+
+		paramsJSON, err := json.Marshal(params)
+		require.NoError(t, err)
+
+		call := ToolCall{
+			Name:  SourcegraphToolName,
+			Input: string(paramsJSON),
+		}
+
+		response, err := tool.Run(context.Background(), call)
+		require.NoError(t, err)
+		assert.Contains(t, response.Content, "Query parameter is required")
+	})
+
+	t.Run("handles invalid parameters", func(t *testing.T) {
+		tool := NewSourcegraphTool()
+		call := ToolCall{
+			Name:  SourcegraphToolName,
+			Input: "invalid json",
+		}
+
+		response, err := tool.Run(context.Background(), call)
+		require.NoError(t, err)
+		assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
+	})
+
+	t.Run("handles permission denied", func(t *testing.T) {
+		permission.Default = newMockPermissionService(false)
+
+		tool := NewSourcegraphTool()
+		params := SourcegraphParams{
+			Query: "test query",
+		}
+
+		paramsJSON, err := json.Marshal(params)
+		require.NoError(t, err)
+
+		call := ToolCall{
+			Name:  SourcegraphToolName,
+			Input: string(paramsJSON),
+		}
+
+		response, err := tool.Run(context.Background(), call)
+		require.NoError(t, err)
+		assert.Contains(t, response.Content, "Permission denied")
+	})
+
+	t.Run("normalizes count parameter", func(t *testing.T) {
+		// Test cases for count normalization
+		testCases := []struct {
+			name          string
+			inputCount    int
+			expectedCount int
+		}{
+			{"negative count", -5, 10},    // Should use default (10)
+			{"zero count", 0, 10},         // Should use default (10)
+			{"valid count", 50, 50},       // Should keep as is
+			{"excessive count", 150, 100}, // Should cap at 100
+		}
+
+		for _, tc := range testCases {
+			t.Run(tc.name, func(t *testing.T) {
+				// Verify count normalization logic directly
+				assert.NotPanics(t, func() {
+					// Apply the same normalization logic as in the tool
+					normalizedCount := tc.inputCount
+					if normalizedCount <= 0 {
+						normalizedCount = 10
+					} else if normalizedCount > 100 {
+						normalizedCount = 100
+					}
+
+					assert.Equal(t, tc.expectedCount, normalizedCount)
+				})
+			})
+		}
+	})
+}