1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "log/slog"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/csync"
13 "github.com/charmbracelet/crush/internal/lsp"
14 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
15)
16
17type DiagnosticsParams struct {
18 FilePath string `json:"file_path"`
19}
20
21type diagnosticsTool struct {
22 lspClients *csync.Map[string, *lsp.Client]
23}
24
25const (
26 DiagnosticsToolName = "diagnostics"
27 diagnosticsDescription = `Get diagnostics for a file and/or project.
28WHEN TO USE THIS TOOL:
29- Use when you need to check for errors or warnings in your code
30- Helpful for debugging and ensuring code quality
31- Good for getting a quick overview of issues in a file or project
32HOW TO USE:
33- Provide a path to a file to get diagnostics for that file
34- Leave the path empty to get diagnostics for the entire project
35- Results are displayed in a structured format with severity levels
36FEATURES:
37- Displays errors, warnings, and hints
38- Groups diagnostics by severity
39- Provides detailed information about each diagnostic
40LIMITATIONS:
41- Results are limited to the diagnostics provided by the LSP clients
42- May not cover all possible issues in the code
43- Does not provide suggestions for fixing issues
44TIPS:
45- Use in conjunction with other tools for a comprehensive code review
46- Combine with the LSP client for real-time diagnostics
47`
48)
49
50func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool {
51 return &diagnosticsTool{
52 lspClients,
53 }
54}
55
56func (b *diagnosticsTool) Name() string {
57 return DiagnosticsToolName
58}
59
60func (b *diagnosticsTool) Info() ToolInfo {
61 return ToolInfo{
62 Name: DiagnosticsToolName,
63 Description: diagnosticsDescription,
64 Parameters: map[string]any{
65 "file_path": map[string]any{
66 "type": "string",
67 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
68 },
69 },
70 Required: []string{},
71 }
72}
73
74func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
75 var params DiagnosticsParams
76 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
77 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
78 }
79
80 if b.lspClients.Len() == 0 {
81 return NewTextErrorResponse("no LSP clients available"), nil
82 }
83 notifyLSPs(ctx, b.lspClients, params.FilePath)
84 output := getDiagnostics(params.FilePath, b.lspClients)
85 return NewTextResponse(output), nil
86}
87
88func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
89 if filepath == "" {
90 return
91 }
92 for client := range lsps.Seq() {
93 if !client.HandlesFile(filepath) {
94 continue
95 }
96 _ = client.OpenFileOnDemand(ctx, filepath)
97 _ = client.NotifyChange(ctx, filepath)
98 client.WaitForDiagnostics(ctx, 5*time.Second)
99 }
100}
101
102func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
103 fileDiagnostics := []string{}
104 projectDiagnostics := []string{}
105
106 for lspName, client := range lsps.Seq2() {
107 for location, diags := range client.GetDiagnostics() {
108 path, err := location.Path()
109 if err != nil {
110 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
111 continue
112 }
113 isCurrentFile := path == filePath
114 for _, diag := range diags {
115 formattedDiag := formatDiagnostic(path, diag, lspName)
116 if isCurrentFile {
117 fileDiagnostics = append(fileDiagnostics, formattedDiag)
118 } else {
119 projectDiagnostics = append(projectDiagnostics, formattedDiag)
120 }
121 }
122 }
123 }
124
125 sortDiagnostics(fileDiagnostics)
126 sortDiagnostics(projectDiagnostics)
127
128 var output strings.Builder
129 writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
130 writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
131
132 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
133 fileErrors := countSeverity(fileDiagnostics, "Error")
134 fileWarnings := countSeverity(fileDiagnostics, "Warn")
135 projectErrors := countSeverity(projectDiagnostics, "Error")
136 projectWarnings := countSeverity(projectDiagnostics, "Warn")
137 output.WriteString("\n<diagnostic_summary>\n")
138 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
139 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
140 output.WriteString("</diagnostic_summary>\n")
141 }
142
143 out := output.String()
144 slog.Info("Diagnostics", "output", fmt.Sprintf("%q", out))
145 return out
146}
147
148func writeDiagnostics(output *strings.Builder, tag string, in []string) {
149 if len(in) == 0 {
150 return
151 }
152 output.WriteString("\n<" + tag + ">\n")
153 if len(in) > 10 {
154 output.WriteString(strings.Join(in[:10], "\n"))
155 fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
156 } else {
157 output.WriteString(strings.Join(in, "\n"))
158 }
159 output.WriteString("\n</" + tag + ">\n")
160}
161
162func sortDiagnostics(in []string) []string {
163 sort.Slice(in, func(i, j int) bool {
164 iIsError := strings.HasPrefix(in[i], "Error")
165 jIsError := strings.HasPrefix(in[j], "Error")
166 if iIsError != jIsError {
167 return iIsError // Errors come first
168 }
169 return in[i] < in[j] // Then alphabetically
170 })
171 return in
172}
173
174func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
175 severity := "Info"
176 switch diagnostic.Severity {
177 case protocol.SeverityError:
178 severity = "Error"
179 case protocol.SeverityWarning:
180 severity = "Warn"
181 case protocol.SeverityHint:
182 severity = "Hint"
183 }
184
185 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
186
187 sourceInfo := ""
188 if diagnostic.Source != "" {
189 sourceInfo = diagnostic.Source
190 } else if source != "" {
191 sourceInfo = source
192 }
193
194 codeInfo := ""
195 if diagnostic.Code != nil {
196 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
197 }
198
199 tagsInfo := ""
200 if len(diagnostic.Tags) > 0 {
201 tags := []string{}
202 for _, tag := range diagnostic.Tags {
203 switch tag {
204 case protocol.Unnecessary:
205 tags = append(tags, "unnecessary")
206 case protocol.Deprecated:
207 tags = append(tags, "deprecated")
208 }
209 }
210 if len(tags) > 0 {
211 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
212 }
213 }
214
215 return fmt.Sprintf("%s: %s [%s]%s%s %s",
216 severity,
217 location,
218 sourceInfo,
219 codeInfo,
220 tagsInfo,
221 diagnostic.Message)
222}
223
224func countSeverity(diagnostics []string, severity string) int {
225 count := 0
226 for _, diag := range diagnostics {
227 if strings.HasPrefix(diag, severity) {
228 count++
229 }
230 }
231 return count
232}