1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "maps"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/kujtimiihoxha/termai/internal/lsp"
13 "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
14)
15
16type diagnosticsTool struct {
17 lspClients map[string]*lsp.Client
18}
19
20const (
21 DiagnosticsToolName = "diagnostics"
22)
23
24type DiagnosticsParams struct {
25 FilePath string `json:"file_path"`
26}
27
28func (b *diagnosticsTool) Info() ToolInfo {
29 return ToolInfo{
30 Name: DiagnosticsToolName,
31 Description: "Get diagnostics for a file and/or project.",
32 Parameters: map[string]any{
33 "file_path": map[string]any{
34 "type": "string",
35 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
36 },
37 },
38 Required: []string{},
39 }
40}
41
42func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
43 var params DiagnosticsParams
44 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
45 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
46 }
47
48 lsps := b.lspClients
49
50 if len(lsps) == 0 {
51 return NewTextErrorResponse("no LSP clients available"), nil
52 }
53
54 if params.FilePath != "" {
55 notifyLspOpenFile(ctx, params.FilePath, lsps)
56 }
57
58 output := appendDiagnostics(params.FilePath, lsps)
59
60 return NewTextResponse(output), nil
61}
62
63func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
64 // Create a channel to receive diagnostic notifications
65 diagChan := make(chan struct{}, 1)
66
67 // Register a temporary diagnostic handler for each client
68 for _, client := range lsps {
69 // Store the original diagnostics map to detect changes
70 originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
71 maps.Copy(originalDiags, client.GetDiagnostics())
72
73 // Create a notification handler that will signal when diagnostics are received
74 handler := func(params json.RawMessage) {
75 var diagParams protocol.PublishDiagnosticsParams
76 if err := json.Unmarshal(params, &diagParams); err != nil {
77 return
78 }
79
80 // If this is for our file or we've received any new diagnostics, signal completion
81 if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
82 select {
83 case diagChan <- struct{}{}:
84 // Signal sent
85 default:
86 // Channel already has a value, no need to send again
87 }
88 }
89 }
90
91 // Register our temporary handler
92 client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
93
94 // Open the file
95 err := client.OpenFile(ctx, filePath)
96 if err != nil {
97 // If there's an error opening the file, continue to the next client
98 continue
99 }
100 }
101
102 // Wait for diagnostics with a reasonable timeout
103 select {
104 case <-diagChan:
105 // Diagnostics received
106 case <-time.After(5 * time.Second):
107 // Timeout after 2 seconds - this is a fallback in case no diagnostics are published
108 case <-ctx.Done():
109 // Context cancelled
110 }
111
112 // Note: We're not unregistering our handler because the Client.RegisterNotificationHandler
113 // replaces any existing handler, and we'll be replaced by the original handler when
114 // the LSP client is reinitialized or when a new handler is registered.
115}
116
117// hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set
118func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
119 for uri, diags := range current {
120 origDiags, exists := original[uri]
121 if !exists || len(diags) != len(origDiags) {
122 return true
123 }
124 }
125 return false
126}
127
128func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
129 fileDiagnostics := []string{}
130 projectDiagnostics := []string{}
131
132 // Enhanced format function that includes more diagnostic information
133 formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
134 // Base components
135 severity := "Info"
136 switch diagnostic.Severity {
137 case protocol.SeverityError:
138 severity = "Error"
139 case protocol.SeverityWarning:
140 severity = "Warn"
141 case protocol.SeverityHint:
142 severity = "Hint"
143 }
144
145 // Location information
146 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
147
148 // Source information (LSP name)
149 sourceInfo := ""
150 if diagnostic.Source != "" {
151 sourceInfo = diagnostic.Source
152 } else if source != "" {
153 sourceInfo = source
154 }
155
156 // Code information
157 codeInfo := ""
158 if diagnostic.Code != nil {
159 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
160 }
161
162 // Tags information
163 tagsInfo := ""
164 if len(diagnostic.Tags) > 0 {
165 tags := []string{}
166 for _, tag := range diagnostic.Tags {
167 switch tag {
168 case protocol.Unnecessary:
169 tags = append(tags, "unnecessary")
170 case protocol.Deprecated:
171 tags = append(tags, "deprecated")
172 }
173 }
174 if len(tags) > 0 {
175 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
176 }
177 }
178
179 // Assemble the full diagnostic message
180 return fmt.Sprintf("%s: %s [%s]%s%s %s",
181 severity,
182 location,
183 sourceInfo,
184 codeInfo,
185 tagsInfo,
186 diagnostic.Message)
187 }
188
189 for lspName, client := range lsps {
190 diagnostics := client.GetDiagnostics()
191 if len(diagnostics) > 0 {
192 for location, diags := range diagnostics {
193 isCurrentFile := location.Path() == filePath
194
195 // Group diagnostics by severity for better organization
196 for _, diag := range diags {
197 formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
198
199 if isCurrentFile {
200 fileDiagnostics = append(fileDiagnostics, formattedDiag)
201 } else {
202 projectDiagnostics = append(projectDiagnostics, formattedDiag)
203 }
204 }
205 }
206 }
207 }
208
209 // Sort diagnostics by severity (errors first) and then by location
210 sort.Slice(fileDiagnostics, func(i, j int) bool {
211 iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
212 jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
213 if iIsError != jIsError {
214 return iIsError // Errors come first
215 }
216 return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
217 })
218
219 sort.Slice(projectDiagnostics, func(i, j int) bool {
220 iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
221 jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
222 if iIsError != jIsError {
223 return iIsError
224 }
225 return projectDiagnostics[i] < projectDiagnostics[j]
226 })
227
228 output := ""
229
230 if len(fileDiagnostics) > 0 {
231 output += "\n<file_diagnostics>\n"
232 if len(fileDiagnostics) > 10 {
233 output += strings.Join(fileDiagnostics[:10], "\n")
234 output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
235 } else {
236 output += strings.Join(fileDiagnostics, "\n")
237 }
238 output += "\n</file_diagnostics>\n"
239 }
240
241 if len(projectDiagnostics) > 0 {
242 output += "\n<project_diagnostics>\n"
243 if len(projectDiagnostics) > 10 {
244 output += strings.Join(projectDiagnostics[:10], "\n")
245 output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
246 } else {
247 output += strings.Join(projectDiagnostics, "\n")
248 }
249 output += "\n</project_diagnostics>\n"
250 }
251
252 // Add summary counts
253 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
254 fileErrors := countSeverity(fileDiagnostics, "Error")
255 fileWarnings := countSeverity(fileDiagnostics, "Warn")
256 projectErrors := countSeverity(projectDiagnostics, "Error")
257 projectWarnings := countSeverity(projectDiagnostics, "Warn")
258
259 output += "\n<diagnostic_summary>\n"
260 output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
261 output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
262 output += "</diagnostic_summary>\n"
263 }
264
265 return output
266}
267
268// Helper function to count diagnostics by severity
269func countSeverity(diagnostics []string, severity string) int {
270 count := 0
271 for _, diag := range diagnostics {
272 if strings.HasPrefix(diag, severity) {
273 count++
274 }
275 }
276 return count
277}
278
279func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
280 return &diagnosticsTool{
281 lspClients,
282 }
283}