Improve Sourcegraph tool with context window and fix diagnostics

Kujtim Hoxha and termai created

- Add context_window parameter to control code context display
- Fix LSP diagnostics notification handling with proper async waiting
- Switch to keyword search pattern for better results
- Add Sourcegraph tool to task agent

🤖 Generated with termai
Co-Authored-By: termai <noreply@termai.io>

Change summary

cmd/lsp/main.go                   | 12 ++++++
internal/llm/agent/task.go        |  1 
internal/llm/tools/diagnostics.go | 62 ++++++++++++++++++++++++++++++--
internal/llm/tools/sourcegraph.go | 32 ++++++++--------
4 files changed, 87 insertions(+), 20 deletions(-)

Detailed changes

cmd/lsp/main.go 🔗

@@ -1,4 +1,16 @@
 package main
 
+import (
+	"context"
+	"fmt"
+
+	"github.com/kujtimiihoxha/termai/internal/llm/tools"
+)
+
 func main() {
+	t := tools.NewSourcegraphTool()
+	r, _ := t.Run(context.Background(), tools.ToolCall{
+		Input: `{"query": "context.WithCancel lang:go"}`,
+	})
+	fmt.Println(r.Content)
 }

internal/llm/agent/task.go 🔗

@@ -34,6 +34,7 @@ func NewTaskAgent(app *app.App) (Agent, error) {
 				tools.NewGlobTool(),
 				tools.NewGrepTool(),
 				tools.NewLsTool(),
+				tools.NewSourcegraphTool(),
 				tools.NewViewTool(app.LSPClients),
 			},
 			model:          model,

internal/llm/tools/diagnostics.go 🔗

@@ -4,6 +4,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"maps"
 	"sort"
 	"strings"
 	"time"
@@ -50,7 +51,7 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 		return NewTextErrorResponse("no LSP clients available"), nil
 	}
 
-	if params.FilePath == "" {
+	if params.FilePath != "" {
 		notifyLspOpenFile(ctx, params.FilePath, lsps)
 	}
 
@@ -60,15 +61,68 @@ func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 }
 
 func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
+	// Create a channel to receive diagnostic notifications
+	diagChan := make(chan struct{}, 1)
+
+	// Register a temporary diagnostic handler for each client
 	for _, client := range lsps {
+		// Store the original diagnostics map to detect changes
+		originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
+		maps.Copy(originalDiags, client.GetDiagnostics())
+
+		// Create a notification handler that will signal when diagnostics are received
+		handler := func(params json.RawMessage) {
+			var diagParams protocol.PublishDiagnosticsParams
+			if err := json.Unmarshal(params, &diagParams); err != nil {
+				return
+			}
+
+			// If this is for our file or we've received any new diagnostics, signal completion
+			if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
+				select {
+				case diagChan <- struct{}{}:
+					// Signal sent
+				default:
+					// Channel already has a value, no need to send again
+				}
+			}
+		}
+
+		// Register our temporary handler
+		client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
+
+		// Open the file
 		err := client.OpenFile(ctx, filePath)
 		if err != nil {
-			// Wait for the file to be opened and diagnostics to be received
-			// TODO: see if we can do this in a more efficient way
-			time.Sleep(3 * time.Second)
+			// If there's an error opening the file, continue to the next client
+			continue
 		}
+	}
+
+	// Wait for diagnostics with a reasonable timeout
+	select {
+	case <-diagChan:
+		// Diagnostics received
+	case <-time.After(5 * time.Second):
+		// Timeout after 2 seconds - this is a fallback in case no diagnostics are published
+	case <-ctx.Done():
+		// Context cancelled
+	}
 
+	// Note: We're not unregistering our handler because the Client.RegisterNotificationHandler
+	// replaces any existing handler, and we'll be replaced by the original handler when
+	// the LSP client is reinitialized or when a new handler is registered.
+}
+
+// hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set
+func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
+	for uri, diags := range current {
+		origDiags, exists := original[uri]
+		if !exists || len(diags) != len(origDiags) {
+			return true
+		}
 	}
