1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "maps"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/kujtimiihoxha/termai/internal/lsp"
13 "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
14)
15
16type diagnosticsTool struct {
17 lspClients map[string]*lsp.Client
18}
19
20const (
21 DiagnosticsToolName = "diagnostics"
22)
23
24type DiagnosticsParams struct {
25 FilePath string `json:"file_path"`
26}
27
28func (b *diagnosticsTool) Info() ToolInfo {
29 return ToolInfo{
30 Name: DiagnosticsToolName,
31 Description: "Get diagnostics for a file and/or project.",
32 Parameters: map[string]any{
33 "file_path": map[string]any{
34 "type": "string",
35 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
36 },
37 },
38 Required: []string{},
39 }
40}
41
42func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
43 var params DiagnosticsParams
44 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
45 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
46 }
47
48 lsps := b.lspClients
49
50 if len(lsps) == 0 {
51 return NewTextErrorResponse("no LSP clients available"), nil
52 }
53
54 if params.FilePath != "" {
55 notifyLspOpenFile(ctx, params.FilePath, lsps)
56 waitForLspDiagnostics(ctx, params.FilePath, lsps)
57 }
58
59 output := appendDiagnostics(params.FilePath, lsps)
60
61 return NewTextResponse(output), nil
62}
63
64func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
65 for _, client := range lsps {
66 // Open the file
67 err := client.OpenFile(ctx, filePath)
68 if err != nil {
69 // If there's an error opening the file, continue to the next client
70 continue
71 }
72 }
73}
74
75// waitForLspDiagnostics opens a file in LSP clients and waits for diagnostics to be published
76func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
77 if len(lsps) == 0 {
78 return
79 }
80
81 // Create a channel to receive diagnostic notifications
82 diagChan := make(chan struct{}, 1)
83
84 // Register a temporary diagnostic handler for each client
85 for _, client := range lsps {
86 // Store the original diagnostics map to detect changes
87 originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
88 maps.Copy(originalDiags, client.GetDiagnostics())
89
90 // Create a notification handler that will signal when diagnostics are received
91 handler := func(params json.RawMessage) {
92 lsp.HandleDiagnostics(client, params)
93 var diagParams protocol.PublishDiagnosticsParams
94 if err := json.Unmarshal(params, &diagParams); err != nil {
95 return
96 }
97
98 // If this is for our file or we've received any new diagnostics, signal completion
99 if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
100 select {
101 case diagChan <- struct{}{}:
102 // Signal sent
103 default:
104 // Channel already has a value, no need to send again
105 }
106 }
107 }
108
109 // Register our temporary handler
110 client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
111
112 // Notify change if the file is already open
113 if client.IsFileOpen(filePath) {
114 err := client.NotifyChange(ctx, filePath)
115 if err != nil {
116 continue
117 }
118 } else {
119 // Open the file if it's not already open
120 err := client.OpenFile(ctx, filePath)
121 if err != nil {
122 continue
123 }
124 }
125 }
126
127 // Wait for diagnostics with a reasonable timeout
128 select {
129 case <-diagChan:
130 // Diagnostics received
131 case <-time.After(5 * time.Second):
132 // Timeout after 5 seconds - this is a fallback in case no diagnostics are published
133 case <-ctx.Done():
134 // Context cancelled
135 }
136
137 // Note: We're not unregistering our handler because the Client.RegisterNotificationHandler
138 // replaces any existing handler, and we'll be replaced by the original handler when
139 // the LSP client is reinitialized or when a new handler is registered.
140}
141
142// hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set
143func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
144 for uri, diags := range current {
145 origDiags, exists := original[uri]
146 if !exists || len(diags) != len(origDiags) {
147 return true
148 }
149 }
150 return false
151}
152
153func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
154 fileDiagnostics := []string{}
155 projectDiagnostics := []string{}
156
157 // Enhanced format function that includes more diagnostic information
158 formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
159 // Base components
160 severity := "Info"
161 switch diagnostic.Severity {
162 case protocol.SeverityError:
163 severity = "Error"
164 case protocol.SeverityWarning:
165 severity = "Warn"
166 case protocol.SeverityHint:
167 severity = "Hint"
168 }
169
170 // Location information
171 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
172
173 // Source information (LSP name)
174 sourceInfo := ""
175 if diagnostic.Source != "" {
176 sourceInfo = diagnostic.Source
177 } else if source != "" {
178 sourceInfo = source
179 }
180
181 // Code information
182 codeInfo := ""
183 if diagnostic.Code != nil {
184 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
185 }
186
187 // Tags information
188 tagsInfo := ""
189 if len(diagnostic.Tags) > 0 {
190 tags := []string{}
191 for _, tag := range diagnostic.Tags {
192 switch tag {
193 case protocol.Unnecessary:
194 tags = append(tags, "unnecessary")
195 case protocol.Deprecated:
196 tags = append(tags, "deprecated")
197 }
198 }
199 if len(tags) > 0 {
200 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
201 }
202 }
203
204 // Assemble the full diagnostic message
205 return fmt.Sprintf("%s: %s [%s]%s%s %s",
206 severity,
207 location,
208 sourceInfo,
209 codeInfo,
210 tagsInfo,
211 diagnostic.Message)
212 }
213
214 for lspName, client := range lsps {
215 diagnostics := client.GetDiagnostics()
216 if len(diagnostics) > 0 {
217 for location, diags := range diagnostics {
218 isCurrentFile := location.Path() == filePath
219
220 // Group diagnostics by severity for better organization
221 for _, diag := range diags {
222 formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
223
224 if isCurrentFile {
225 fileDiagnostics = append(fileDiagnostics, formattedDiag)
226 } else {
227 projectDiagnostics = append(projectDiagnostics, formattedDiag)
228 }
229 }
230 }
231 }
232 }
233
234 // Sort diagnostics by severity (errors first) and then by location
235 sort.Slice(fileDiagnostics, func(i, j int) bool {
236 iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
237 jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
238 if iIsError != jIsError {
239 return iIsError // Errors come first
240 }
241 return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
242 })
243
244 sort.Slice(projectDiagnostics, func(i, j int) bool {
245 iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
246 jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
247 if iIsError != jIsError {
248 return iIsError
249 }
250 return projectDiagnostics[i] < projectDiagnostics[j]
251 })
252
253 output := ""
254
255 if len(fileDiagnostics) > 0 {
256 output += "\n<file_diagnostics>\n"
257 if len(fileDiagnostics) > 10 {
258 output += strings.Join(fileDiagnostics[:10], "\n")
259 output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
260 } else {
261 output += strings.Join(fileDiagnostics, "\n")
262 }
263 output += "\n</file_diagnostics>\n"
264 }
265
266 if len(projectDiagnostics) > 0 {
267 output += "\n<project_diagnostics>\n"
268 if len(projectDiagnostics) > 10 {
269 output += strings.Join(projectDiagnostics[:10], "\n")
270 output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
271 } else {
272 output += strings.Join(projectDiagnostics, "\n")
273 }
274 output += "\n</project_diagnostics>\n"
275 }
276
277 // Add summary counts
278 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
279 fileErrors := countSeverity(fileDiagnostics, "Error")
280 fileWarnings := countSeverity(fileDiagnostics, "Warn")
281 projectErrors := countSeverity(projectDiagnostics, "Error")
282 projectWarnings := countSeverity(projectDiagnostics, "Warn")
283
284 output += "\n<diagnostic_summary>\n"
285 output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
286 output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
287 output += "</diagnostic_summary>\n"
288 }
289
290 return output
291}
292
293// Helper function to count diagnostics by severity
294func countSeverity(diagnostics []string, severity string) int {
295 count := 0
296 for _, diag := range diagnostics {
297 if strings.HasPrefix(diag, severity) {
298 count++
299 }
300 }
301 return count
302}
303
304func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
305 return &diagnosticsTool{
306 lspClients,
307 }
308}