diagnostics.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"log/slog"
  8	"sort"
  9	"strings"
 10	"time"
 11
 12	"charm.land/fantasy"
 13	"github.com/charmbracelet/crush/internal/lsp"
 14	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 15)
 16
 17type DiagnosticsParams struct {
 18	FilePath string `json:"file_path,omitempty" description:"The path to the file to get diagnostics for (leave empty for project diagnostics)"`
 19}
 20
 21const DiagnosticsToolName = "lsp_diagnostics"
 22
 23//go:embed diagnostics.md
 24var diagnosticsDescription []byte
 25
 26func NewDiagnosticsTool(lspManager *lsp.Manager) fantasy.AgentTool {
 27	return fantasy.NewAgentTool(
 28		DiagnosticsToolName,
 29		string(diagnosticsDescription),
 30		func(ctx context.Context, params DiagnosticsParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 31			if lspManager.Clients().Len() == 0 {
 32				return fantasy.NewTextErrorResponse("no LSP clients available"), nil
 33			}
 34			notifyLSPs(ctx, lspManager, params.FilePath)
 35			output := getDiagnostics(params.FilePath, lspManager)
 36			return fantasy.NewTextResponse(output), nil
 37		})
 38}
 39
 40func notifyLSPs(
 41	ctx context.Context,
 42	manager *lsp.Manager,
 43	filepath string,
 44) {
 45	if filepath == "" {
 46		return
 47	}
 48
 49	if manager == nil {
 50		return
 51	}
 52
 53	manager.Start(ctx, filepath)
 54
 55	for client := range manager.Clients().Seq() {
 56		if !client.HandlesFile(filepath) {
 57			continue
 58		}
 59		_ = client.OpenFileOnDemand(ctx, filepath)
 60		_ = client.NotifyChange(ctx, filepath)
 61		client.WaitForDiagnostics(ctx, 5*time.Second)
 62	}
 63}
 64
 65func getDiagnostics(filePath string, manager *lsp.Manager) string {
 66	if manager == nil {
 67		return ""
 68	}
 69
 70	var fileDiagnostics []string
 71	var projectDiagnostics []string
 72
 73	for lspName, client := range manager.Clients().Seq2() {
 74		for location, diags := range client.GetDiagnostics() {
 75			path, err := location.Path()
 76			if err != nil {
 77				slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
 78				continue
 79			}
 80			isCurrentFile := path == filePath
 81			for _, diag := range diags {
 82				formattedDiag := formatDiagnostic(path, diag, lspName)
 83				if isCurrentFile {
 84					fileDiagnostics = append(fileDiagnostics, formattedDiag)
 85				} else {
 86					projectDiagnostics = append(projectDiagnostics, formattedDiag)
 87				}
 88			}
 89		}
 90	}
 91
 92	sortDiagnostics(fileDiagnostics)
 93	sortDiagnostics(projectDiagnostics)
 94
 95	var output strings.Builder
 96	writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
 97	writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
 98
 99	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
100		fileErrors := countSeverity(fileDiagnostics, "Error")
101		fileWarnings := countSeverity(fileDiagnostics, "Warn")
102		projectErrors := countSeverity(projectDiagnostics, "Error")
103		projectWarnings := countSeverity(projectDiagnostics, "Warn")
104		output.WriteString("\n<diagnostic_summary>\n")
105		fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
106		fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
107		output.WriteString("</diagnostic_summary>\n")
108	}
109
110	out := output.String()
111	slog.Debug("Diagnostics", "output", out)
112	return out
113}
114
115func writeDiagnostics(output *strings.Builder, tag string, in []string) {
116	if len(in) == 0 {
117		return
118	}
119	output.WriteString("\n<" + tag + ">\n")
120	if len(in) > 10 {
121		output.WriteString(strings.Join(in[:10], "\n"))
122		fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
123	} else {
124		output.WriteString(strings.Join(in, "\n"))
125	}
126	output.WriteString("\n</" + tag + ">\n")
127}
128
129func sortDiagnostics(in []string) []string {
130	sort.Slice(in, func(i, j int) bool {
131		iIsError := strings.HasPrefix(in[i], "Error")
132		jIsError := strings.HasPrefix(in[j], "Error")
133		if iIsError != jIsError {
134			return iIsError // Errors come first
135		}
136		return in[i] < in[j] // Then alphabetically
137	})
138	return in
139}
140
141func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
142	severity := "Info"
143	switch diagnostic.Severity {
144	case protocol.SeverityError:
145		severity = "Error"
146	case protocol.SeverityWarning:
147		severity = "Warn"
148	case protocol.SeverityHint:
149		severity = "Hint"
150	}
151
152	location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
153
154	sourceInfo := source
155	if diagnostic.Source != "" {
156		sourceInfo += " " + diagnostic.Source
157	}
158
159	codeInfo := ""
160	if diagnostic.Code != nil {
161		codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
162	}
163
164	tagsInfo := ""
165	if len(diagnostic.Tags) > 0 {
166		var tags []string
167		for _, tag := range diagnostic.Tags {
168			switch tag {
169			case protocol.Unnecessary:
170				tags = append(tags, "unnecessary")
171			case protocol.Deprecated:
172				tags = append(tags, "deprecated")
173			}
174		}
175		if len(tags) > 0 {
176			tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
177		}
178	}
179
180	return fmt.Sprintf("%s: %s [%s]%s%s %s",
181		severity,
182		location,
183		sourceInfo,
184		codeInfo,
185		tagsInfo,
186		diagnostic.Message)
187}
188
189func countSeverity(diagnostics []string, severity string) int {
190	count := 0
191	for _, diag := range diagnostics {
192		if strings.HasPrefix(diag, severity) {
193			count++
194		}
195	}
196	return count
197}