1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "log/slog"
8 "sort"
9 "strings"
10 "time"
11
12 "charm.land/fantasy"
13 "github.com/charmbracelet/crush/internal/lsp"
14 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
15)
16
17type DiagnosticsParams struct {
18 FilePath string `json:"file_path,omitempty" description:"The path to the file to get diagnostics for (leave empty for project diagnostics)"`
19}
20
21const DiagnosticsToolName = "lsp_diagnostics"
22
23//go:embed diagnostics.md
24var diagnosticsDescription []byte
25
26func NewDiagnosticsTool(lspManager *lsp.Manager) fantasy.AgentTool {
27 return fantasy.NewAgentTool(
28 DiagnosticsToolName,
29 string(diagnosticsDescription),
30 func(ctx context.Context, params DiagnosticsParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
31 if lspManager.Clients().Len() == 0 {
32 return fantasy.NewTextErrorResponse("no LSP clients available"), nil
33 }
34 notifyLSPs(ctx, lspManager, params.FilePath)
35 output := getDiagnostics(params.FilePath, lspManager)
36 return fantasy.NewTextResponse(output), nil
37 })
38}
39
40func notifyLSPs(
41 ctx context.Context,
42 manager *lsp.Manager,
43 filepath string,
44) {
45 if filepath == "" {
46 return
47 }
48
49 if manager == nil {
50 return
51 }
52
53 manager.Start(ctx, filepath)
54
55 for client := range manager.Clients().Seq() {
56 if !client.HandlesFile(filepath) {
57 continue
58 }
59 _ = client.OpenFileOnDemand(ctx, filepath)
60 _ = client.NotifyChange(ctx, filepath)
61 client.WaitForDiagnostics(ctx, 5*time.Second)
62 }
63}
64
65func getDiagnostics(filePath string, manager *lsp.Manager) string {
66 if manager == nil {
67 return ""
68 }
69
70 var fileDiagnostics []string
71 var projectDiagnostics []string
72
73 for lspName, client := range manager.Clients().Seq2() {
74 for location, diags := range client.GetDiagnostics() {
75 path, err := location.Path()
76 if err != nil {
77 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
78 continue
79 }
80 isCurrentFile := path == filePath
81 for _, diag := range diags {
82 formattedDiag := formatDiagnostic(path, diag, lspName)
83 if isCurrentFile {
84 fileDiagnostics = append(fileDiagnostics, formattedDiag)
85 } else {
86 projectDiagnostics = append(projectDiagnostics, formattedDiag)
87 }
88 }
89 }
90 }
91
92 sortDiagnostics(fileDiagnostics)
93 sortDiagnostics(projectDiagnostics)
94
95 var output strings.Builder
96 writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
97 writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
98
99 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
100 fileErrors := countSeverity(fileDiagnostics, "Error")
101 fileWarnings := countSeverity(fileDiagnostics, "Warn")
102 projectErrors := countSeverity(projectDiagnostics, "Error")
103 projectWarnings := countSeverity(projectDiagnostics, "Warn")
104 output.WriteString("\n<diagnostic_summary>\n")
105 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
106 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
107 output.WriteString("</diagnostic_summary>\n")
108 }
109
110 out := output.String()
111 slog.Debug("Diagnostics", "output", out)
112 return out
113}
114
115func writeDiagnostics(output *strings.Builder, tag string, in []string) {
116 if len(in) == 0 {
117 return
118 }
119 output.WriteString("\n<" + tag + ">\n")
120 if len(in) > 10 {
121 output.WriteString(strings.Join(in[:10], "\n"))
122 fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
123 } else {
124 output.WriteString(strings.Join(in, "\n"))
125 }
126 output.WriteString("\n</" + tag + ">\n")
127}
128
129func sortDiagnostics(in []string) []string {
130 sort.Slice(in, func(i, j int) bool {
131 iIsError := strings.HasPrefix(in[i], "Error")
132 jIsError := strings.HasPrefix(in[j], "Error")
133 if iIsError != jIsError {
134 return iIsError // Errors come first
135 }
136 return in[i] < in[j] // Then alphabetically
137 })
138 return in
139}
140
141func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
142 severity := "Info"
143 switch diagnostic.Severity {
144 case protocol.SeverityError:
145 severity = "Error"
146 case protocol.SeverityWarning:
147 severity = "Warn"
148 case protocol.SeverityHint:
149 severity = "Hint"
150 }
151
152 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
153
154 sourceInfo := source
155 if diagnostic.Source != "" {
156 sourceInfo += " " + diagnostic.Source
157 }
158
159 codeInfo := ""
160 if diagnostic.Code != nil {
161 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
162 }
163
164 tagsInfo := ""
165 if len(diagnostic.Tags) > 0 {
166 var tags []string
167 for _, tag := range diagnostic.Tags {
168 switch tag {
169 case protocol.Unnecessary:
170 tags = append(tags, "unnecessary")
171 case protocol.Deprecated:
172 tags = append(tags, "deprecated")
173 }
174 }
175 if len(tags) > 0 {
176 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
177 }
178 }
179
180 return fmt.Sprintf("%s: %s [%s]%s%s %s",
181 severity,
182 location,
183 sourceInfo,
184 codeInfo,
185 tagsInfo,
186 diagnostic.Message)
187}
188
189func countSeverity(diagnostics []string, severity string) int {
190 count := 0
191 for _, diag := range diagnostics {
192 if strings.HasPrefix(diag, severity) {
193 count++
194 }
195 }
196 return count
197}