1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"log/slog"
  8	"sort"
  9	"strings"
 10	"time"
 11
 12	"charm.land/fantasy"
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/lsp"
 15	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 16)
 17
 18type DiagnosticsParams struct {
 19	FilePath string `json:"file_path,omitempty" description:"The path to the file to get diagnostics for (leave w empty for project diagnostics)"`
 20}
 21
 22const DiagnosticsToolName = "lsp_diagnostics"
 23
 24//go:embed diagnostics.md
 25var diagnosticsDescription []byte
 26
 27func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
 28	return fantasy.NewAgentTool(
 29		DiagnosticsToolName,
 30		string(diagnosticsDescription),
 31		func(ctx context.Context, params DiagnosticsParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 32			if lspClients.Len() == 0 {
 33				return fantasy.NewTextErrorResponse("no LSP clients available"), nil
 34			}
 35			notifyLSPs(ctx, lspClients, params.FilePath)
 36			output := getDiagnostics(params.FilePath, lspClients)
 37			return fantasy.NewTextResponse(output), nil
 38		})
 39}
 40
 41func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
 42	if filepath == "" {
 43		return
 44	}
 45	for client := range lsps.Seq() {
 46		if !client.HandlesFile(filepath) {
 47			continue
 48		}
 49		_ = client.OpenFileOnDemand(ctx, filepath)
 50		_ = client.NotifyChange(ctx, filepath)
 51		client.WaitForDiagnostics(ctx, 5*time.Second)
 52	}
 53}
 54
 55func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
 56	fileDiagnostics := []string{}
 57	projectDiagnostics := []string{}
 58
 59	for lspName, client := range lsps.Seq2() {
 60		for location, diags := range client.GetDiagnostics() {
 61			path, err := location.Path()
 62			if err != nil {
 63				slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
 64				continue
 65			}
 66			isCurrentFile := path == filePath
 67			for _, diag := range diags {
 68				formattedDiag := formatDiagnostic(path, diag, lspName)
 69				if isCurrentFile {
 70					fileDiagnostics = append(fileDiagnostics, formattedDiag)
 71				} else {
 72					projectDiagnostics = append(projectDiagnostics, formattedDiag)
 73				}
 74			}
 75		}
 76	}
 77
 78	sortDiagnostics(fileDiagnostics)
 79	sortDiagnostics(projectDiagnostics)
 80
 81	var output strings.Builder
 82	writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
 83	writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
 84
 85	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
 86		fileErrors := countSeverity(fileDiagnostics, "Error")
 87		fileWarnings := countSeverity(fileDiagnostics, "Warn")
 88		projectErrors := countSeverity(projectDiagnostics, "Error")
 89		projectWarnings := countSeverity(projectDiagnostics, "Warn")
 90		output.WriteString("\n<diagnostic_summary>\n")
 91		fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
 92		fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
 93		output.WriteString("</diagnostic_summary>\n")
 94	}
 95
 96	out := output.String()
 97	slog.Info("Diagnostics", "output", out)
 98	return out
 99}
100
101func writeDiagnostics(output *strings.Builder, tag string, in []string) {
102	if len(in) == 0 {
103		return
104	}
105	output.WriteString("\n<" + tag + ">\n")
106	if len(in) > 10 {
107		output.WriteString(strings.Join(in[:10], "\n"))
108		fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
109	} else {
110		output.WriteString(strings.Join(in, "\n"))
111	}
112	output.WriteString("\n</" + tag + ">\n")
113}
114
115func sortDiagnostics(in []string) []string {
116	sort.Slice(in, func(i, j int) bool {
117		iIsError := strings.HasPrefix(in[i], "Error")
118		jIsError := strings.HasPrefix(in[j], "Error")
119		if iIsError != jIsError {
120			return iIsError // Errors come first
121		}
122		return in[i] < in[j] // Then alphabetically
123	})
124	return in
125}
126
127func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
128	severity := "Info"
129	switch diagnostic.Severity {
130	case protocol.SeverityError:
131		severity = "Error"
132	case protocol.SeverityWarning:
133		severity = "Warn"
134	case protocol.SeverityHint:
135		severity = "Hint"
136	}
137
138	location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
139
140	sourceInfo := ""
141	if diagnostic.Source != "" {
142		sourceInfo = diagnostic.Source
143	} else if source != "" {
144		sourceInfo = source
145	}
146
147	codeInfo := ""
148	if diagnostic.Code != nil {
149		codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
150	}
151
152	tagsInfo := ""
153	if len(diagnostic.Tags) > 0 {
154		tags := []string{}
155		for _, tag := range diagnostic.Tags {
156			switch tag {
157			case protocol.Unnecessary:
158				tags = append(tags, "unnecessary")
159			case protocol.Deprecated:
160				tags = append(tags, "deprecated")
161			}
162		}
163		if len(tags) > 0 {
164			tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
165		}
166	}
167
168	return fmt.Sprintf("%s: %s [%s]%s%s %s",
169		severity,
170		location,
171		sourceInfo,
172		codeInfo,
173		tagsInfo,
174		diagnostic.Message)
175}
176
177func countSeverity(diagnostics []string, severity string) int {
178	count := 0
179	for _, diag := range diagnostics {
180		if strings.HasPrefix(diag, severity) {
181			count++
182		}
183	}
184	return count
185}