diagnostics.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"maps"
  8	"sort"
  9	"strings"
 10	"time"
 11
 12	"github.com/kujtimiihoxha/termai/internal/lsp"
 13	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 14)
 15
 16type diagnosticsTool struct {
 17	lspClients map[string]*lsp.Client
 18}
 19
 20const (
 21	DiagnosticsToolName = "diagnostics"
 22)
 23
 24type DiagnosticsParams struct {
 25	FilePath string `json:"file_path"`
 26}
 27
 28func (b *diagnosticsTool) Info() ToolInfo {
 29	return ToolInfo{
 30		Name:        DiagnosticsToolName,
 31		Description: "Get diagnostics for a file and/or project.",
 32		Parameters: map[string]any{
 33			"file_path": map[string]any{
 34				"type":        "string",
 35				"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
 36			},
 37		},
 38		Required: []string{},
 39	}
 40}
 41
 42func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 43	var params DiagnosticsParams
 44	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
 45		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 46	}
 47
 48	lsps := b.lspClients
 49
 50	if len(lsps) == 0 {
 51		return NewTextErrorResponse("no LSP clients available"), nil
 52	}
 53
 54	if params.FilePath != "" {
 55		notifyLspOpenFile(ctx, params.FilePath, lsps)
 56		waitForLspDiagnostics(ctx, params.FilePath, lsps)
 57	}
 58
 59	output := appendDiagnostics(params.FilePath, lsps)
 60
 61	return NewTextResponse(output), nil
 62}
 63
 64func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
 65	for _, client := range lsps {
 66		// Open the file
 67		err := client.OpenFile(ctx, filePath)
 68		if err != nil {
 69			// If there's an error opening the file, continue to the next client
 70			continue
 71		}
 72	}
 73}
 74
 75// waitForLspDiagnostics opens a file in LSP clients and waits for diagnostics to be published
 76func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
 77	if len(lsps) == 0 {
 78		return
 79	}
 80
 81	// Create a channel to receive diagnostic notifications
 82	diagChan := make(chan struct{}, 1)
 83
 84	// Register a temporary diagnostic handler for each client
 85	for _, client := range lsps {
 86		// Store the original diagnostics map to detect changes
 87		originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
 88		maps.Copy(originalDiags, client.GetDiagnostics())
 89
 90		// Create a notification handler that will signal when diagnostics are received
 91		handler := func(params json.RawMessage) {
 92			lsp.HandleDiagnostics(client, params)
 93			var diagParams protocol.PublishDiagnosticsParams
 94			if err := json.Unmarshal(params, &diagParams); err != nil {
 95				return
 96			}
 97
 98			// If this is for our file or we've received any new diagnostics, signal completion
 99			if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
100				select {
101				case diagChan <- struct{}{}:
102					// Signal sent
103				default:
104					// Channel already has a value, no need to send again
105				}
106			}
107		}
108
109		// Register our temporary handler
110		client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
111
112		// Notify change if the file is already open
113		if client.IsFileOpen(filePath) {
114			err := client.NotifyChange(ctx, filePath)
115			if err != nil {
116				continue
117			}
118		} else {
119			// Open the file if it's not already open
120			err := client.OpenFile(ctx, filePath)
121			if err != nil {
122				continue
123			}
124		}
125	}
126
127	// Wait for diagnostics with a reasonable timeout
128	select {
129	case <-diagChan:
130		// Diagnostics received
131	case <-time.After(5 * time.Second):
132		// Timeout after 5 seconds - this is a fallback in case no diagnostics are published
133	case <-ctx.Done():
134		// Context cancelled
135	}
136
137	// Note: We're not unregistering our handler because the Client.RegisterNotificationHandler
138	// replaces any existing handler, and we'll be replaced by the original handler when
139	// the LSP client is reinitialized or when a new handler is registered.
140}
141
142// hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set
143func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
144	for uri, diags := range current {
145		origDiags, exists := original[uri]
146		if !exists || len(diags) != len(origDiags) {
147			return true
148		}
149	}
150	return false
151}
152
153func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
154	fileDiagnostics := []string{}
155	projectDiagnostics := []string{}
156
157	// Enhanced format function that includes more diagnostic information
158	formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
159		// Base components
160		severity := "Info"
161		switch diagnostic.Severity {
162		case protocol.SeverityError:
163			severity = "Error"
164		case protocol.SeverityWarning:
165			severity = "Warn"
166		case protocol.SeverityHint:
167			severity = "Hint"
168		}
169
170		// Location information
171		location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
172
173		// Source information (LSP name)
174		sourceInfo := ""
175		if diagnostic.Source != "" {
176			sourceInfo = diagnostic.Source
177		} else if source != "" {
178			sourceInfo = source
179		}
180
181		// Code information
182		codeInfo := ""
183		if diagnostic.Code != nil {
184			codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
185		}
186
187		// Tags information
188		tagsInfo := ""
189		if len(diagnostic.Tags) > 0 {
190			tags := []string{}
191			for _, tag := range diagnostic.Tags {
192				switch tag {
193				case protocol.Unnecessary:
194					tags = append(tags, "unnecessary")
195				case protocol.Deprecated:
196					tags = append(tags, "deprecated")
197				}
198			}
199			if len(tags) > 0 {
200				tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
201			}
202		}
203
204		// Assemble the full diagnostic message
205		return fmt.Sprintf("%s: %s [%s]%s%s %s",
206			severity,
207			location,
208			sourceInfo,
209			codeInfo,
210			tagsInfo,
211			diagnostic.Message)
212	}
213
214	for lspName, client := range lsps {
215		diagnostics := client.GetDiagnostics()
216		if len(diagnostics) > 0 {
217			for location, diags := range diagnostics {
218				isCurrentFile := location.Path() == filePath
219
220				// Group diagnostics by severity for better organization
221				for _, diag := range diags {
222					formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
223
224					if isCurrentFile {
225						fileDiagnostics = append(fileDiagnostics, formattedDiag)
226					} else {
227						projectDiagnostics = append(projectDiagnostics, formattedDiag)
228					}
229				}
230			}
231		}
232	}
233
234	// Sort diagnostics by severity (errors first) and then by location
235	sort.Slice(fileDiagnostics, func(i, j int) bool {
236		iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
237		jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
238		if iIsError != jIsError {
239			return iIsError // Errors come first
240		}
241		return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
242	})
243
244	sort.Slice(projectDiagnostics, func(i, j int) bool {
245		iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
246		jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
247		if iIsError != jIsError {
248			return iIsError
249		}
250		return projectDiagnostics[i] < projectDiagnostics[j]
251	})
252
253	output := ""
254
255	if len(fileDiagnostics) > 0 {
256		output += "\n<file_diagnostics>\n"
257		if len(fileDiagnostics) > 10 {
258			output += strings.Join(fileDiagnostics[:10], "\n")
259			output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
260		} else {
261			output += strings.Join(fileDiagnostics, "\n")
262		}
263		output += "\n</file_diagnostics>\n"
264	}
265
266	if len(projectDiagnostics) > 0 {
267		output += "\n<project_diagnostics>\n"
268		if len(projectDiagnostics) > 10 {
269			output += strings.Join(projectDiagnostics[:10], "\n")
270			output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
271		} else {
272			output += strings.Join(projectDiagnostics, "\n")
273		}
274		output += "\n</project_diagnostics>\n"
275	}
276
277	// Add summary counts
278	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
279		fileErrors := countSeverity(fileDiagnostics, "Error")
280		fileWarnings := countSeverity(fileDiagnostics, "Warn")
281		projectErrors := countSeverity(projectDiagnostics, "Error")
282		projectWarnings := countSeverity(projectDiagnostics, "Warn")
283
284		output += "\n<diagnostic_summary>\n"
285		output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
286		output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
287		output += "</diagnostic_summary>\n"
288	}
289
290	return output
291}
292
293// Helper function to count diagnostics by severity
294func countSeverity(diagnostics []string, severity string) int {
295	count := 0
296	for _, diag := range diagnostics {
297		if strings.HasPrefix(diag, severity) {
298			count++
299		}
300	}
301	return count
302}
303
304func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
305	return &diagnosticsTool{
306		lspClients,
307	}
308}