1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"maps"
  9	"sort"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/lsp"
 14	"github.com/charmbracelet/crush/internal/lsp/protocol"
 15)
 16
 17type DiagnosticsParams struct {
 18	FilePath string `json:"file_path"`
 19}
 20type diagnosticsTool struct {
 21	lspClients map[string]*lsp.Client
 22}
 23
 24const (
 25	DiagnosticsToolName    = "diagnostics"
 26	diagnosticsDescription = `Get diagnostics for a file and/or project.
 27WHEN TO USE THIS TOOL:
 28- Use when you need to check for errors or warnings in your code
 29- Helpful for debugging and ensuring code quality
 30- Good for getting a quick overview of issues in a file or project
 31HOW TO USE:
 32- Provide a path to a file to get diagnostics for that file
 33- Leave the path empty to get diagnostics for the entire project
 34- Results are displayed in a structured format with severity levels
 35FEATURES:
 36- Displays errors, warnings, and hints
 37- Groups diagnostics by severity
 38- Provides detailed information about each diagnostic
 39LIMITATIONS:
 40- Results are limited to the diagnostics provided by the LSP clients
 41- May not cover all possible issues in the code
 42- Does not provide suggestions for fixing issues
 43TIPS:
 44- Use in conjunction with other tools for a comprehensive code review
 45- Combine with the LSP client for real-time diagnostics
 46`
 47)
 48
 49func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
 50	return &diagnosticsTool{
 51		lspClients,
 52	}
 53}
 54
 55func (b *diagnosticsTool) Name() string {
 56	return DiagnosticsToolName
 57}
 58
 59func (b *diagnosticsTool) Info() ToolInfo {
 60	return ToolInfo{
 61		Name:        DiagnosticsToolName,
 62		Description: diagnosticsDescription,
 63		Parameters: map[string]any{
 64			"file_path": map[string]any{
 65				"type":        "string",
 66				"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
 67			},
 68		},
 69		Required: []string{},
 70	}
 71}
 72
 73func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 74	var params DiagnosticsParams
 75	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 76		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 77	}
 78
 79	lsps := b.lspClients
 80
 81	if len(lsps) == 0 {
 82		return NewTextErrorResponse("no LSP clients available"), nil
 83	}
 84
 85	if params.FilePath != "" {
 86		notifyLspOpenFile(ctx, params.FilePath, lsps)
 87		waitForLspDiagnostics(ctx, params.FilePath, lsps)
 88	}
 89
 90	output := getDiagnostics(params.FilePath, lsps)
 91
 92	return NewTextResponse(output), nil
 93}
 94
 95func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
 96	for _, client := range lsps {
 97		err := client.OpenFile(ctx, filePath)
 98		if err != nil {
 99			continue
100		}
101	}
102}
103
104func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
105	if len(lsps) == 0 {
106		return
107	}
108
109	diagChan := make(chan struct{}, 1)
110
111	for _, client := range lsps {
112		originalDiags := make(map[protocol.DocumentURI][]protocol.Diagnostic)
113		maps.Copy(originalDiags, client.GetDiagnostics())
114
115		handler := func(params json.RawMessage) {
116			lsp.HandleDiagnostics(client, params)
117			var diagParams protocol.PublishDiagnosticsParams
118			if err := json.Unmarshal(params, &diagParams); err != nil {
119				return
120			}
121
122			path, err := diagParams.URI.Path()
123			if err != nil {
124				slog.Error("Failed to convert diagnostic URI to path", "uri", diagParams.URI, "error", err)
125				return
126			}
127
128			if path == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
129				select {
130				case diagChan <- struct{}{}:
131				default:
132				}
133			}
134		}
135
136		client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
137
138		if client.IsFileOpen(filePath) {
139			err := client.NotifyChange(ctx, filePath)
140			if err != nil {
141				continue
142			}
143		} else {
144			err := client.OpenFile(ctx, filePath)
145			if err != nil {
146				continue
147			}
148		}
149	}
150
151	select {
152	case <-diagChan:
153	case <-time.After(5 * time.Second):
154	case <-ctx.Done():
155	}
156}
157
158func hasDiagnosticsChanged(current, original map[protocol.DocumentURI][]protocol.Diagnostic) bool {
159	for uri, diags := range current {
160		origDiags, exists := original[uri]
161		if !exists || len(diags) != len(origDiags) {
162			return true
163		}
164	}
165	return false
166}
167
168func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
169	fileDiagnostics := []string{}
170	projectDiagnostics := []string{}
171
172	formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
173		severity := "Info"
174		switch diagnostic.Severity {
175		case protocol.SeverityError:
176			severity = "Error"
177		case protocol.SeverityWarning:
178			severity = "Warn"
179		case protocol.SeverityHint:
180			severity = "Hint"
181		}
182
183		location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
184
185		sourceInfo := ""
186		if diagnostic.Source != "" {
187			sourceInfo = diagnostic.Source
188		} else if source != "" {
189			sourceInfo = source
190		}
191
192		codeInfo := ""
193		if diagnostic.Code != nil {
194			codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
195		}
196
197		tagsInfo := ""
198		if len(diagnostic.Tags) > 0 {
199			tags := []string{}
200			for _, tag := range diagnostic.Tags {
201				switch tag {
202				case protocol.Unnecessary:
203					tags = append(tags, "unnecessary")
204				case protocol.Deprecated:
205					tags = append(tags, "deprecated")
206				}
207			}
208			if len(tags) > 0 {
209				tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
210			}
211		}
212
213		return fmt.Sprintf("%s: %s [%s]%s%s %s",
214			severity,
215			location,
216			sourceInfo,
217			codeInfo,
218			tagsInfo,
219			diagnostic.Message)
220	}
221
222	for lspName, client := range lsps {
223		diagnostics := client.GetDiagnostics()
224		if len(diagnostics) > 0 {
225			for location, diags := range diagnostics {
226				path, err := location.Path()
227				if err != nil {
228					slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
229					continue
230				}
231				isCurrentFile := path == filePath
232
233				for _, diag := range diags {
234					formattedDiag := formatDiagnostic(path, diag, lspName)
235
236					if isCurrentFile {
237						fileDiagnostics = append(fileDiagnostics, formattedDiag)
238					} else {
239						projectDiagnostics = append(projectDiagnostics, formattedDiag)
240					}
241				}
242			}
243		}
244	}
245
246	sort.Slice(fileDiagnostics, func(i, j int) bool {
247		iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
248		jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
249		if iIsError != jIsError {
250			return iIsError // Errors come first
251		}
252		return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
253	})
254
255	sort.Slice(projectDiagnostics, func(i, j int) bool {
256		iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
257		jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
258		if iIsError != jIsError {
259			return iIsError
260		}
261		return projectDiagnostics[i] < projectDiagnostics[j]
262	})
263
264	var output strings.Builder
265
266	if len(fileDiagnostics) > 0 {
267		output.WriteString("\n<file_diagnostics>\n")
268		if len(fileDiagnostics) > 10 {
269			output.WriteString(strings.Join(fileDiagnostics[:10], "\n"))
270			fmt.Fprintf(&output, "\n... and %d more diagnostics", len(fileDiagnostics)-10)
271		} else {
272			output.WriteString(strings.Join(fileDiagnostics, "\n"))
273		}
274		output.WriteString("\n</file_diagnostics>\n")
275	}
276
277	if len(projectDiagnostics) > 0 {
278		output.WriteString("\n<project_diagnostics>\n")
279		if len(projectDiagnostics) > 10 {
280			output.WriteString(strings.Join(projectDiagnostics[:10], "\n"))
281			fmt.Fprintf(&output, "\n... and %d more diagnostics", len(projectDiagnostics)-10)
282		} else {
283			output.WriteString(strings.Join(projectDiagnostics, "\n"))
284		}
285		output.WriteString("\n</project_diagnostics>\n")
286	}
287
288	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
289		fileErrors := countSeverity(fileDiagnostics, "Error")
290		fileWarnings := countSeverity(fileDiagnostics, "Warn")
291		projectErrors := countSeverity(projectDiagnostics, "Error")
292		projectWarnings := countSeverity(projectDiagnostics, "Warn")
293
294		output.WriteString("\n<diagnostic_summary>\n")
295		fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
296		fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
297		output.WriteString("</diagnostic_summary>\n")
298	}
299
300	return output.String()
301}
302
303func countSeverity(diagnostics []string, severity string) int {
304	count := 0
305	for _, diag := range diagnostics {
306		if strings.HasPrefix(diag, severity) {
307			count++
308		}
309	}
310	return count
311}