diagnostics.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"sort"
  8	"strings"
  9	"time"
 10
 11	"github.com/kujtimiihoxha/termai/internal/lsp"
 12	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 13)
 14
 15type diagnosticsTool struct {
 16	lspClients map[string]*lsp.Client
 17}
 18
 19const (
 20	DiagnosticsToolName = "diagnostics"
 21)
 22
 23type DiagnosticsParams struct {
 24	FilePath string `json:"file_path"`
 25}
 26
 27func (b *diagnosticsTool) Info() ToolInfo {
 28	return ToolInfo{
 29		Name:        DiagnosticsToolName,
 30		Description: "Get diagnostics for a file and/or project.",
 31		Parameters: map[string]any{
 32			"file_path": map[string]any{
 33				"type":        "string",
 34				"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
 35			},
 36		},
 37		Required: []string{},
 38	}
 39}
 40
 41func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 42	var params DiagnosticsParams
 43	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
 44		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 45	}
 46
 47	lsps := b.lspClients
 48
 49	if len(lsps) == 0 {
 50		return NewTextErrorResponse("no LSP clients available"), nil
 51	}
 52
 53	if params.FilePath == "" {
 54		notifyLspOpenFile(ctx, params.FilePath, lsps)
 55	}
 56
 57	output := appendDiagnostics(params.FilePath, lsps)
 58
 59	return NewTextResponse(output), nil
 60}
 61
 62func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
 63	for _, client := range lsps {
 64		err := client.OpenFile(ctx, filePath)
 65		if err != nil {
 66			// Wait for the file to be opened and diagnostics to be received
 67			// TODO: see if we can do this in a more efficient way
 68			time.Sleep(2 * time.Second)
 69		}
 70
 71	}
 72}
 73
 74func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
 75	fileDiagnostics := []string{}
 76	projectDiagnostics := []string{}
 77
 78	// Enhanced format function that includes more diagnostic information
 79	formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
 80		// Base components
 81		severity := "Info"
 82		switch diagnostic.Severity {
 83		case protocol.SeverityError:
 84			severity = "Error"
 85		case protocol.SeverityWarning:
 86			severity = "Warn"
 87		case protocol.SeverityHint:
 88			severity = "Hint"
 89		}
 90
 91		// Location information
 92		location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
 93
 94		// Source information (LSP name)
 95		sourceInfo := ""
 96		if diagnostic.Source != "" {
 97			sourceInfo = diagnostic.Source
 98		} else if source != "" {
 99			sourceInfo = source
100		}
101
102		// Code information
103		codeInfo := ""
104		if diagnostic.Code != nil {
105			codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
106		}
107
108		// Tags information
109		tagsInfo := ""
110		if len(diagnostic.Tags) > 0 {
111			tags := []string{}
112			for _, tag := range diagnostic.Tags {
113				switch tag {
114				case protocol.Unnecessary:
115					tags = append(tags, "unnecessary")
116				case protocol.Deprecated:
117					tags = append(tags, "deprecated")
118				}
119			}
120			if len(tags) > 0 {
121				tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
122			}
123		}
124
125		// Assemble the full diagnostic message
126		return fmt.Sprintf("%s: %s [%s]%s%s %s",
127			severity,
128			location,
129			sourceInfo,
130			codeInfo,
131			tagsInfo,
132			diagnostic.Message)
133	}
134
135	for lspName, client := range lsps {
136		diagnostics := client.GetDiagnostics()
137		if len(diagnostics) > 0 {
138			for location, diags := range diagnostics {
139				isCurrentFile := location.Path() == filePath
140
141				// Group diagnostics by severity for better organization
142				for _, diag := range diags {
143					formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
144
145					if isCurrentFile {
146						fileDiagnostics = append(fileDiagnostics, formattedDiag)
147					} else {
148						projectDiagnostics = append(projectDiagnostics, formattedDiag)
149					}
150				}
151			}
152		}
153	}
154
155	// Sort diagnostics by severity (errors first) and then by location
156	sort.Slice(fileDiagnostics, func(i, j int) bool {
157		iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
158		jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
159		if iIsError != jIsError {
160			return iIsError // Errors come first
161		}
162		return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
163	})
164
165	sort.Slice(projectDiagnostics, func(i, j int) bool {
166		iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
167		jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
168		if iIsError != jIsError {
169			return iIsError
170		}
171		return projectDiagnostics[i] < projectDiagnostics[j]
172	})
173
174	output := ""
175
176	if len(fileDiagnostics) > 0 {
177		output += "\n<file_diagnostics>\n"
178		if len(fileDiagnostics) > 10 {
179			output += strings.Join(fileDiagnostics[:10], "\n")
180			output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
181		} else {
182			output += strings.Join(fileDiagnostics, "\n")
183		}
184		output += "\n</file_diagnostics>\n"
185	}
186
187	if len(projectDiagnostics) > 0 {
188		output += "\n<project_diagnostics>\n"
189		if len(projectDiagnostics) > 10 {
190			output += strings.Join(projectDiagnostics[:10], "\n")
191			output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
192		} else {
193			output += strings.Join(projectDiagnostics, "\n")
194		}
195		output += "\n</project_diagnostics>\n"
196	}
197
198	// Add summary counts
199	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
200		fileErrors := countSeverity(fileDiagnostics, "Error")
201		fileWarnings := countSeverity(fileDiagnostics, "Warn")
202		projectErrors := countSeverity(projectDiagnostics, "Error")
203		projectWarnings := countSeverity(projectDiagnostics, "Warn")
204
205		output += "\n<diagnostic_summary>\n"
206		output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
207		output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
208		output += "</diagnostic_summary>\n"
209	}
210
211	return output
212}
213
214// Helper function to count diagnostics by severity
215func countSeverity(diagnostics []string, severity string) int {
216	count := 0
217	for _, diag := range diagnostics {
218		if strings.HasPrefix(diag, severity) {
219			count++
220		}
221	}
222	return count
223}
224
225func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
226	return &diagnosticsTool{
227		lspClients,
228	}
229}