1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"sort"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/lsp"
 15	"github.com/charmbracelet/crush/internal/proto"
 16	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 17)
 18
 19type DiagnosticsParams = proto.DiagnosticsParams
 20
 21type diagnosticsTool struct {
 22	lspClients *csync.Map[string, *lsp.Client]
 23}
 24
 25const DiagnosticsToolName = "diagnostics"
 26
 27//go:embed diagnostics.md
 28var diagnosticsDescription []byte
 29
 30func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool {
 31	return &diagnosticsTool{
 32		lspClients,
 33	}
 34}
 35
 36func (b *diagnosticsTool) Name() string {
 37	return DiagnosticsToolName
 38}
 39
 40func (b *diagnosticsTool) Info() ToolInfo {
 41	return ToolInfo{
 42		Name:        DiagnosticsToolName,
 43		Description: string(diagnosticsDescription),
 44		Parameters: map[string]any{
 45			"file_path": map[string]any{
 46				"type":        "string",
 47				"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
 48			},
 49		},
 50		Required: []string{},
 51	}
 52}
 53
 54func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 55	var params DiagnosticsParams
 56	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 57		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 58	}
 59
 60	if b.lspClients.Len() == 0 {
 61		return NewTextErrorResponse("no LSP clients available"), nil
 62	}
 63	notifyLSPs(ctx, b.lspClients, params.FilePath)
 64	output := getDiagnostics(params.FilePath, b.lspClients)
 65	return NewTextResponse(output), nil
 66}
 67
 68func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
 69	if filepath == "" {
 70		return
 71	}
 72	for client := range lsps.Seq() {
 73		if !client.HandlesFile(filepath) {
 74			continue
 75		}
 76		_ = client.OpenFileOnDemand(ctx, filepath)
 77		_ = client.NotifyChange(ctx, filepath)
 78		client.WaitForDiagnostics(ctx, 5*time.Second)
 79	}
 80}
 81
 82func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
 83	fileDiagnostics := []string{}
 84	projectDiagnostics := []string{}
 85
 86	for lspName, client := range lsps.Seq2() {
 87		for location, diags := range client.GetDiagnostics() {
 88			path, err := location.Path()
 89			if err != nil {
 90				slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
 91				continue
 92			}
 93			isCurrentFile := path == filePath
 94			for _, diag := range diags {
 95				formattedDiag := formatDiagnostic(path, diag, lspName)
 96				if isCurrentFile {
 97					fileDiagnostics = append(fileDiagnostics, formattedDiag)
 98				} else {
 99					projectDiagnostics = append(projectDiagnostics, formattedDiag)
100				}
101			}
102		}
103	}
104
105	sortDiagnostics(fileDiagnostics)
106	sortDiagnostics(projectDiagnostics)
107
108	var output strings.Builder
109	writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
110	writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
111
112	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
113		fileErrors := countSeverity(fileDiagnostics, "Error")
114		fileWarnings := countSeverity(fileDiagnostics, "Warn")
115		projectErrors := countSeverity(projectDiagnostics, "Error")
116		projectWarnings := countSeverity(projectDiagnostics, "Warn")
117		output.WriteString("\n<diagnostic_summary>\n")
118		fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
119		fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
120		output.WriteString("</diagnostic_summary>\n")
121	}
122
123	out := output.String()
124	slog.Info("Diagnostics", "output", fmt.Sprintf("%q", out))
125	return out
126}
127
128func writeDiagnostics(output *strings.Builder, tag string, in []string) {
129	if len(in) == 0 {
130		return
131	}
132	output.WriteString("\n<" + tag + ">\n")
133	if len(in) > 10 {
134		output.WriteString(strings.Join(in[:10], "\n"))
135		fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
136	} else {
137		output.WriteString(strings.Join(in, "\n"))
138	}
139	output.WriteString("\n</" + tag + ">\n")
140}
141
142func sortDiagnostics(in []string) []string {
143	sort.Slice(in, func(i, j int) bool {
144		iIsError := strings.HasPrefix(in[i], "Error")
145		jIsError := strings.HasPrefix(in[j], "Error")
146		if iIsError != jIsError {
147			return iIsError // Errors come first
148		}
149		return in[i] < in[j] // Then alphabetically
150	})
151	return in
152}
153
154func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
155	severity := "Info"
156	switch diagnostic.Severity {
157	case protocol.SeverityError:
158		severity = "Error"
159	case protocol.SeverityWarning:
160		severity = "Warn"
161	case protocol.SeverityHint:
162		severity = "Hint"
163	}
164
165	location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
166
167	sourceInfo := ""
168	if diagnostic.Source != "" {
169		sourceInfo = diagnostic.Source
170	} else if source != "" {
171		sourceInfo = source
172	}
173
174	codeInfo := ""
175	if diagnostic.Code != nil {
176		codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
177	}
178
179	tagsInfo := ""
180	if len(diagnostic.Tags) > 0 {
181		tags := []string{}
182		for _, tag := range diagnostic.Tags {
183			switch tag {
184			case protocol.Unnecessary:
185				tags = append(tags, "unnecessary")
186			case protocol.Deprecated:
187				tags = append(tags, "deprecated")
188			}
189		}
190		if len(tags) > 0 {
191			tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
192		}
193	}
194
195	return fmt.Sprintf("%s: %s [%s]%s%s %s",
196		severity,
197		location,
198		sourceInfo,
199		codeInfo,
200		tagsInfo,
201		diagnostic.Message)
202}
203
204func countSeverity(diagnostics []string, severity string) int {
205	count := 0
206	for _, diag := range diagnostics {
207		if strings.HasPrefix(diag, severity) {
208			count++
209		}
210	}
211	return count
212}