@@ -0,0 +1,401 @@
+package tools
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+)
+
+const (
+ SourcegraphToolName = "sourcegraph"
+ sourcegraphToolDescription = `Search code across public repositories using Sourcegraph's GraphQL API.
+
+WHEN TO USE THIS TOOL:
+- Use when you need to find code examples or implementations across public repositories
+- Helpful for researching how others have solved similar problems
+- Useful for discovering patterns and best practices in open source code
+
+HOW TO USE:
+- Provide a search query using Sourcegraph's query syntax
+- Optionally specify the number of results to return (default: 10)
+- Optionally set a timeout for the request
+
+QUERY SYNTAX:
+- Basic search: "fmt.Println" searches for exact matches
+- File filters: "file:.go fmt.Println" limits to Go files
+- Repository filters: "repo:^github\.com/golang/go$ fmt.Println" limits to specific repos
+- Language filters: "lang:go fmt.Println" limits to Go code
+- Boolean operators: "fmt.Println AND log.Fatal" for combined terms
+- Regular expressions: "fmt\.(Print|Printf|Println)" for pattern matching
+- Quoted strings: "\"exact phrase\"" for exact phrase matching
+- Exclude filters: "-file:test" or "-repo:forks" to exclude matches
+
+ADVANCED FILTERS:
+- Repository filters:
+ * "repo:name" - Match repositories with name containing "name"
+ * "repo:^github\.com/org/repo$" - Exact repository match
+ * "repo:org/repo@branch" - Search specific branch
+ * "repo:org/repo rev:branch" - Alternative branch syntax
+ * "-repo:name" - Exclude repositories
+ * "fork:yes" or "fork:only" - Include or only show forks
+ * "archived:yes" or "archived:only" - Include or only show archived repos
+ * "visibility:public" or "visibility:private" - Filter by visibility
+
+- File filters:
+ * "file:\.js$" - Files with .js extension
+ * "file:internal/" - Files in internal directory
+ * "-file:test" - Exclude test files
+ * "file:has.content(Copyright)" - Files containing "Copyright"
+ * "file:has.contributor([email protected])" - Files with specific contributor
+
+- Content filters:
+ * "content:\"exact string\"" - Search for exact string
+ * "-content:\"unwanted\"" - Exclude files with unwanted content
+ * "case:yes" - Case-sensitive search
+
+- Type filters:
+ * "type:symbol" - Search for symbols (functions, classes, etc.)
+ * "type:file" - Search file content only
+ * "type:path" - Search filenames only
+ * "type:diff" - Search code changes
+ * "type:commit" - Search commit messages
+
+- Commit/diff search:
+ * "after:\"1 month ago\"" - Commits after date
+ * "before:\"2023-01-01\"" - Commits before date
+ * "author:name" - Commits by author
+ * "message:\"fix bug\"" - Commits with message
+
+- Result selection:
+ * "select:repo" - Show only repository names
+ * "select:file" - Show only file paths
+ * "select:content" - Show only matching content
+ * "select:symbol" - Show only matching symbols
+
+- Result control:
+ * "count:100" - Return up to 100 results
+ * "count:all" - Return all results
+ * "timeout:30s" - Set search timeout
+
+EXAMPLES:
+- "file:.go context.WithTimeout" - Find Go code using context.WithTimeout
+- "lang:typescript useState type:symbol" - Find TypeScript React useState hooks
+- "repo:^github\.com/kubernetes/kubernetes$ pod list type:file" - Find Kubernetes files related to pod listing
+- "repo:sourcegraph/sourcegraph$ after:\"3 months ago\" type:diff database" - Recent changes to database code
+- "file:Dockerfile (alpine OR ubuntu) -content:alpine:latest" - Dockerfiles with specific base images
+- "repo:has.path(\.py) file:requirements.txt tensorflow" - Python projects using TensorFlow
+
+BOOLEAN OPERATORS:
+- "term1 AND term2" - Results containing both terms
+- "term1 OR term2" - Results containing either term
+- "term1 NOT term2" - Results with term1 but not term2
+- "term1 and (term2 or term3)" - Grouping with parentheses
+
+LIMITATIONS:
+- Only searches public repositories
+- Rate limits may apply
+- Complex queries may take longer to execute
+- Maximum of 20 results per query
+
+TIPS:
+- Use specific file extensions to narrow results
+- Add repo: filters for more targeted searches
+- Use type:symbol to find function/method definitions
+- Use type:file to find relevant files
+- For more details on query syntax, visit: https://docs.sourcegraph.com/code_search/queries`
+)
+
+type SourcegraphParams struct {
+ Query string `json:"query"`
+ Count int `json:"count,omitempty"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+type SourcegraphPermissionsParams struct {
+ Query string `json:"query"`
+ Count int `json:"count,omitempty"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+type sourcegraphTool struct {
+ client *http.Client
+}
+
+func NewSourcegraphTool() BaseTool {
+ return &sourcegraphTool{
+ client: &http.Client{
+ Timeout: 30 * time.Second,
+ },
+ }
+}
+
+func (t *sourcegraphTool) Info() ToolInfo {
+ return ToolInfo{
+ Name: SourcegraphToolName,
+ Description: sourcegraphToolDescription,
+ Parameters: map[string]any{
+ "query": map[string]any{
+ "type": "string",
+ "description": "The Sourcegraph search query",
+ },
+ "count": map[string]any{
+ "type": "number",
+ "description": "Optional number of results to return (default: 10, max: 20)",
+ },
+ "timeout": map[string]any{
+ "type": "number",
+ "description": "Optional timeout in seconds (max 120)",
+ },
+ },
+ Required: []string{"query"},
+ }
+}
+
+func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
+ var params SourcegraphParams
+ if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
+ return NewTextErrorResponse("Failed to parse sourcegraph parameters: " + err.Error()), nil
+ }
+
+ if params.Query == "" {
+ return NewTextErrorResponse("Query parameter is required"), nil
+ }
+
+ // Set default count if not specified
+ if params.Count <= 0 {
+ params.Count = 10
+ } else if params.Count > 20 {
+ params.Count = 20 // Limit to 20 results
+ }
+
+ client := t.client
+ if params.Timeout > 0 {
+ maxTimeout := 120 // 2 minutes
+ if params.Timeout > maxTimeout {
+ params.Timeout = maxTimeout
+ }
+ client = &http.Client{
+ Timeout: time.Duration(params.Timeout) * time.Second,
+ }
+ }
+
+ // GraphQL query for Sourcegraph search
+ // Create a properly escaped JSON structure
+ type graphqlRequest struct {
+ Query string `json:"query"`
+ Variables struct {
+ Query string `json:"query"`
+ } `json:"variables"`
+ }
+
+ request := graphqlRequest{
+ Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: standard ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
+ }
+ request.Variables.Query = params.Query
+
+ // Marshal to JSON to ensure proper escaping
+ graphqlQueryBytes, err := json.Marshal(request)
+ if err != nil {
+ return NewTextErrorResponse("Failed to create GraphQL request: " + err.Error()), nil
+ }
+ graphqlQuery := string(graphqlQueryBytes)
+
+ // Create request to Sourcegraph API
+ req, err := http.NewRequestWithContext(
+ ctx,
+ "POST",
+ "https://sourcegraph.com/.api/graphql",
+ bytes.NewBuffer([]byte(graphqlQuery)),
+ )
+ if err != nil {
+ return NewTextErrorResponse("Failed to create request: " + err.Error()), nil
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", "termai/1.0")
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return NewTextErrorResponse("Failed to execute request: " + err.Error()), nil
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ // log the error response
+ body, _ := io.ReadAll(resp.Body)
+ if len(body) > 0 {
+ return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d, response: %s", resp.StatusCode, string(body))), nil
+ }
+
+ return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
+ }
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
+ }
+
+ // Parse the GraphQL response
+ var result map[string]any
+ if err = json.Unmarshal(body, &result); err != nil {
+ return NewTextErrorResponse("Failed to parse response: " + err.Error()), nil
+ }
+
+ // Format the results in a readable way
+ formattedResults, err := formatSourcegraphResults(result)
+ if err != nil {
+ return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
+ }
+
+ return NewTextResponse(formattedResults), nil
+}
+
+func formatSourcegraphResults(result map[string]any) (string, error) {
+ var buffer strings.Builder
+
+ // Check for errors in the GraphQL response
+ if errors, ok := result["errors"].([]any); ok && len(errors) > 0 {
+ buffer.WriteString("## Sourcegraph API Error\n\n")
+ for _, err := range errors {
+ if errMap, ok := err.(map[string]any); ok {
+ if message, ok := errMap["message"].(string); ok {
+ buffer.WriteString(fmt.Sprintf("- %s\n", message))
+ }
+ }
+ }
+ return buffer.String(), nil
+ }
+
+ // Extract data from the response
+ data, ok := result["data"].(map[string]any)
+ if !ok {
+ return "", fmt.Errorf("invalid response format: missing data field")
+ }
+
+ search, ok := data["search"].(map[string]any)
+ if !ok {
+ return "", fmt.Errorf("invalid response format: missing search field")
+ }
+
+ searchResults, ok := search["results"].(map[string]any)
+ if !ok {
+ return "", fmt.Errorf("invalid response format: missing results field")
+ }
+
+ // Write search metadata
+ matchCount, _ := searchResults["matchCount"].(float64)
+ resultCount, _ := searchResults["resultCount"].(float64)
+ limitHit, _ := searchResults["limitHit"].(bool)
+
+ buffer.WriteString("# Sourcegraph Search Results\n\n")
+ buffer.WriteString(fmt.Sprintf("Found %d matches across %d results\n", int(matchCount), int(resultCount)))
+
+ if limitHit {
+ buffer.WriteString("(Result limit reached, try a more specific query)\n")
+ }
+
+ buffer.WriteString("\n")
+
+ // Process results
+ results, ok := searchResults["results"].([]any)
+ if !ok || len(results) == 0 {
+ buffer.WriteString("No results found. Try a different query.\n")
+ return buffer.String(), nil
+ }
+
+ // Limit to 10 results
+ maxResults := 10
+ if len(results) > maxResults {
+ results = results[:maxResults]
+ }
+
+ // Process each result
+ for i, res := range results {
+ fileMatch, ok := res.(map[string]any)
+ if !ok {
+ continue
+ }
+
+ // Skip non-FileMatch results
+ typeName, _ := fileMatch["__typename"].(string)
+ if typeName != "FileMatch" {
+ continue
+ }
+
+ // Extract repository and file information
+ repo, _ := fileMatch["repository"].(map[string]any)
+ file, _ := fileMatch["file"].(map[string]any)
+ lineMatches, _ := fileMatch["lineMatches"].([]any)
+
+ if repo == nil || file == nil {
+ continue
+ }
+
+ repoName, _ := repo["name"].(string)
+ filePath, _ := file["path"].(string)
+ fileURL, _ := file["url"].(string)
+ fileContent, _ := file["content"].(string)
+
+ buffer.WriteString(fmt.Sprintf("## Result %d: %s/%s\n\n", i+1, repoName, filePath))
+
+ if fileURL != "" {
+ buffer.WriteString(fmt.Sprintf("URL: %s\n\n", fileURL))
+ }
+
+ // Show line matches with context
+ if len(lineMatches) > 0 {
+ for _, lm := range lineMatches {
+ lineMatch, ok := lm.(map[string]any)
+ if !ok {
+ continue
+ }
+
+ lineNumber, _ := lineMatch["lineNumber"].(float64)
+ preview, _ := lineMatch["preview"].(string)
+
+ // Extract context from file content if available
+ if fileContent != "" {
+ lines := strings.Split(fileContent, "\n")
+
+ buffer.WriteString("```\n")
+
+ // Display context before the match (up to 10 lines)
+ contextBefore := 10
+ startLine := max(1, int(lineNumber)-contextBefore)
+
+ for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
+ if j >= 0 {
+ buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
+ }
+ }
+
+ // Display the matching line (highlighted)
+ buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
+
+ // Display context after the match (up to 10 lines)
+ contextAfter := 10
+ endLine := int(lineNumber) + contextAfter
+
+ for j := int(lineNumber); j < endLine && j < len(lines); j++ {
+ if j < len(lines) {
+ buffer.WriteString(fmt.Sprintf("%d| %s\n", j+1, lines[j]))
+ }
+ }
+
+ buffer.WriteString("```\n\n")
+ } else {
+ // If file content is not available, just show the preview
+ buffer.WriteString("```\n")
+ buffer.WriteString(fmt.Sprintf("%d| %s\n", int(lineNumber), preview))
+ buffer.WriteString("```\n\n")
+ }
+ }
+ }
+ }
+
+ return buffer.String(), nil
+}
@@ -0,0 +1,115 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/kujtimiihoxha/termai/internal/permission"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestSourcegraphTool_Info(t *testing.T) {
+ tool := NewSourcegraphTool()
+ info := tool.Info()
+
+ assert.Equal(t, SourcegraphToolName, info.Name)
+ assert.NotEmpty(t, info.Description)
+ assert.Contains(t, info.Parameters, "query")
+ assert.Contains(t, info.Parameters, "count")
+ assert.Contains(t, info.Parameters, "timeout")
+ assert.Contains(t, info.Required, "query")
+}
+
+func TestSourcegraphTool_Run(t *testing.T) {
+ // Setup a mock permission handler that always allows
+ origPermission := permission.Default
+ defer func() {
+ permission.Default = origPermission
+ }()
+ permission.Default = newMockPermissionService(true)
+
+ t.Run("handles missing query parameter", func(t *testing.T) {
+ tool := NewSourcegraphTool()
+ params := SourcegraphParams{
+ Query: "",
+ }
+
+ paramsJSON, err := json.Marshal(params)
+ require.NoError(t, err)
+
+ call := ToolCall{
+ Name: SourcegraphToolName,
+ Input: string(paramsJSON),
+ }
+
+ response, err := tool.Run(context.Background(), call)
+ require.NoError(t, err)
+ assert.Contains(t, response.Content, "Query parameter is required")
+ })
+
+ t.Run("handles invalid parameters", func(t *testing.T) {
+ tool := NewSourcegraphTool()
+ call := ToolCall{
+ Name: SourcegraphToolName,
+ Input: "invalid json",
+ }
+
+ response, err := tool.Run(context.Background(), call)
+ require.NoError(t, err)
+ assert.Contains(t, response.Content, "Failed to parse sourcegraph parameters")
+ })
+
+ t.Run("handles permission denied", func(t *testing.T) {
+ permission.Default = newMockPermissionService(false)
+
+ tool := NewSourcegraphTool()
+ params := SourcegraphParams{
+ Query: "test query",
+ }
+
+ paramsJSON, err := json.Marshal(params)
+ require.NoError(t, err)
+
+ call := ToolCall{
+ Name: SourcegraphToolName,
+ Input: string(paramsJSON),
+ }
+
+ response, err := tool.Run(context.Background(), call)
+ require.NoError(t, err)
+ assert.Contains(t, response.Content, "Permission denied")
+ })
+
+ t.Run("normalizes count parameter", func(t *testing.T) {
+ // Test cases for count normalization
+ testCases := []struct {
+ name string
+ inputCount int
+ expectedCount int
+ }{
+ {"negative count", -5, 10}, // Should use default (10)
+ {"zero count", 0, 10}, // Should use default (10)
+ {"valid count", 50, 50}, // Should keep as is
+ {"excessive count", 150, 100}, // Should cap at 100
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Verify count normalization logic directly
+ assert.NotPanics(t, func() {
+ // Apply the same normalization logic as in the tool
+ normalizedCount := tc.inputCount
+ if normalizedCount <= 0 {
+ normalizedCount = 10
+ } else if normalizedCount > 100 {
+ normalizedCount = 100
+ }
+
+ assert.Equal(t, tc.expectedCount, normalizedCount)
+ })
+ })
+ }
+ })
+}