1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/ai"
13 "github.com/charmbracelet/crush/internal/lsp"
14 "github.com/charmbracelet/crush/internal/lsp/protocol"
15)
16
17type DiagnosticsParams struct {
18 FilePath string `json:"file_path,omitempty" description:"The path to the file to get diagnostics for (leave w empty for project diagnostics)"`
19}
20
21const (
22 DiagnosticsToolName = "diagnostics"
23)
24
25func NewDiagnosticsTool(lsps map[string]*lsp.Client) ai.AgentTool {
26 return ai.NewTypedToolFunc(
27 DiagnosticsToolName,
28 `Get diagnostics for a file and/or project.
29WHEN TO USE THIS TOOL:
30- Use when you need to check for errors or warnings in your code
31- Helpful for debugging and ensuring code quality
32- Good for getting a quick overview of issues in a file or project
33HOW TO USE:
34- Provide a path to a file to get diagnostics for that file
35- Leave the path empty to get diagnostics for the entire project
36- Results are displayed in a structured format with severity levels
37FEATURES:
38- Displays errors, warnings, and hints
39- Groups diagnostics by severity
40- Provides detailed information about each diagnostic
41LIMITATIONS:
42- Results are limited to the diagnostics provided by the LSP clients
43- May not cover all possible issues in the code
44- Does not provide suggestions for fixing issues
45TIPS:
46- Use in conjunction with other tools for a comprehensive code review
47- Combine with the LSP client for real-time diagnostics`,
48
49 func(ctx context.Context, params DiagnosticsParams, call ai.ToolCall) (ai.ToolResponse, error) {
50 if len(lsps) == 0 {
51 return ai.NewTextErrorResponse("no LSP clients available"), nil
52 }
53
54 if params.FilePath != "" {
55 notifyLspOpenFile(ctx, params.FilePath, lsps)
56 waitForLspDiagnostics(ctx, params.FilePath, lsps)
57 }
58
59 output := getDiagnostics(params.FilePath, lsps)
60
61 return ai.NewTextResponse(output), nil
62 },
63 )
64}
65
66func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
67 for _, client := range lsps {
68 err := client.OpenFile(ctx, filePath)
69 if err != nil {
70 continue
71 }
72 }
73}
74
75func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
76 if len(lsps) == 0 {
77 return
78 }
79
80 diagChan := make(chan struct{}, 1)
81
82 for _, client := range lsps {
83 originalDiags := client.GetDiagnostics()
84
85 handler := func(params json.RawMessage) {
86 lsp.HandleDiagnostics(client, params)
87 var diagParams protocol.PublishDiagnosticsParams
88 if err := json.Unmarshal(params, &diagParams); err != nil {
89 return
90 }
91
92 path, err := diagParams.URI.Path()
93 if err != nil {
94 slog.Error("Failed to convert diagnostic URI to path", "uri", diagParams.URI, "error", err)
95 return
96 }
97
98 if path == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
99 select {
100 case diagChan <- struct{}{}:
101 default:
102 }
103 }
104 }
105
106 client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
107
108 if client.IsFileOpen(filePath) {
109 err := client.NotifyChange(ctx, filePath)
110 if err != nil {
111 continue
112 }
113 } else {
114 err := client.OpenFile(ctx, filePath)
115 if err != nil {
116 continue
117 }
118 }
119 }
120
121 select {
122 case <-diagChan:
123 case <-time.After(5 * time.Second):
124 case <-ctx.Done():
125 }
126}
127
128func hasDiagnosticsChanged(current, original map[protocol.DocumentURI][]protocol.Diagnostic) bool {
129 for uri, diags := range current {
130 origDiags, exists := original[uri]
131 if !exists || len(diags) != len(origDiags) {
132 return true
133 }
134 }
135 return false
136}
137
138func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
139 fileDiagnostics := []string{}
140 projectDiagnostics := []string{}
141
142 formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
143 severity := "Info"
144 switch diagnostic.Severity {
145 case protocol.SeverityError:
146 severity = "Error"
147 case protocol.SeverityWarning:
148 severity = "Warn"
149 case protocol.SeverityHint:
150 severity = "Hint"
151 }
152
153 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
154
155 sourceInfo := ""
156 if diagnostic.Source != "" {
157 sourceInfo = diagnostic.Source
158 } else if source != "" {
159 sourceInfo = source
160 }
161
162 codeInfo := ""
163 if diagnostic.Code != nil {
164 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
165 }
166
167 tagsInfo := ""
168 if len(diagnostic.Tags) > 0 {
169 tags := []string{}
170 for _, tag := range diagnostic.Tags {
171 switch tag {
172 case protocol.Unnecessary:
173 tags = append(tags, "unnecessary")
174 case protocol.Deprecated:
175 tags = append(tags, "deprecated")
176 }
177 }
178 if len(tags) > 0 {
179 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
180 }
181 }
182
183 return fmt.Sprintf("%s: %s [%s]%s%s %s",
184 severity,
185 location,
186 sourceInfo,
187 codeInfo,
188 tagsInfo,
189 diagnostic.Message)
190 }
191
192 for lspName, client := range lsps {
193 diagnostics := client.GetDiagnostics()
194 if len(diagnostics) > 0 {
195 for location, diags := range diagnostics {
196 path, err := location.Path()
197 if err != nil {
198 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
199 continue
200 }
201 isCurrentFile := path == filePath
202
203 for _, diag := range diags {
204 formattedDiag := formatDiagnostic(path, diag, lspName)
205
206 if isCurrentFile {
207 fileDiagnostics = append(fileDiagnostics, formattedDiag)
208 } else {
209 projectDiagnostics = append(projectDiagnostics, formattedDiag)
210 }
211 }
212 }
213 }
214 }
215
216 sort.Slice(fileDiagnostics, func(i, j int) bool {
217 iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
218 jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
219 if iIsError != jIsError {
220 return iIsError // Errors come first
221 }
222 return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
223 })
224
225 sort.Slice(projectDiagnostics, func(i, j int) bool {
226 iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
227 jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
228 if iIsError != jIsError {
229 return iIsError
230 }
231 return projectDiagnostics[i] < projectDiagnostics[j]
232 })
233
234 var output strings.Builder
235
236 if len(fileDiagnostics) > 0 {
237 output.WriteString("\n<file_diagnostics>\n")
238 if len(fileDiagnostics) > 10 {
239 output.WriteString(strings.Join(fileDiagnostics[:10], "\n"))
240 fmt.Fprintf(&output, "\n... and %d more diagnostics", len(fileDiagnostics)-10)
241 } else {
242 output.WriteString(strings.Join(fileDiagnostics, "\n"))
243 }
244 output.WriteString("\n</file_diagnostics>\n")
245 }
246
247 if len(projectDiagnostics) > 0 {
248 output.WriteString("\n<project_diagnostics>\n")
249 if len(projectDiagnostics) > 10 {
250 output.WriteString(strings.Join(projectDiagnostics[:10], "\n"))
251 fmt.Fprintf(&output, "\n... and %d more diagnostics", len(projectDiagnostics)-10)
252 } else {
253 output.WriteString(strings.Join(projectDiagnostics, "\n"))
254 }
255 output.WriteString("\n</project_diagnostics>\n")
256 }
257
258 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
259 fileErrors := countSeverity(fileDiagnostics, "Error")
260 fileWarnings := countSeverity(fileDiagnostics, "Warn")
261 projectErrors := countSeverity(projectDiagnostics, "Error")
262 projectWarnings := countSeverity(projectDiagnostics, "Warn")
263
264 output.WriteString("\n<diagnostic_summary>\n")
265 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
266 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
267 output.WriteString("</diagnostic_summary>\n")
268 }
269
270 return output.String()
271}
272
273func countSeverity(diagnostics []string, severity string) int {
274 count := 0
275 for _, diag := range diagnostics {
276 if strings.HasPrefix(diag, severity) {
277 count++
278 }
279 }
280 return count
281}