1package tools
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "sort"
8 "strings"
9 "time"
10
11 "github.com/kujtimiihoxha/termai/internal/lsp"
12 "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
13)
14
15type diagnosticsTool struct {
16 lspClients map[string]*lsp.Client
17}
18
19const (
20 DiagnosticsToolName = "diagnostics"
21)
22
23type DiagnosticsParams struct {
24 FilePath string `json:"file_path"`
25}
26
27func (b *diagnosticsTool) Info() ToolInfo {
28 return ToolInfo{
29 Name: DiagnosticsToolName,
30 Description: "Get diagnostics for a file and/or project.",
31 Parameters: map[string]any{
32 "file_path": map[string]any{
33 "type": "string",
34 "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
35 },
36 },
37 Required: []string{},
38 }
39}
40
41func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
42 var params DiagnosticsParams
43 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
44 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
45 }
46
47 lsps := b.lspClients
48
49 if len(lsps) == 0 {
50 return NewTextErrorResponse("no LSP clients available"), nil
51 }
52
53 if params.FilePath == "" {
54 notifyLspOpenFile(ctx, params.FilePath, lsps)
55 }
56
57 output := appendDiagnostics(params.FilePath, lsps)
58
59 return NewTextResponse(output), nil
60}
61
62func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
63 for _, client := range lsps {
64 err := client.OpenFile(ctx, filePath)
65 if err != nil {
66 // Wait for the file to be opened and diagnostics to be received
67 // TODO: see if we can do this in a more efficient way
68 time.Sleep(3 * time.Second)
69 }
70
71 }
72}
73
74func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
75 fileDiagnostics := []string{}
76 projectDiagnostics := []string{}
77
78 // Enhanced format function that includes more diagnostic information
79 formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
80 // Base components
81 severity := "Info"
82 switch diagnostic.Severity {
83 case protocol.SeverityError:
84 severity = "Error"
85 case protocol.SeverityWarning:
86 severity = "Warn"
87 case protocol.SeverityHint:
88 severity = "Hint"
89 }
90
91 // Location information
92 location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
93
94 // Source information (LSP name)
95 sourceInfo := ""
96 if diagnostic.Source != "" {
97 sourceInfo = diagnostic.Source
98 } else if source != "" {
99 sourceInfo = source
100 }
101
102 // Code information
103 codeInfo := ""
104 if diagnostic.Code != nil {
105 codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
106 }
107
108 // Tags information
109 tagsInfo := ""
110 if len(diagnostic.Tags) > 0 {
111 tags := []string{}
112 for _, tag := range diagnostic.Tags {
113 switch tag {
114 case protocol.Unnecessary:
115 tags = append(tags, "unnecessary")
116 case protocol.Deprecated:
117 tags = append(tags, "deprecated")
118 }
119 }
120 if len(tags) > 0 {
121 tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
122 }
123 }
124
125 // Assemble the full diagnostic message
126 return fmt.Sprintf("%s: %s [%s]%s%s %s",
127 severity,
128 location,
129 sourceInfo,
130 codeInfo,
131 tagsInfo,
132 diagnostic.Message)
133 }
134
135 for lspName, client := range lsps {
136 diagnostics := client.GetDiagnostics()
137 if len(diagnostics) > 0 {
138 for location, diags := range diagnostics {
139 isCurrentFile := location.Path() == filePath
140
141 // Group diagnostics by severity for better organization
142 for _, diag := range diags {
143 formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
144
145 if isCurrentFile {
146 fileDiagnostics = append(fileDiagnostics, formattedDiag)
147 } else {
148 projectDiagnostics = append(projectDiagnostics, formattedDiag)
149 }
150 }
151 }
152 }
153 }
154
155 // Sort diagnostics by severity (errors first) and then by location
156 sort.Slice(fileDiagnostics, func(i, j int) bool {
157 iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
158 jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
159 if iIsError != jIsError {
160 return iIsError // Errors come first
161 }
162 return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
163 })
164
165 sort.Slice(projectDiagnostics, func(i, j int) bool {
166 iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
167 jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
168 if iIsError != jIsError {
169 return iIsError
170 }
171 return projectDiagnostics[i] < projectDiagnostics[j]
172 })
173
174 output := ""
175
176 if len(fileDiagnostics) > 0 {
177 output += "\n<file_diagnostics>\n"
178 if len(fileDiagnostics) > 10 {
179 output += strings.Join(fileDiagnostics[:10], "\n")
180 output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
181 } else {
182 output += strings.Join(fileDiagnostics, "\n")
183 }
184 output += "\n</file_diagnostics>\n"
185 }
186
187 if len(projectDiagnostics) > 0 {
188 output += "\n<project_diagnostics>\n"
189 if len(projectDiagnostics) > 10 {
190 output += strings.Join(projectDiagnostics[:10], "\n")
191 output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
192 } else {
193 output += strings.Join(projectDiagnostics, "\n")
194 }
195 output += "\n</project_diagnostics>\n"
196 }
197
198 // Add summary counts
199 if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
200 fileErrors := countSeverity(fileDiagnostics, "Error")
201 fileWarnings := countSeverity(fileDiagnostics, "Warn")
202 projectErrors := countSeverity(projectDiagnostics, "Error")
203 projectWarnings := countSeverity(projectDiagnostics, "Warn")
204
205 output += "\n<diagnostic_summary>\n"
206 output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
207 output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
208 output += "</diagnostic_summary>\n"
209 }
210
211 return output
212}
213
214// Helper function to count diagnostics by severity
215func countSeverity(diagnostics []string, severity string) int {
216 count := 0
217 for _, diag := range diagnostics {
218 if strings.HasPrefix(diag, severity) {
219 count++
220 }
221 }
222 return count
223}
224
225func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
226 return &diagnosticsTool{
227 lspClients,
228 }
229}