1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "maps"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/kujtimiihoxha/termai/internal/lsp"
13 "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
14)
15
16type diagnosticsTool struct {
17 lspClients map[string]*lsp.Client
18}
19
20const (
21 DiagnosticsToolName = "diagnostics"
22)
23
24type DiagnosticsParams struct {
25 FilePath string `json:"file_path"`
26}
27
28func (b *diagnosticsTool) Info() ToolInfo {
29 return ToolInfo{
30 Name: DiagnosticsToolName,
31 Description: "Get diagnostics for a file and/or project.",
32 Parameters: map[string]any{
33 "file_path": map[string]any{
34 "type": "string",
35 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
36 },
37 },
38 Required: []string{},
39 }
40}
41
42func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
43 var params DiagnosticsParams
44 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
45 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
46 }
47
48 lsps := b.lspClients
49
50 if len(lsps) == 0 {
51 return NewTextErrorResponse("no LSP clients available"), nil
52 }
53
54 if params.FilePath != "" {
55 notifyLspOpenFile(ctx, params.FilePath, lsps)
56 }
57
58 output := appendDiagnostics(params.FilePath, lsps)
59
60 return NewTextResponse(output), nil
61}
62
63func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
64 // Create a channel to receive diagnostic notifications
65 diagChan := make(chan struct{}, 1)
66
67 // Register a temporary diagnostic handler for each client
68 for _, client := range lsps {
69 // Store the original diagnostics map to detect changes
70 originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
71 maps.Copy(originalDiags, client.GetDiagnostics())
72
73 // Create a notification handler that will signal when diagnostics are received
74 handler := func(params json.RawMessage) {
75 lsp.HandleDiagnostics(client, params)
76 var diagParams protocol.PublishDiagnosticsParams
77 if err := json.Unmarshal(params, &diagParams); err != nil {
78 return
79 }
80
81 // If this is for our file or we've received any new diagnostics, signal completion
82 if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
83 select {
84 case diagChan <- struct{}{}:
85 // Signal sent
86 default:
87 // Channel already has a value, no need to send again
88 }
89 }
90 }
91
92 // Register our temporary handler
93 client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
94
95 // Open the file
96 err := client.OpenFile(ctx, filePath)
97 if err != nil {
98 // If there's an error opening the file, continue to the next client
99 continue
100 }
101 }
102
103 // Wait for diagnostics with a reasonable timeout
104 select {
105 case <-diagChan:
106 // Diagnostics received
107 case <-time.After(10 * time.Second):
108 // Timeout after 5 seconds - this is a fallback in case no diagnostics are published
109 case <-ctx.Done():
110 // Context cancelled
111 }
112
113 // Note: We're not unregistering our handler because the Client.RegisterNotificationHandler
114 // replaces any existing handler, and we'll be replaced by the original handler when
115 // the LSP client is reinitialized or when a new handler is registered.
116}
117
118// hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set
119func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
120 for uri, diags := range current {
121 origDiags, exists := original[uri]
122 if !exists || len(diags) != len(origDiags) {
123 return true
124 }
125 }
126 return false
127}
128
129func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
130 fileDiagnostics := []string{}
131 projectDiagnostics := []string{}
132
133 // Enhanced format function that includes more diagnostic information
134 formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
135 // Base components
136 severity := "Info"
137 switch diagnostic.Severity {
138 case protocol.SeverityError:
139 severity = "Error"
140 case protocol.SeverityWarning:
141 severity = "Warn"
142 case protocol.SeverityHint:
143 severity = "Hint"
144 }
145
146 // Location information
147 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
148
149 // Source information (LSP name)
150 sourceInfo := ""
151 if diagnostic.Source != "" {
152 sourceInfo = diagnostic.Source
153 } else if source != "" {
154 sourceInfo = source
155 }
156
157 // Code information
158 codeInfo := ""
159 if diagnostic.Code != nil {
160 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
161 }
162
163 // Tags information
164 tagsInfo := ""
165 if len(diagnostic.Tags) > 0 {
166 tags := []string{}
167 for _, tag := range diagnostic.Tags {
168 switch tag {
169 case protocol.Unnecessary:
170 tags = append(tags, "unnecessary")
171 case protocol.Deprecated:
172 tags = append(tags, "deprecated")
173 }
174 }
175 if len(tags) > 0 {
176 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
177 }
178 }
179
180 // Assemble the full diagnostic message
181 return fmt.Sprintf("%s: %s [%s]%s%s %s",
182 severity,
183 location,
184 sourceInfo,
185 codeInfo,
186 tagsInfo,
187 diagnostic.Message)
188 }
189
190 for lspName, client := range lsps {
191 diagnostics := client.GetDiagnostics()
192 if len(diagnostics) > 0 {
193 for location, diags := range diagnostics {
194 isCurrentFile := location.Path() == filePath
195
196 // Group diagnostics by severity for better organization
197 for _, diag := range diags {
198 formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
199
200 if isCurrentFile {
201 fileDiagnostics = append(fileDiagnostics, formattedDiag)
202 } else {
203 projectDiagnostics = append(projectDiagnostics, formattedDiag)
204 }
205 }
206 }
207 }
208 }
209
210 // Sort diagnostics by severity (errors first) and then by location
211 sort.Slice(fileDiagnostics, func(i, j int) bool {
212 iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
213 jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
214 if iIsError != jIsError {
215 return iIsError // Errors come first
216 }
217 return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
218 })
219
220 sort.Slice(projectDiagnostics, func(i, j int) bool {
221 iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
222 jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
223 if iIsError != jIsError {
224 return iIsError
225 }
226 return projectDiagnostics[i] < projectDiagnostics[j]
227 })
228
229 output := ""
230
231 if len(fileDiagnostics) > 0 {
232 output += "\n<file_diagnostics>\n"
233 if len(fileDiagnostics) > 10 {
234 output += strings.Join(fileDiagnostics[:10], "\n")
235 output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
236 } else {
237 output += strings.Join(fileDiagnostics, "\n")
238 }
239 output += "\n</file_diagnostics>\n"
240 }
241
242 if len(projectDiagnostics) > 0 {
243 output += "\n<project_diagnostics>\n"
244 if len(projectDiagnostics) > 10 {
245 output += strings.Join(projectDiagnostics[:10], "\n")
246 output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
247 } else {
248 output += strings.Join(projectDiagnostics, "\n")
249 }
250 output += "\n</project_diagnostics>\n"
251 }
252
253 // Add summary counts
254 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
255 fileErrors := countSeverity(fileDiagnostics, "Error")
256 fileWarnings := countSeverity(fileDiagnostics, "Warn")
257 projectErrors := countSeverity(projectDiagnostics, "Error")
258 projectWarnings := countSeverity(projectDiagnostics, "Warn")
259
260 output += "\n<diagnostic_summary>\n"
261 output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
262 output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
263 output += "</diagnostic_summary>\n"
264 }
265
266 return output
267}
268
269// Helper function to count diagnostics by severity
270func countSeverity(diagnostics []string, severity string) int {
271 count := 0
272 for _, diag := range diagnostics {
273 if strings.HasPrefix(diag, severity) {
274 count++
275 }
276 }
277 return count
278}
279
280func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
281 return &diagnosticsTool{
282 lspClients,
283 }
284}