1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "sort"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/lsp"
15 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
16)
17
18type DiagnosticsParams struct {
19 FilePath string `json:"file_path"`
20}
21
22type diagnosticsTool struct {
23 lspClients *csync.Map[string, *lsp.Client]
24}
25
26const DiagnosticsToolName = "lsp_diagnostics"
27
28//go:embed diagnostics.md
29var diagnosticsDescription []byte
30
31func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool {
32 return &diagnosticsTool{
33 lspClients,
34 }
35}
36
37func (b *diagnosticsTool) Name() string {
38 return DiagnosticsToolName
39}
40
41func (b *diagnosticsTool) Info() ToolInfo {
42 return ToolInfo{
43 Name: DiagnosticsToolName,
44 Description: string(diagnosticsDescription),
45 Parameters: map[string]any{
46 "file_path": map[string]any{
47 "type": "string",
48 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
49 },
50 },
51 Required: []string{},
52 }
53}
54
55func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
56 var params DiagnosticsParams
57 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
58 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
59 }
60
61 if b.lspClients.Len() == 0 {
62 return NewTextErrorResponse("no LSP clients available"), nil
63 }
64 notifyLSPs(ctx, b.lspClients, params.FilePath)
65 output := getDiagnostics(params.FilePath, b.lspClients)
66 return NewTextResponse(output), nil
67}
68
69func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
70 if filepath == "" {
71 return
72 }
73 for client := range lsps.Seq() {
74 if !client.HandlesFile(filepath) {
75 continue
76 }
77 _ = client.OpenFileOnDemand(ctx, filepath)
78 _ = client.NotifyChange(ctx, filepath)
79 client.WaitForDiagnostics(ctx, 5*time.Second)
80 }
81}
82
83func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
84 fileDiagnostics := []string{}
85 projectDiagnostics := []string{}
86
87 for lspName, client := range lsps.Seq2() {
88 for location, diags := range client.GetDiagnostics() {
89 path, err := location.Path()
90 if err != nil {
91 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
92 continue
93 }
94 isCurrentFile := path == filePath
95 for _, diag := range diags {
96 formattedDiag := formatDiagnostic(path, diag, lspName)
97 if isCurrentFile {
98 fileDiagnostics = append(fileDiagnostics, formattedDiag)
99 } else {
100 projectDiagnostics = append(projectDiagnostics, formattedDiag)
101 }
102 }
103 }
104 }
105
106 sortDiagnostics(fileDiagnostics)
107 sortDiagnostics(projectDiagnostics)
108
109 var output strings.Builder
110 writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
111 writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
112
113 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
114 fileErrors := countSeverity(fileDiagnostics, "Error")
115 fileWarnings := countSeverity(fileDiagnostics, "Warn")
116 projectErrors := countSeverity(projectDiagnostics, "Error")
117 projectWarnings := countSeverity(projectDiagnostics, "Warn")
118 output.WriteString("\n<diagnostic_summary>\n")
119 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
120 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
121 output.WriteString("</diagnostic_summary>\n")
122 }
123
124 out := output.String()
125 slog.Info("Diagnostics", "output", out)
126 return out
127}
128
129func writeDiagnostics(output *strings.Builder, tag string, in []string) {
130 if len(in) == 0 {
131 return
132 }
133 output.WriteString("\n<" + tag + ">\n")
134 if len(in) > 10 {
135 output.WriteString(strings.Join(in[:10], "\n"))
136 fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
137 } else {
138 output.WriteString(strings.Join(in, "\n"))
139 }
140 output.WriteString("\n</" + tag + ">\n")
141}
142
143func sortDiagnostics(in []string) []string {
144 sort.Slice(in, func(i, j int) bool {
145 iIsError := strings.HasPrefix(in[i], "Error")
146 jIsError := strings.HasPrefix(in[j], "Error")
147 if iIsError != jIsError {
148 return iIsError // Errors come first
149 }
150 return in[i] < in[j] // Then alphabetically
151 })
152 return in
153}
154
155func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
156 severity := "Info"
157 switch diagnostic.Severity {
158 case protocol.SeverityError:
159 severity = "Error"
160 case protocol.SeverityWarning:
161 severity = "Warn"
162 case protocol.SeverityHint:
163 severity = "Hint"
164 }
165
166 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
167
168 sourceInfo := ""
169 if diagnostic.Source != "" {
170 sourceInfo = diagnostic.Source
171 } else if source != "" {
172 sourceInfo = source
173 }
174
175 codeInfo := ""
176 if diagnostic.Code != nil {
177 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
178 }
179
180 tagsInfo := ""
181 if len(diagnostic.Tags) > 0 {
182 tags := []string{}
183 for _, tag := range diagnostic.Tags {
184 switch tag {
185 case protocol.Unnecessary:
186 tags = append(tags, "unnecessary")
187 case protocol.Deprecated:
188 tags = append(tags, "deprecated")
189 }
190 }
191 if len(tags) > 0 {
192 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
193 }
194 }
195
196 return fmt.Sprintf("%s: %s [%s]%s%s %s",
197 severity,
198 location,
199 sourceInfo,
200 codeInfo,
201 tagsInfo,
202 diagnostic.Message)
203}
204
205func countSeverity(diagnostics []string, severity string) int {
206 count := 0
207 for _, diag := range diagnostics {
208 if strings.HasPrefix(diag, severity) {
209 count++
210 }
211 }
212 return count
213}