diagnostics.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"sort"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/ai"
 13	"github.com/charmbracelet/crush/internal/lsp"
 14	"github.com/charmbracelet/crush/internal/lsp/protocol"
 15)
 16
 17type DiagnosticsParams struct {
 18	FilePath string `json:"file_path,omitempty" description:"The path to the file to get diagnostics for (leave w empty for project diagnostics)"`
 19}
 20
 21const (
 22	DiagnosticsToolName = "diagnostics"
 23)
 24
 25func NewDiagnosticsTool(lsps map[string]*lsp.Client) ai.AgentTool {
 26	return ai.NewTypedToolFunc(
 27		DiagnosticsToolName,
 28		`Get diagnostics for a file and/or project.
 29WHEN TO USE THIS TOOL:
 30- Use when you need to check for errors or warnings in your code
 31- Helpful for debugging and ensuring code quality
 32- Good for getting a quick overview of issues in a file or project
 33HOW TO USE:
 34- Provide a path to a file to get diagnostics for that file
 35- Leave the path empty to get diagnostics for the entire project
 36- Results are displayed in a structured format with severity levels
 37FEATURES:
 38- Displays errors, warnings, and hints
 39- Groups diagnostics by severity
 40- Provides detailed information about each diagnostic
 41LIMITATIONS:
 42- Results are limited to the diagnostics provided by the LSP clients
 43- May not cover all possible issues in the code
 44- Does not provide suggestions for fixing issues
 45TIPS:
 46- Use in conjunction with other tools for a comprehensive code review
 47- Combine with the LSP client for real-time diagnostics`,
 48
 49		func(ctx context.Context, params DiagnosticsParams, call ai.ToolCall) (ai.ToolResponse, error) {
 50			if len(lsps) == 0 {
 51				return ai.NewTextErrorResponse("no LSP clients available"), nil
 52			}
 53
 54			if params.FilePath != "" {
 55				notifyLspOpenFile(ctx, params.FilePath, lsps)
 56				waitForLspDiagnostics(ctx, params.FilePath, lsps)
 57			}
 58
 59			output := getDiagnostics(params.FilePath, lsps)
 60
 61			return ai.NewTextResponse(output), nil
 62		},
 63	)
 64}
 65
 66func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
 67	for _, client := range lsps {
 68		err := client.OpenFile(ctx, filePath)
 69		if err != nil {
 70			continue
 71		}
 72	}
 73}
 74
 75func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
 76	if len(lsps) == 0 {
 77		return
 78	}
 79
 80	diagChan := make(chan struct{}, 1)
 81
 82	for _, client := range lsps {
 83		originalDiags := client.GetDiagnostics()
 84
 85		handler := func(params json.RawMessage) {
 86			lsp.HandleDiagnostics(client, params)
 87			var diagParams protocol.PublishDiagnosticsParams
 88			if err := json.Unmarshal(params, &diagParams); err != nil {
 89				return
 90			}
 91
 92			path, err := diagParams.URI.Path()
 93			if err != nil {
 94				slog.Error("Failed to convert diagnostic URI to path", "uri", diagParams.URI, "error", err)
 95				return
 96			}
 97
 98			if path == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
 99				select {
100				case diagChan <- struct{}{}:
101				default:
102				}
103			}
104		}
105
106		client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
107
108		if client.IsFileOpen(filePath) {
109			err := client.NotifyChange(ctx, filePath)
110			if err != nil {
111				continue
112			}
113		} else {
114			err := client.OpenFile(ctx, filePath)
115			if err != nil {
116				continue
117			}
118		}
119	}
120
121	select {
122	case <-diagChan:
123	case <-time.After(5 * time.Second):
124	case <-ctx.Done():
125	}
126}
127
128func hasDiagnosticsChanged(current, original map[protocol.DocumentURI][]protocol.Diagnostic) bool {
129	for uri, diags := range current {
130		origDiags, exists := original[uri]
131		if !exists || len(diags) != len(origDiags) {
132			return true
133		}
134	}
135	return false
136}
137
138func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
139	fileDiagnostics := []string{}
140	projectDiagnostics := []string{}
141
142	formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
143		severity := "Info"
144		switch diagnostic.Severity {
145		case protocol.SeverityError:
146			severity = "Error"
147		case protocol.SeverityWarning:
148			severity = "Warn"
149		case protocol.SeverityHint:
150			severity = "Hint"
151		}
152
153		location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
154
155		sourceInfo := ""
156		if diagnostic.Source != "" {
157			sourceInfo = diagnostic.Source
158		} else if source != "" {
159			sourceInfo = source
160		}
161
162		codeInfo := ""
163		if diagnostic.Code != nil {
164			codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
165		}
166
167		tagsInfo := ""
168		if len(diagnostic.Tags) > 0 {
169			tags := []string{}
170			for _, tag := range diagnostic.Tags {
171				switch tag {
172				case protocol.Unnecessary:
173					tags = append(tags, "unnecessary")
174				case protocol.Deprecated:
175					tags = append(tags, "deprecated")
176				}
177			}
178			if len(tags) > 0 {
179				tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
180			}
181		}
182
183		return fmt.Sprintf("%s: %s [%s]%s%s %s",
184			severity,
185			location,
186			sourceInfo,
187			codeInfo,
188			tagsInfo,
189			diagnostic.Message)
190	}
191
192	for lspName, client := range lsps {
193		diagnostics := client.GetDiagnostics()
194		if len(diagnostics) > 0 {
195			for location, diags := range diagnostics {
196				path, err := location.Path()
197				if err != nil {
198					slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
199					continue
200				}
201				isCurrentFile := path == filePath
202
203				for _, diag := range diags {
204					formattedDiag := formatDiagnostic(path, diag, lspName)
205
206					if isCurrentFile {
207						fileDiagnostics = append(fileDiagnostics, formattedDiag)
208					} else {
209						projectDiagnostics = append(projectDiagnostics, formattedDiag)
210					}
211				}
212			}
213		}
214	}
215
216	sort.Slice(fileDiagnostics, func(i, j int) bool {
217		iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
218		jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
219		if iIsError != jIsError {
220			return iIsError // Errors come first
221		}
222		return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
223	})
224
225	sort.Slice(projectDiagnostics, func(i, j int) bool {
226		iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
227		jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
228		if iIsError != jIsError {
229			return iIsError
230		}
231		return projectDiagnostics[i] < projectDiagnostics[j]
232	})
233
234	var output strings.Builder
235
236	if len(fileDiagnostics) > 0 {
237		output.WriteString("\n<file_diagnostics>\n")
238		if len(fileDiagnostics) > 10 {
239			output.WriteString(strings.Join(fileDiagnostics[:10], "\n"))
240			fmt.Fprintf(&output, "\n... and %d more diagnostics", len(fileDiagnostics)-10)
241		} else {
242			output.WriteString(strings.Join(fileDiagnostics, "\n"))
243		}
244		output.WriteString("\n</file_diagnostics>\n")
245	}
246
247	if len(projectDiagnostics) > 0 {
248		output.WriteString("\n<project_diagnostics>\n")
249		if len(projectDiagnostics) > 10 {
250			output.WriteString(strings.Join(projectDiagnostics[:10], "\n"))
251			fmt.Fprintf(&output, "\n... and %d more diagnostics", len(projectDiagnostics)-10)
252		} else {
253			output.WriteString(strings.Join(projectDiagnostics, "\n"))
254		}
255		output.WriteString("\n</project_diagnostics>\n")
256	}
257
258	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
259		fileErrors := countSeverity(fileDiagnostics, "Error")
260		fileWarnings := countSeverity(fileDiagnostics, "Warn")
261		projectErrors := countSeverity(projectDiagnostics, "Error")
262		projectWarnings := countSeverity(projectDiagnostics, "Warn")
263
264		output.WriteString("\n<diagnostic_summary>\n")
265		fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
266		fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
267		output.WriteString("</diagnostic_summary>\n")
268	}
269
270	return output.String()
271}
272
273func countSeverity(diagnostics []string, severity string) int {
274	count := 0
275	for _, diag := range diagnostics {
276		if strings.HasPrefix(diag, severity) {
277			count++
278		}
279	}
280	return count
281}