1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "sort"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/lsp"
15 "github.com/charmbracelet/crush/internal/proto"
16 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
17)
18
19type DiagnosticsParams = proto.DiagnosticsParams
20
21type diagnosticsTool struct {
22 lspClients *csync.Map[string, *lsp.Client]
23}
24
25const DiagnosticsToolName = "diagnostics"
26
27//go:embed diagnostics.md
28var diagnosticsDescription []byte
29
30func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool {
31 return &diagnosticsTool{
32 lspClients,
33 }
34}
35
36func (b *diagnosticsTool) Name() string {
37 return DiagnosticsToolName
38}
39
40func (b *diagnosticsTool) Info() ToolInfo {
41 return ToolInfo{
42 Name: DiagnosticsToolName,
43 Description: string(diagnosticsDescription),
44 Parameters: map[string]any{
45 "file_path": map[string]any{
46 "type": "string",
47 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
48 },
49 },
50 Required: []string{},
51 }
52}
53
54func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
55 var params DiagnosticsParams
56 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
57 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
58 }
59
60 if b.lspClients.Len() == 0 {
61 return NewTextErrorResponse("no LSP clients available"), nil
62 }
63 notifyLSPs(ctx, b.lspClients, params.FilePath)
64 output := getDiagnostics(params.FilePath, b.lspClients)
65 return NewTextResponse(output), nil
66}
67
68func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
69 if filepath == "" {
70 return
71 }
72 for client := range lsps.Seq() {
73 if !client.HandlesFile(filepath) {
74 continue
75 }
76 _ = client.OpenFileOnDemand(ctx, filepath)
77 _ = client.NotifyChange(ctx, filepath)
78 client.WaitForDiagnostics(ctx, 5*time.Second)
79 }
80}
81
82func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
83 fileDiagnostics := []string{}
84 projectDiagnostics := []string{}
85
86 for lspName, client := range lsps.Seq2() {
87 for location, diags := range client.GetDiagnostics() {
88 path, err := location.Path()
89 if err != nil {
90 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
91 continue
92 }
93 isCurrentFile := path == filePath
94 for _, diag := range diags {
95 formattedDiag := formatDiagnostic(path, diag, lspName)
96 if isCurrentFile {
97 fileDiagnostics = append(fileDiagnostics, formattedDiag)
98 } else {
99 projectDiagnostics = append(projectDiagnostics, formattedDiag)
100 }
101 }
102 }
103 }
104
105 sortDiagnostics(fileDiagnostics)
106 sortDiagnostics(projectDiagnostics)
107
108 var output strings.Builder
109 writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
110 writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
111
112 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
113 fileErrors := countSeverity(fileDiagnostics, "Error")
114 fileWarnings := countSeverity(fileDiagnostics, "Warn")
115 projectErrors := countSeverity(projectDiagnostics, "Error")
116 projectWarnings := countSeverity(projectDiagnostics, "Warn")
117 output.WriteString("\n<diagnostic_summary>\n")
118 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
119 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
120 output.WriteString("</diagnostic_summary>\n")
121 }
122
123 out := output.String()
124 slog.Info("Diagnostics", "output", fmt.Sprintf("%q", out))
125 return out
126}
127
128func writeDiagnostics(output *strings.Builder, tag string, in []string) {
129 if len(in) == 0 {
130 return
131 }
132 output.WriteString("\n<" + tag + ">\n")
133 if len(in) > 10 {
134 output.WriteString(strings.Join(in[:10], "\n"))
135 fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
136 } else {
137 output.WriteString(strings.Join(in, "\n"))
138 }
139 output.WriteString("\n</" + tag + ">\n")
140}
141
142func sortDiagnostics(in []string) []string {
143 sort.Slice(in, func(i, j int) bool {
144 iIsError := strings.HasPrefix(in[i], "Error")
145 jIsError := strings.HasPrefix(in[j], "Error")
146 if iIsError != jIsError {
147 return iIsError // Errors come first
148 }
149 return in[i] < in[j] // Then alphabetically
150 })
151 return in
152}
153
154func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
155 severity := "Info"
156 switch diagnostic.Severity {
157 case protocol.SeverityError:
158 severity = "Error"
159 case protocol.SeverityWarning:
160 severity = "Warn"
161 case protocol.SeverityHint:
162 severity = "Hint"
163 }
164
165 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
166
167 sourceInfo := ""
168 if diagnostic.Source != "" {
169 sourceInfo = diagnostic.Source
170 } else if source != "" {
171 sourceInfo = source
172 }
173
174 codeInfo := ""
175 if diagnostic.Code != nil {
176 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
177 }
178
179 tagsInfo := ""
180 if len(diagnostic.Tags) > 0 {
181 tags := []string{}
182 for _, tag := range diagnostic.Tags {
183 switch tag {
184 case protocol.Unnecessary:
185 tags = append(tags, "unnecessary")
186 case protocol.Deprecated:
187 tags = append(tags, "deprecated")
188 }
189 }
190 if len(tags) > 0 {
191 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
192 }
193 }
194
195 return fmt.Sprintf("%s: %s [%s]%s%s %s",
196 severity,
197 location,
198 sourceInfo,
199 codeInfo,
200 tagsInfo,
201 diagnostic.Message)
202}
203
204func countSeverity(diagnostics []string, severity string) int {
205 count := 0
206 for _, diag := range diagnostics {
207 if strings.HasPrefix(diag, severity) {
208 count++
209 }
210 }
211 return count
212}