1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "log/slog"
8 "sort"
9 "strings"
10 "time"
11
12 "charm.land/fantasy"
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/lsp"
15 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
16)
17
18type DiagnosticsParams struct {
19 FilePath string `json:"file_path,omitempty" description:"The path to the file to get diagnostics for (leave w empty for project diagnostics)"`
20}
21
22const DiagnosticsToolName = "lsp_diagnostics"
23
24//go:embed diagnostics.md
25var diagnosticsDescription []byte
26
27func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
28 return fantasy.NewAgentTool(
29 DiagnosticsToolName,
30 string(diagnosticsDescription),
31 func(ctx context.Context, params DiagnosticsParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
32 if lspClients.Len() == 0 {
33 return fantasy.NewTextErrorResponse("no LSP clients available"), nil
34 }
35 notifyLSPs(ctx, lspClients, params.FilePath)
36 output := getDiagnostics(params.FilePath, lspClients)
37 return fantasy.NewTextResponse(output), nil
38 })
39}
40
41func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
42 if filepath == "" {
43 return
44 }
45 for client := range lsps.Seq() {
46 if !client.HandlesFile(filepath) {
47 continue
48 }
49 _ = client.OpenFileOnDemand(ctx, filepath)
50 _ = client.NotifyChange(ctx, filepath)
51 client.WaitForDiagnostics(ctx, 5*time.Second)
52 }
53}
54
55func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
56 fileDiagnostics := []string{}
57 projectDiagnostics := []string{}
58
59 for lspName, client := range lsps.Seq2() {
60 for location, diags := range client.GetDiagnostics() {
61 path, err := location.Path()
62 if err != nil {
63 slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
64 continue
65 }
66 isCurrentFile := path == filePath
67 for _, diag := range diags {
68 formattedDiag := formatDiagnostic(path, diag, lspName)
69 if isCurrentFile {
70 fileDiagnostics = append(fileDiagnostics, formattedDiag)
71 } else {
72 projectDiagnostics = append(projectDiagnostics, formattedDiag)
73 }
74 }
75 }
76 }
77
78 sortDiagnostics(fileDiagnostics)
79 sortDiagnostics(projectDiagnostics)
80
81 var output strings.Builder
82 writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
83 writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
84
85 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
86 fileErrors := countSeverity(fileDiagnostics, "Error")
87 fileWarnings := countSeverity(fileDiagnostics, "Warn")
88 projectErrors := countSeverity(projectDiagnostics, "Error")
89 projectWarnings := countSeverity(projectDiagnostics, "Warn")
90 output.WriteString("\n<diagnostic_summary>\n")
91 fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
92 fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
93 output.WriteString("</diagnostic_summary>\n")
94 }
95
96 out := output.String()
97 slog.Info("Diagnostics", "output", out)
98 return out
99}
100
101func writeDiagnostics(output *strings.Builder, tag string, in []string) {
102 if len(in) == 0 {
103 return
104 }
105 output.WriteString("\n<" + tag + ">\n")
106 if len(in) > 10 {
107 output.WriteString(strings.Join(in[:10], "\n"))
108 fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
109 } else {
110 output.WriteString(strings.Join(in, "\n"))
111 }
112 output.WriteString("\n</" + tag + ">\n")
113}
114
115func sortDiagnostics(in []string) []string {
116 sort.Slice(in, func(i, j int) bool {
117 iIsError := strings.HasPrefix(in[i], "Error")
118 jIsError := strings.HasPrefix(in[j], "Error")
119 if iIsError != jIsError {
120 return iIsError // Errors come first
121 }
122 return in[i] < in[j] // Then alphabetically
123 })
124 return in
125}
126
127func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
128 severity := "Info"
129 switch diagnostic.Severity {
130 case protocol.SeverityError:
131 severity = "Error"
132 case protocol.SeverityWarning:
133 severity = "Warn"
134 case protocol.SeverityHint:
135 severity = "Hint"
136 }
137
138 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
139
140 sourceInfo := ""
141 if diagnostic.Source != "" {
142 sourceInfo = diagnostic.Source
143 } else if source != "" {
144 sourceInfo = source
145 }
146
147 codeInfo := ""
148 if diagnostic.Code != nil {
149 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
150 }
151
152 tagsInfo := ""
153 if len(diagnostic.Tags) > 0 {
154 tags := []string{}
155 for _, tag := range diagnostic.Tags {
156 switch tag {
157 case protocol.Unnecessary:
158 tags = append(tags, "unnecessary")
159 case protocol.Deprecated:
160 tags = append(tags, "deprecated")
161 }
162 }
163 if len(tags) > 0 {
164 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
165 }
166 }
167
168 return fmt.Sprintf("%s: %s [%s]%s%s %s",
169 severity,
170 location,
171 sourceInfo,
172 codeInfo,
173 tagsInfo,
174 diagnostic.Message)
175}
176
177func countSeverity(diagnostics []string, severity string) int {
178 count := 0
179 for _, diag := range diagnostics {
180 if strings.HasPrefix(diag, severity) {
181 count++
182 }
183 }
184 return count
185}