diagnostics.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"maps"
  8	"sort"
  9	"strings"
 10	"time"
 11
 12	"github.com/kujtimiihoxha/termai/internal/lsp"
 13	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 14)
 15
 16type diagnosticsTool struct {
 17	lspClients map[string]*lsp.Client
 18}
 19
 20const (
 21	DiagnosticsToolName = "diagnostics"
 22)
 23
 24type DiagnosticsParams struct {
 25	FilePath string `json:"file_path"`
 26}
 27
 28func (b *diagnosticsTool) Info() ToolInfo {
 29	return ToolInfo{
 30		Name:        DiagnosticsToolName,
 31		Description: "Get diagnostics for a file and/or project.",
 32		Parameters: map[string]any{
 33			"file_path": map[string]any{
 34				"type":        "string",
 35				"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
 36			},
 37		},
 38		Required: []string{},
 39	}
 40}
 41
 42func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 43	var params DiagnosticsParams
 44	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
 45		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 46	}
 47
 48	lsps := b.lspClients
 49
 50	if len(lsps) == 0 {
 51		return NewTextErrorResponse("no LSP clients available"), nil
 52	}
 53
 54	if params.FilePath != "" {
 55		notifyLspOpenFile(ctx, params.FilePath, lsps)
 56	}
 57
 58	output := appendDiagnostics(params.FilePath, lsps)
 59
 60	return NewTextResponse(output), nil
 61}
 62
 63func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
 64	// Create a channel to receive diagnostic notifications
 65	diagChan := make(chan struct{}, 1)
 66
 67	// Register a temporary diagnostic handler for each client
 68	for _, client := range lsps {
 69		// Store the original diagnostics map to detect changes
 70		originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
 71		maps.Copy(originalDiags, client.GetDiagnostics())
 72
 73		// Create a notification handler that will signal when diagnostics are received
 74		handler := func(params json.RawMessage) {
 75			lsp.HandleDiagnostics(client, params)
 76			var diagParams protocol.PublishDiagnosticsParams
 77			if err := json.Unmarshal(params, &diagParams); err != nil {
 78				return
 79			}
 80
 81			// If this is for our file or we've received any new diagnostics, signal completion
 82			if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
 83				select {
 84				case diagChan <- struct{}{}:
 85					// Signal sent
 86				default:
 87					// Channel already has a value, no need to send again
 88				}
 89			}
 90		}
 91
 92		// Register our temporary handler
 93		client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
 94
 95		// Open the file
 96		err := client.OpenFile(ctx, filePath)
 97		if err != nil {
 98			// If there's an error opening the file, continue to the next client
 99			continue
100		}
101	}
102
103	// Wait for diagnostics with a reasonable timeout
104	select {
105	case <-diagChan:
106		// Diagnostics received
107	case <-time.After(10 * time.Second):
108		// Timeout after 5 seconds - this is a fallback in case no diagnostics are published
109	case <-ctx.Done():
110		// Context cancelled
111	}
112
113	// Note: We're not unregistering our handler because the Client.RegisterNotificationHandler
114	// replaces any existing handler, and we'll be replaced by the original handler when
115	// the LSP client is reinitialized or when a new handler is registered.
116}
117
118// hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set
119func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
120	for uri, diags := range current {
121		origDiags, exists := original[uri]
122		if !exists || len(diags) != len(origDiags) {
123			return true
124		}
125	}
126	return false
127}
128
129func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
130	fileDiagnostics := []string{}
131	projectDiagnostics := []string{}
132
133	// Enhanced format function that includes more diagnostic information
134	formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
135		// Base components
136		severity := "Info"
137		switch diagnostic.Severity {
138		case protocol.SeverityError:
139			severity = "Error"
140		case protocol.SeverityWarning:
141			severity = "Warn"
142		case protocol.SeverityHint:
143			severity = "Hint"
144		}
145
146		// Location information
147		location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
148
149		// Source information (LSP name)
150		sourceInfo := ""
151		if diagnostic.Source != "" {
152			sourceInfo = diagnostic.Source
153		} else if source != "" {
154			sourceInfo = source
155		}
156
157		// Code information
158		codeInfo := ""
159		if diagnostic.Code != nil {
160			codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
161		}
162
163		// Tags information
164		tagsInfo := ""
165		if len(diagnostic.Tags) > 0 {
166			tags := []string{}
167			for _, tag := range diagnostic.Tags {
168				switch tag {
169				case protocol.Unnecessary:
170					tags = append(tags, "unnecessary")
171				case protocol.Deprecated:
172					tags = append(tags, "deprecated")
173				}
174			}
175			if len(tags) > 0 {
176				tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
177			}
178		}
179
180		// Assemble the full diagnostic message
181		return fmt.Sprintf("%s: %s [%s]%s%s %s",
182			severity,
183			location,
184			sourceInfo,
185			codeInfo,
186			tagsInfo,
187			diagnostic.Message)
188	}
189
190	for lspName, client := range lsps {
191		diagnostics := client.GetDiagnostics()
192		if len(diagnostics) > 0 {
193			for location, diags := range diagnostics {
194				isCurrentFile := location.Path() == filePath
195
196				// Group diagnostics by severity for better organization
197				for _, diag := range diags {
198					formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
199
200					if isCurrentFile {
201						fileDiagnostics = append(fileDiagnostics, formattedDiag)
202					} else {
203						projectDiagnostics = append(projectDiagnostics, formattedDiag)
204					}
205				}
206			}
207		}
208	}
209
210	// Sort diagnostics by severity (errors first) and then by location
211	sort.Slice(fileDiagnostics, func(i, j int) bool {
212		iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
213		jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
214		if iIsError != jIsError {
215			return iIsError // Errors come first
216		}
217		return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
218	})
219
220	sort.Slice(projectDiagnostics, func(i, j int) bool {
221		iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
222		jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
223		if iIsError != jIsError {
224			return iIsError
225		}
226		return projectDiagnostics[i] < projectDiagnostics[j]
227	})
228
229	output := ""
230
231	if len(fileDiagnostics) > 0 {
232		output += "\n<file_diagnostics>\n"
233		if len(fileDiagnostics) > 10 {
234			output += strings.Join(fileDiagnostics[:10], "\n")
235			output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
236		} else {
237			output += strings.Join(fileDiagnostics, "\n")
238		}
239		output += "\n</file_diagnostics>\n"
240	}
241
242	if len(projectDiagnostics) > 0 {
243		output += "\n<project_diagnostics>\n"
244		if len(projectDiagnostics) > 10 {
245			output += strings.Join(projectDiagnostics[:10], "\n")
246			output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
247		} else {
248			output += strings.Join(projectDiagnostics, "\n")
249		}
250		output += "\n</project_diagnostics>\n"
251	}
252
253	// Add summary counts
254	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
255		fileErrors := countSeverity(fileDiagnostics, "Error")
256		fileWarnings := countSeverity(fileDiagnostics, "Warn")
257		projectErrors := countSeverity(projectDiagnostics, "Error")
258		projectWarnings := countSeverity(projectDiagnostics, "Warn")
259
260		output += "\n<diagnostic_summary>\n"
261		output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
262		output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
263		output += "</diagnostic_summary>\n"
264	}
265
266	return output
267}
268
269// Helper function to count diagnostics by severity
270func countSeverity(diagnostics []string, severity string) int {
271	count := 0
272	for _, diag := range diagnostics {
273		if strings.HasPrefix(diag, severity) {
274			count++
275		}
276	}
277	return count
278}
279
280func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
281	return &diagnosticsTool{
282		lspClients,
283	}
284}