diagnostics.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"sort"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/lsp"
 15	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 16)
 17
 18type DiagnosticsParams struct {
 19	FilePath string `json:"file_path"`
 20}
 21
 22type diagnosticsTool struct {
 23	lspClients *csync.Map[string, *lsp.Client]
 24}
 25
 26const DiagnosticsToolName = "lsp_diagnostics"
 27
 28//go:embed diagnostics.md
 29var diagnosticsDescription []byte
 30
 31func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool {
 32	return &diagnosticsTool{
 33		lspClients,
 34	}
 35}
 36
 37func (b *diagnosticsTool) Name() string {
 38	return DiagnosticsToolName
 39}
 40
 41func (b *diagnosticsTool) Info() ToolInfo {
 42	return ToolInfo{
 43		Name:        DiagnosticsToolName,
 44		Description: string(diagnosticsDescription),
 45		Parameters: map[string]any{
 46			"file_path": map[string]any{
 47				"type":        "string",
 48				"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
 49			},
 50		},
 51		Required: []string{},
 52	}
 53}
 54
 55func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 56	var params DiagnosticsParams
 57	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
 58		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 59	}
 60
 61	if b.lspClients.Len() == 0 {
 62		return NewTextErrorResponse("no LSP clients available"), nil
 63	}
 64	notifyLSPs(ctx, b.lspClients, params.FilePath)
 65	output := getDiagnostics(params.FilePath, b.lspClients)
 66	return NewTextResponse(output), nil
 67}
 68
 69func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
 70	if filepath == "" {
 71		return
 72	}
 73	for client := range lsps.Seq() {
 74		if !client.HandlesFile(filepath) {
 75			continue
 76		}
 77		_ = client.OpenFileOnDemand(ctx, filepath)
 78		_ = client.NotifyChange(ctx, filepath)
 79		client.WaitForDiagnostics(ctx, 5*time.Second)
 80	}
 81}
 82
 83func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
 84	fileDiagnostics := []string{}
 85	projectDiagnostics := []string{}
 86
 87	for lspName, client := range lsps.Seq2() {
 88		for location, diags := range client.GetDiagnostics() {
 89			path, err := location.Path()
 90			if err != nil {
 91				slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
 92				continue
 93			}
 94			isCurrentFile := path == filePath
 95			for _, diag := range diags {
 96				formattedDiag := formatDiagnostic(path, diag, lspName)
 97				if isCurrentFile {
 98					fileDiagnostics = append(fileDiagnostics, formattedDiag)
 99				} else {
100					projectDiagnostics = append(projectDiagnostics, formattedDiag)
101				}
102			}
103		}
104	}
105
106	sortDiagnostics(fileDiagnostics)
107	sortDiagnostics(projectDiagnostics)
108
109	var output strings.Builder
110	writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
111	writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
112
113	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
114		fileErrors := countSeverity(fileDiagnostics, "Error")
115		fileWarnings := countSeverity(fileDiagnostics, "Warn")
116		projectErrors := countSeverity(projectDiagnostics, "Error")
117		projectWarnings := countSeverity(projectDiagnostics, "Warn")
118		output.WriteString("\n<diagnostic_summary>\n")
119		fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
120		fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
121		output.WriteString("</diagnostic_summary>\n")
122	}
123
124	out := output.String()
125	slog.Info("Diagnostics", "output", out)
126	return out
127}
128
129func writeDiagnostics(output *strings.Builder, tag string, in []string) {
130	if len(in) == 0 {
131		return
132	}
133	output.WriteString("\n<" + tag + ">\n")
134	if len(in) > 10 {
135		output.WriteString(strings.Join(in[:10], "\n"))
136		fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
137	} else {
138		output.WriteString(strings.Join(in, "\n"))
139	}
140	output.WriteString("\n</" + tag + ">\n")
141}
142
143func sortDiagnostics(in []string) []string {
144	sort.Slice(in, func(i, j int) bool {
145		iIsError := strings.HasPrefix(in[i], "Error")
146		jIsError := strings.HasPrefix(in[j], "Error")
147		if iIsError != jIsError {
148			return iIsError // Errors come first
149		}
150		return in[i] < in[j] // Then alphabetically
151	})
152	return in
153}
154
155func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
156	severity := "Info"
157	switch diagnostic.Severity {
158	case protocol.SeverityError:
159		severity = "Error"
160	case protocol.SeverityWarning:
161		severity = "Warn"
162	case protocol.SeverityHint:
163		severity = "Hint"
164	}
165
166	location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
167
168	sourceInfo := ""
169	if diagnostic.Source != "" {
170		sourceInfo = diagnostic.Source
171	} else if source != "" {
172		sourceInfo = source
173	}
174
175	codeInfo := ""
176	if diagnostic.Code != nil {
177		codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
178	}
179
180	tagsInfo := ""
181	if len(diagnostic.Tags) > 0 {
182		tags := []string{}
183		for _, tag := range diagnostic.Tags {
184			switch tag {
185			case protocol.Unnecessary:
186				tags = append(tags, "unnecessary")
187			case protocol.Deprecated:
188				tags = append(tags, "deprecated")
189			}
190		}
191		if len(tags) > 0 {
192			tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
193		}
194	}
195
196	return fmt.Sprintf("%s: %s [%s]%s%s %s",
197		severity,
198		location,
199		sourceInfo,
200		codeInfo,
201		tagsInfo,
202		diagnostic.Message)
203}
204
205func countSeverity(diagnostics []string, severity string) int {
206	count := 0
207	for _, diag := range diagnostics {
208		if strings.HasPrefix(diag, severity) {
209			count++
210		}
211	}
212	return count
213}