+	return false
 }
 
 func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {

internal/llm/tools/sourcegraph.go 🔗

@@ -111,15 +111,10 @@ TIPS:
 )
 
 type SourcegraphParams struct {
-	Query   string `json:"query"`
-	Count   int    `json:"count,omitempty"`
-	Timeout int    `json:"timeout,omitempty"`
-}
-
-type SourcegraphPermissionsParams struct {
-	Query   string `json:"query"`
-	Count   int    `json:"count,omitempty"`
-	Timeout int    `json:"timeout,omitempty"`
+	Query         string `json:"query"`
+	Count         int    `json:"count,omitempty"`
+	ContextWindow int    `json:"context_window,omitempty"`
+	Timeout       int    `json:"timeout,omitempty"`
 }
 
 type sourcegraphTool struct {
@@ -147,6 +142,10 @@ func (t *sourcegraphTool) Info() ToolInfo {
 				"type":        "number",
 				"description": "Optional number of results to return (default: 10, max: 20)",
 			},
+			"context_window": map[string]any{
+				"type":        "number",
+				"description": "The context around the match to return (default: 10 lines)",
+			},
 			"timeout": map[string]any{
 				"type":        "number",
 				"description": "Optional timeout in seconds (max 120)",
@@ -173,6 +172,9 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 		params.Count = 20 // Limit to 20 results
 	}
 
+	if params.ContextWindow <= 0 {
+		params.ContextWindow = 10 // Default context window
+	}
 	client := t.client
 	if params.Timeout > 0 {
 		maxTimeout := 120 // 2 minutes
@@ -194,7 +196,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	}
 
 	request := graphqlRequest{
-		Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: standard ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
+		Query: "query Search($query: String!) { search(query: $query, version: V2, patternType: keyword ) { results { matchCount, limitHit, resultCount, approximateResultCount, missing { name }, timedout { name }, indexUnavailable, results { __typename, ... on FileMatch { repository { name }, file { path, url, content }, lineMatches { preview, lineNumber, offsetAndLengths } } } } } }",
 	}
 	request.Variables.Query = params.Query
 
@@ -246,7 +248,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	}
 
 	// Format the results in a readable way
-	formattedResults, err := formatSourcegraphResults(result)
+	formattedResults, err := formatSourcegraphResults(result, params.ContextWindow)
 	if err != nil {
 		return NewTextErrorResponse("Failed to format results: " + err.Error()), nil
 	}
@@ -254,7 +256,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	return NewTextResponse(formattedResults), nil
 }
 
-func formatSourcegraphResults(result map[string]any) (string, error) {
+func formatSourcegraphResults(result map[string]any, contextWindow int) (string, error) {
 	var buffer strings.Builder
 
 	// Check for errors in the GraphQL response
@@ -364,8 +366,7 @@ func formatSourcegraphResults(result map[string]any) (string, error) {
 					buffer.WriteString("```\n")
 
 					// Display context before the match (up to 10 lines)
-					contextBefore := 10
-					startLine := max(1, int(lineNumber)-contextBefore)
+					startLine := max(1, int(lineNumber)-contextWindow)
 
 					for j := startLine - 1; j < int(lineNumber)-1 && j < len(lines); j++ {
 						if j >= 0 {
@@ -377,8 +378,7 @@ func formatSourcegraphResults(result map[string]any) (string, error) {
 					buffer.WriteString(fmt.Sprintf("%d|  %s\n", int(lineNumber), preview))
 
 					// Display context after the match (up to 10 lines)
-					contextAfter := 10
-					endLine := int(lineNumber) + contextAfter
+					endLine := int(lineNumber) + contextWindow
 
 					for j := int(lineNumber); j < endLine && j < len(lines); j++ {
 						if j < len(lines) {