1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/lsp"
13 "github.com/charmbracelet/crush/internal/lsp/protocol"
14)
15
16type DiagnosticsParams struct {
17 FilePath string `json:"file_path"`
18}
19type diagnosticsTool struct {
20 lspClients map[string]*lsp.Client
21}
22
23const (
24 DiagnosticsToolName = "diagnostics"
25 diagnosticsDescription = `Get diagnostics for a file and/or project.
26WHEN TO USE THIS TOOL:
27- Use when you need to check for errors or warnings in your code
28- Helpful for debugging and ensuring code quality
29- Good for getting a quick overview of issues in a file or project
30HOW TO USE:
31- Provide a path to a file to get diagnostics for that file
32- Leave the path empty to get diagnostics for the entire project
33- Results are displayed in a structured format with severity levels
34FEATURES:
35- Displays errors, warnings, and hints
36- Groups diagnostics by severity
37- Provides detailed information about each diagnostic
38LIMITATIONS:
39- Results are limited to the diagnostics provided by the LSP clients
40- May not cover all possible issues in the code
41- Does not provide suggestions for fixing issues
42TIPS:
43- Use in conjunction with other tools for a comprehensive code review
44- Combine with the LSP client for real-time diagnostics
45`
46)
47
48func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
49 return &diagnosticsTool{
50 lspClients,
51 }
52}
53
54func (b *diagnosticsTool) Info() ToolInfo {
55 return ToolInfo{
56 Name: DiagnosticsToolName,
57 Description: diagnosticsDescription,
58 Parameters: map[string]any{
59 "file_path": map[string]any{
60 "type": "string",
61 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
62 },
63 },
64 Required: []string{},
65 }
66}
67
68func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
69 var params DiagnosticsParams
70 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
71 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
72 }
73
74 lsps := b.lspClients
75
76 if len(lsps) == 0 {
77 return NewTextErrorResponse("no LSP clients available"), nil
78 }
79
80 if params.FilePath != "" {
81 notifyLspOpenFile(ctx, params.FilePath, lsps)
82 waitForLspDiagnostics(ctx, params.FilePath, lsps)
83 }
84
85 output := getDiagnostics(params.FilePath, lsps)
86
87 return NewTextResponse(output), nil
88}
89
90func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
91 for _, client := range lsps {
92 err := client.OpenFile(ctx, filePath)
93 if err != nil {
94 continue
95 }
96 }
97}
98
99func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
100 if len(lsps) == 0 {
101 return
102 }
103
104 diagChan := make(chan struct{}, 1)
105
106 for _, client := range lsps {
107 originalDiags := client.GetDiagnostics()
108
109 handler := func(params json.RawMessage) {
110 lsp.HandleDiagnostics(client, params)
111 var diagParams protocol.PublishDiagnosticsParams
112 if err := json.Unmarshal(params, &diagParams); err != nil {
113 return
114 }
115
116 path, err := diagParams.URI.Path()
117 if err != nil {
118 slog.Error("Failed to convert diagnostic URI to path", "uri", diagParams.URI, "error", err)
119 return
120 }
121
122 if path == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
123 select {
124 case diagChan <- struct{}{}:
125 default:
126 }
127 }
128 }
129
130 client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
131
132 if client.IsFileOpen(filePath) {
133 err := client.NotifyChange(ctx, filePath)
134 if err != nil {
135 continue
136 }
137 } else {
138 err := client.OpenFile(ctx, filePath)
139 if err != nil {
140 continue
141 }
142 }
143 }
144
145 select {
146 case <-diagChan:
147 case <-time.After(5 * time.Second):
148 case <-ctx.Done():
149 }
150}
151
152func hasDiagnosticsChanged(current, original map[protocol.DocumentURI][]protocol.Diagnostic) bool {
153 for uri, diags := range current {
154 origDiags, exists := original[uri]
155 if !exists || len(diags) != len(origDiags) {
156 return true
157 }
158 }
159 return false
160}
161
162func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
163 fileDiagnostics := []string{}
164 projectDiagnostics := []string{}
165
166 formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
167 severity := "Info"
168 switch diagnostic.Severity {
169 case protocol.SeverityError:
170 severity = "Error"
171 case protocol.SeverityWarning:
172 severity = "Warn"
173 case protocol.SeverityHint:
174 severity = "Hint"
175 }
176
177 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
178
179 sourceInfo := ""
180 if diagnostic.Source != "" {
181 sourceInfo = diagnostic.Source
182 } else if source != "" {
183 sourceInfo = source
184 }
185
186 codeInfo := ""
187 if diagnostic.Code != nil {
188 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
189 }
190
191 tagsInfo := ""
192 if len(diagnostic.Tags) > 0 {
193 tags := []string{}
194 for _, tag := range diagnostic.Tags {
195 switch tag {
196 case protocol.Unnecessary:
197 tags = append(tags, "unnecessary")
198 case protocol.Deprecated:
199 tags = append(tags, "deprecated")
200 }
201 }
202 if len(tags) > 0 {
203 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
204 }
205 }
206
207 return fmt.Sprintf("%s: %s [%s]%s%s %s",
208 severity,
209 location,
210 sourceInfo,
211 codeInfo,
212 tagsInfo,
213 diagnostic.Message)
214 }
215
216 for lspName, client := range lsps {
217 diagnostics := client.GetDiagnostics()
218 if len(diagnostics) > 0 {
219 for location, diags := range diagnostics {
220 path, err := location.Path()
221 if err != nil {
222 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
223 continue
224 }
225 isCurrentFile := path == filePath
226
227 for _, diag := range diags {
228 formattedDiag := formatDiagnostic(path, diag, lspName)
229
230 if isCurrentFile {
231 fileDiagnostics = append(fileDiagnostics, formattedDiag)
232 } else {
233 projectDiagnostics = append(projectDiagnostics, formattedDiag)
234 }
235 }
236 }
237 }
238 }
239
240 sort.Slice(fileDiagnostics, func(i, j int) bool {
241 iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
242 jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
243 if iIsError != jIsError {
244 return iIsError // Errors come first
245 }
246 return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
247 })
248
249 sort.Slice(projectDiagnostics, func(i, j int) bool {
250 iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
251 jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
252 if iIsError != jIsError {
253 return iIsError
254 }
255 return projectDiagnostics[i] < projectDiagnostics[j]
256 })
257
258 var output strings.Builder
259
260 if len(fileDiagnostics) > 0 {
261 output.WriteString("\n<file_diagnostics>\n")
262 if len(fileDiagnostics) > 10 {
263 output.WriteString(strings.Join(fileDiagnostics[:10], "\n"))
264 fmt.Fprintf(&output, "\n... and %d more diagnostics", len(fileDiagnostics)-10)
265 } else {
266 output.WriteString(strings.Join(fileDiagnostics, "\n"))
267 }
268 output.WriteString("\n</file_diagnostics>\n")
269 }
270
271 if len(projectDiagnostics) > 0 {
272 output.WriteString("\n<project_diagnostics>\n")
273 if len(projectDiagnostics) > 10 {
274 output.WriteString(strings.Join(projectDiagnostics[:10], "\n"))
275 fmt.Fprintf(&output, "\n... and %d more diagnostics", len(projectDiagnostics)-10)
276 } else {
277 output.WriteString(strings.Join(projectDiagnostics, "\n"))
278 }
279 output.WriteString("\n</project_diagnostics>\n")
280 }
281
282 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
283 fileErrors := countSeverity(fileDiagnostics, "Error")
284 fileWarnings := countSeverity(fileDiagnostics, "Warn")
285 projectErrors := countSeverity(projectDiagnostics, "Error")
286 projectWarnings := countSeverity(projectDiagnostics, "Warn")
287
288 output.WriteString("\n<diagnostic_summary>\n")
289 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
290 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
291 output.WriteString("</diagnostic_summary>\n")
292 }
293
294 return output.String()
295}
296
297func countSeverity(diagnostics []string, severity string) int {
298 count := 0
299 for _, diag := range diagnostics {
300 if strings.HasPrefix(diag, severity) {
301 count++
302 }
303 }
304 return count
305}