1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/lsp"
13 "github.com/charmbracelet/crush/internal/lsp/protocol"
14)
15
16type DiagnosticsParams struct {
17 FilePath string `json:"file_path"`
18}
19type diagnosticsTool struct {
20 lspClients map[string]*lsp.Client
21}
22
23const (
24 DiagnosticsToolName = "diagnostics"
25 diagnosticsDescription = `Get diagnostics for a file and/or project.
26WHEN TO USE THIS TOOL:
27- Use when you need to check for errors or warnings in your code
28- Helpful for debugging and ensuring code quality
29- Good for getting a quick overview of issues in a file or project
30HOW TO USE:
31- Provide a path to a file to get diagnostics for that file
32- Leave the path empty to get diagnostics for the entire project
33- Results are displayed in a structured format with severity levels
34FEATURES:
35- Displays errors, warnings, and hints
36- Groups diagnostics by severity
37- Provides detailed information about each diagnostic
38LIMITATIONS:
39- Results are limited to the diagnostics provided by the LSP clients
40- May not cover all possible issues in the code
41- Does not provide suggestions for fixing issues
42TIPS:
43- Use in conjunction with other tools for a comprehensive code review
44- Combine with the LSP client for real-time diagnostics
45`
46)
47
48func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
49 return &diagnosticsTool{
50 lspClients,
51 }
52}
53
54func (b *diagnosticsTool) Name() string {
55 return DiagnosticsToolName
56}
57
58func (b *diagnosticsTool) Info() ToolInfo {
59 return ToolInfo{
60 Name: DiagnosticsToolName,
61 Description: diagnosticsDescription,
62 Parameters: map[string]any{
63 "file_path": map[string]any{
64 "type": "string",
65 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
66 },
67 },
68 Required: []string{},
69 }
70}
71
72func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
73 var params DiagnosticsParams
74 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
75 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
76 }
77
78 lsps := b.lspClients
79
80 if len(lsps) == 0 {
81 return NewTextErrorResponse("no LSP clients available"), nil
82 }
83
84 if params.FilePath != "" {
85 notifyLspOpenFile(ctx, params.FilePath, lsps)
86 waitForLspDiagnostics(ctx, params.FilePath, lsps)
87 }
88
89 output := getDiagnostics(params.FilePath, lsps)
90
91 return NewTextResponse(output), nil
92}
93
94func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
95 for _, client := range lsps {
96 err := client.OpenFile(ctx, filePath)
97 if err != nil {
98 continue
99 }
100 }
101}
102
103func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
104 if len(lsps) == 0 {
105 return
106 }
107
108 diagChan := make(chan struct{}, 1)
109
110 for _, client := range lsps {
111 originalDiags := client.GetDiagnostics()
112
113 handler := func(params json.RawMessage) {
114 lsp.HandleDiagnostics(client, params)
115 var diagParams protocol.PublishDiagnosticsParams
116 if err := json.Unmarshal(params, &diagParams); err != nil {
117 return
118 }
119
120 path, err := diagParams.URI.Path()
121 if err != nil {
122 slog.Error("Failed to convert diagnostic URI to path", "uri", diagParams.URI, "error", err)
123 return
124 }
125
126 if path == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
127 select {
128 case diagChan <- struct{}{}:
129 default:
130 }
131 }
132 }
133
134 client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
135
136 if client.IsFileOpen(filePath) {
137 err := client.NotifyChange(ctx, filePath)
138 if err != nil {
139 continue
140 }
141 } else {
142 err := client.OpenFile(ctx, filePath)
143 if err != nil {
144 continue
145 }
146 }
147 }
148
149 select {
150 case <-diagChan:
151 case <-time.After(5 * time.Second):
152 case <-ctx.Done():
153 }
154}
155
156func hasDiagnosticsChanged(current, original map[protocol.DocumentURI][]protocol.Diagnostic) bool {
157 for uri, diags := range current {
158 origDiags, exists := original[uri]
159 if !exists || len(diags) != len(origDiags) {
160 return true
161 }
162 }
163 return false
164}
165
166func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
167 fileDiagnostics := []string{}
168 projectDiagnostics := []string{}
169
170 formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
171 severity := "Info"
172 switch diagnostic.Severity {
173 case protocol.SeverityError:
174 severity = "Error"
175 case protocol.SeverityWarning:
176 severity = "Warn"
177 case protocol.SeverityHint:
178 severity = "Hint"
179 }
180
181 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
182
183 sourceInfo := ""
184 if diagnostic.Source != "" {
185 sourceInfo = diagnostic.Source
186 } else if source != "" {
187 sourceInfo = source
188 }
189
190 codeInfo := ""
191 if diagnostic.Code != nil {
192 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
193 }
194
195 tagsInfo := ""
196 if len(diagnostic.Tags) > 0 {
197 tags := []string{}
198 for _, tag := range diagnostic.Tags {
199 switch tag {
200 case protocol.Unnecessary:
201 tags = append(tags, "unnecessary")
202 case protocol.Deprecated:
203 tags = append(tags, "deprecated")
204 }
205 }
206 if len(tags) > 0 {
207 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
208 }
209 }
210
211 return fmt.Sprintf("%s: %s [%s]%s%s %s",
212 severity,
213 location,
214 sourceInfo,
215 codeInfo,
216 tagsInfo,
217 diagnostic.Message)
218 }
219
220 for lspName, client := range lsps {
221 diagnostics := client.GetDiagnostics()
222 if len(diagnostics) > 0 {
223 for location, diags := range diagnostics {
224 path, err := location.Path()
225 if err != nil {
226 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
227 continue
228 }
229 isCurrentFile := path == filePath
230
231 for _, diag := range diags {
232 formattedDiag := formatDiagnostic(path, diag, lspName)
233
234 if isCurrentFile {
235 fileDiagnostics = append(fileDiagnostics, formattedDiag)
236 } else {
237 projectDiagnostics = append(projectDiagnostics, formattedDiag)
238 }
239 }
240 }
241 }
242 }
243
244 sort.Slice(fileDiagnostics, func(i, j int) bool {
245 iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
246 jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
247 if iIsError != jIsError {
248 return iIsError // Errors come first
249 }
250 return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
251 })
252
253 sort.Slice(projectDiagnostics, func(i, j int) bool {
254 iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
255 jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
256 if iIsError != jIsError {
257 return iIsError
258 }
259 return projectDiagnostics[i] < projectDiagnostics[j]
260 })
261
262 var output strings.Builder
263
264 if len(fileDiagnostics) > 0 {
265 output.WriteString("\n<file_diagnostics>\n")
266 if len(fileDiagnostics) > 10 {
267 output.WriteString(strings.Join(fileDiagnostics[:10], "\n"))
268 fmt.Fprintf(&output, "\n... and %d more diagnostics", len(fileDiagnostics)-10)
269 } else {
270 output.WriteString(strings.Join(fileDiagnostics, "\n"))
271 }
272 output.WriteString("\n</file_diagnostics>\n")
273 }
274
275 if len(projectDiagnostics) > 0 {
276 output.WriteString("\n<project_diagnostics>\n")
277 if len(projectDiagnostics) > 10 {
278 output.WriteString(strings.Join(projectDiagnostics[:10], "\n"))
279 fmt.Fprintf(&output, "\n... and %d more diagnostics", len(projectDiagnostics)-10)
280 } else {
281 output.WriteString(strings.Join(projectDiagnostics, "\n"))
282 }
283 output.WriteString("\n</project_diagnostics>\n")
284 }
285
286 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
287 fileErrors := countSeverity(fileDiagnostics, "Error")
288 fileWarnings := countSeverity(fileDiagnostics, "Warn")
289 projectErrors := countSeverity(projectDiagnostics, "Error")
290 projectWarnings := countSeverity(projectDiagnostics, "Warn")
291
292 output.WriteString("\n<diagnostic_summary>\n")
293 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
294 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
295 output.WriteString("</diagnostic_summary>\n")
296 }
297
298 return output.String()
299}
300
301func countSeverity(diagnostics []string, severity string) int {
302 count := 0
303 for _, diag := range diagnostics {
304 if strings.HasPrefix(diag, severity) {
305 count++
306 }
307 }
308 return count
309}