1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "maps"
9 "sort"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/lsp"
14 "github.com/charmbracelet/crush/internal/lsp/protocol"
15)
16
17type DiagnosticsParams struct {
18 FilePath string `json:"file_path"`
19}
20type diagnosticsTool struct {
21 lspClients map[string]*lsp.Client
22}
23
24const (
25 DiagnosticsToolName = "diagnostics"
26 diagnosticsDescription = `Get diagnostics for a file and/or project.
27WHEN TO USE THIS TOOL:
28- Use when you need to check for errors or warnings in your code
29- Helpful for debugging and ensuring code quality
30- Good for getting a quick overview of issues in a file or project
31HOW TO USE:
32- Provide a path to a file to get diagnostics for that file
33- Leave the path empty to get diagnostics for the entire project
34- Results are displayed in a structured format with severity levels
35FEATURES:
36- Displays errors, warnings, and hints
37- Groups diagnostics by severity
38- Provides detailed information about each diagnostic
39LIMITATIONS:
40- Results are limited to the diagnostics provided by the LSP clients
41- May not cover all possible issues in the code
42- Does not provide suggestions for fixing issues
43TIPS:
44- Use in conjunction with other tools for a comprehensive code review
45- Combine with the LSP client for real-time diagnostics
46`
47)
48
49func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
50 return &diagnosticsTool{
51 lspClients,
52 }
53}
54
55func (b *diagnosticsTool) Name() string {
56 return DiagnosticsToolName
57}
58
59func (b *diagnosticsTool) Info() ToolInfo {
60 return ToolInfo{
61 Name: DiagnosticsToolName,
62 Description: diagnosticsDescription,
63 Parameters: map[string]any{
64 "file_path": map[string]any{
65 "type": "string",
66 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
67 },
68 },
69 Required: []string{},
70 }
71}
72
73func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
74 var params DiagnosticsParams
75 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
76 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
77 }
78
79 lsps := b.lspClients
80
81 if len(lsps) == 0 {
82 return NewTextErrorResponse("no LSP clients available"), nil
83 }
84
85 if params.FilePath != "" {
86 notifyLspOpenFile(ctx, params.FilePath, lsps)
87 waitForLspDiagnostics(ctx, params.FilePath, lsps)
88 }
89
90 output := getDiagnostics(params.FilePath, lsps)
91
92 return NewTextResponse(output), nil
93}
94
95func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
96 for _, client := range lsps {
97 err := client.OpenFile(ctx, filePath)
98 if err != nil {
99 continue
100 }
101 }
102}
103
104func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
105 if len(lsps) == 0 {
106 return
107 }
108
109 diagChan := make(chan struct{}, 1)
110
111 for _, client := range lsps {
112 originalDiags := make(map[protocol.DocumentURI][]protocol.Diagnostic)
113 maps.Copy(originalDiags, client.GetDiagnostics())
114
115 handler := func(params json.RawMessage) {
116 client.HandleDiagnostics(params)
117 var diagParams protocol.PublishDiagnosticsParams
118 if err := json.Unmarshal(params, &diagParams); err != nil {
119 return
120 }
121
122 path, err := diagParams.URI.Path()
123 if err != nil {
124 slog.Error("Failed to convert diagnostic URI to path", "uri", diagParams.URI, "error", err)
125 return
126 }
127
128 if path == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
129 select {
130 case diagChan <- struct{}{}:
131 default:
132 }
133 }
134 }
135
136 client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
137
138 if client.IsFileOpen(filePath) {
139 err := client.NotifyChange(ctx, filePath)
140 if err != nil {
141 continue
142 }
143 } else {
144 err := client.OpenFile(ctx, filePath)
145 if err != nil {
146 continue
147 }
148 }
149 }
150
151 select {
152 case <-diagChan:
153 case <-time.After(5 * time.Second):
154 case <-ctx.Done():
155 }
156}
157
158func hasDiagnosticsChanged(current, original map[protocol.DocumentURI][]protocol.Diagnostic) bool {
159 for uri, diags := range current {
160 origDiags, exists := original[uri]
161 if !exists || len(diags) != len(origDiags) {
162 return true
163 }
164 }
165 return false
166}
167
168func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
169 fileDiagnostics := []string{}
170 projectDiagnostics := []string{}
171
172 formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
173 severity := "Info"
174 switch diagnostic.Severity {
175 case protocol.SeverityError:
176 severity = "Error"
177 case protocol.SeverityWarning:
178 severity = "Warn"
179 case protocol.SeverityHint:
180 severity = "Hint"
181 }
182
183 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
184
185 sourceInfo := ""
186 if diagnostic.Source != "" {
187 sourceInfo = diagnostic.Source
188 } else if source != "" {
189 sourceInfo = source
190 }
191
192 codeInfo := ""
193 if diagnostic.Code != nil {
194 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
195 }
196
197 tagsInfo := ""
198 if len(diagnostic.Tags) > 0 {
199 tags := []string{}
200 for _, tag := range diagnostic.Tags {
201 switch tag {
202 case protocol.Unnecessary:
203 tags = append(tags, "unnecessary")
204 case protocol.Deprecated:
205 tags = append(tags, "deprecated")
206 }
207 }
208 if len(tags) > 0 {
209 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
210 }
211 }
212
213 return fmt.Sprintf("%s: %s [%s]%s%s %s",
214 severity,
215 location,
216 sourceInfo,
217 codeInfo,
218 tagsInfo,
219 diagnostic.Message)
220 }
221
222 for lspName, client := range lsps {
223 diagnostics := client.GetDiagnostics()
224 if len(diagnostics) > 0 {
225 for location, diags := range diagnostics {
226 path, err := location.Path()
227 if err != nil {
228 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
229 continue
230 }
231 isCurrentFile := path == filePath
232
233 for _, diag := range diags {
234 formattedDiag := formatDiagnostic(path, diag, lspName)
235
236 if isCurrentFile {
237 fileDiagnostics = append(fileDiagnostics, formattedDiag)
238 } else {
239 projectDiagnostics = append(projectDiagnostics, formattedDiag)
240 }
241 }
242 }
243 }
244 }
245
246 sort.Slice(fileDiagnostics, func(i, j int) bool {
247 iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
248 jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
249 if iIsError != jIsError {
250 return iIsError // Errors come first
251 }
252 return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
253 })
254
255 sort.Slice(projectDiagnostics, func(i, j int) bool {
256 iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
257 jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
258 if iIsError != jIsError {
259 return iIsError
260 }
261 return projectDiagnostics[i] < projectDiagnostics[j]
262 })
263
264 var output strings.Builder
265
266 if len(fileDiagnostics) > 0 {
267 output.WriteString("\n<file_diagnostics>\n")
268 if len(fileDiagnostics) > 10 {
269 output.WriteString(strings.Join(fileDiagnostics[:10], "\n"))
270 fmt.Fprintf(&output, "\n... and %d more diagnostics", len(fileDiagnostics)-10)
271 } else {
272 output.WriteString(strings.Join(fileDiagnostics, "\n"))
273 }
274 output.WriteString("\n</file_diagnostics>\n")
275 }
276
277 if len(projectDiagnostics) > 0 {
278 output.WriteString("\n<project_diagnostics>\n")
279 if len(projectDiagnostics) > 10 {
280 output.WriteString(strings.Join(projectDiagnostics[:10], "\n"))
281 fmt.Fprintf(&output, "\n... and %d more diagnostics", len(projectDiagnostics)-10)
282 } else {
283 output.WriteString(strings.Join(projectDiagnostics, "\n"))
284 }
285 output.WriteString("\n</project_diagnostics>\n")
286 }
287
288 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
289 fileErrors := countSeverity(fileDiagnostics, "Error")
290 fileWarnings := countSeverity(fileDiagnostics, "Warn")
291 projectErrors := countSeverity(projectDiagnostics, "Error")
292 projectWarnings := countSeverity(projectDiagnostics, "Warn")
293
294 output.WriteString("\n<diagnostic_summary>\n")
295 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
296 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
297 output.WriteString("</diagnostic_summary>\n")
298 }
299
300 return output.String()
301}
302
303func countSeverity(diagnostics []string, severity string) int {
304 count := 0
305 for _, diag := range diagnostics {
306 if strings.HasPrefix(diag, severity) {
307 count++
308 }
309 }
310 return count
311}