1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"sort"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/lsp"
 13	"github.com/charmbracelet/crush/internal/lsp/protocol"
 14)
 15
 16type DiagnosticsParams struct {
 17	FilePath string `json:"file_path"`
 18}
 19type diagnosticsTool struct {
 20	lspClients map[string]*lsp.Client
 21}
 22
 23const (
 24	DiagnosticsToolName    = "diagnostics"
 25	diagnosticsDescription = `Get diagnostics for a file and/or project.
 26WHEN TO USE THIS TOOL:
 27- Use when you need to check for errors or warnings in your code
 28- Helpful for debugging and ensuring code quality
 29- Good for getting a quick overview of issues in a file or project
 30HOW TO USE:
 31- Provide a path to a file to get diagnostics for that file
 32- Leave the path empty to get diagnostics for the entire project
 33- Results are displayed in a structured format with severity levels
 34FEATURES:
 35- Displays errors, warnings, and hints
 36- Groups diagnostics by severity
 37- Provides detailed information about each diagnostic
 38LIMITATIONS:
 39- Results are limited to the diagnostics provided by the LSP clients
 40- May not cover all possible issues in the code
 41- Does not provide suggestions for fixing issues
 42TIPS:
 43- Use in conjunction with other tools for a comprehensive code review
 44- Combine with the LSP client for real-time diagnostics
 45`
 46)
 47
 48func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
 49	return &diagnosticsTool{
 50		lspClients,
 51	}
 52}
 53
 54func (b *diagnosticsTool) Name() string {
 55	return DiagnosticsToolName
 56}
 57
 58func (b *diagnosticsTool) Info() ToolInfo {
 59	return ToolInfo{
 60		Name:        DiagnosticsToolName,
 61		Description: diagnosticsDescription,
 62		Parameters: map[string]any{
 63			"file_path": map[string]any{
 64				"type":        "string",
 65				"description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
 66			},
 67		},
 68		Required: []string{},
 69	}
 70}
 71
 72func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 73	var params DiagnosticsParams
 74	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 75		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 76	}
 77
 78	lsps := b.lspClients
 79
 80	if len(lsps) == 0 {
 81		return NewTextErrorResponse("no LSP clients available"), nil
 82	}
 83
 84	if params.FilePath != "" {
 85		notifyLspOpenFile(ctx, params.FilePath, lsps)
 86		waitForLspDiagnostics(ctx, params.FilePath, lsps)
 87	}
 88
 89	output := getDiagnostics(params.FilePath, lsps)
 90
 91	return NewTextResponse(output), nil
 92}
 93
 94func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
 95	for _, client := range lsps {
 96		err := client.OpenFile(ctx, filePath)
 97		if err != nil {
 98			continue
 99		}
100	}
101}
102
103func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
104	if len(lsps) == 0 {
105		return
106	}
107
108	diagChan := make(chan struct{}, 1)
109
110	for _, client := range lsps {
111		originalDiags := client.GetDiagnostics()
112
113		handler := func(params json.RawMessage) {
114			lsp.HandleDiagnostics(client, params)
115			var diagParams protocol.PublishDiagnosticsParams
116			if err := json.Unmarshal(params, &diagParams); err != nil {
117				return
118			}
119
120			path, err := diagParams.URI.Path()
121			if err != nil {
122				slog.Error("Failed to convert diagnostic URI to path", "uri", diagParams.URI, "error", err)
123				return
124			}
125
126			if path == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
127				select {
128				case diagChan <- struct{}{}:
129				default:
130				}
131			}
132		}
133
134		client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
135
136		if client.IsFileOpen(filePath) {
137			err := client.NotifyChange(ctx, filePath)
138			if err != nil {
139				continue
140			}
141		} else {
142			err := client.OpenFile(ctx, filePath)
143			if err != nil {
144				continue
145			}
146		}
147	}
148
149	select {
150	case <-diagChan:
151	case <-time.After(5 * time.Second):
152	case <-ctx.Done():
153	}
154}
155
156func hasDiagnosticsChanged(current, original map[protocol.DocumentURI][]protocol.Diagnostic) bool {
157	for uri, diags := range current {
158		origDiags, exists := original[uri]
159		if !exists || len(diags) != len(origDiags) {
160			return true
161		}
162	}
163	return false
164}
165
166func getDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
167	fileDiagnostics := []string{}
168	projectDiagnostics := []string{}
169
170	formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
171		severity := "Info"
172		switch diagnostic.Severity {
173		case protocol.SeverityError:
174			severity = "Error"
175		case protocol.SeverityWarning:
176			severity = "Warn"
177		case protocol.SeverityHint:
178			severity = "Hint"
179		}
180
181		location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
182
183		sourceInfo := ""
184		if diagnostic.Source != "" {
185			sourceInfo = diagnostic.Source
186		} else if source != "" {
187			sourceInfo = source
188		}
189
190		codeInfo := ""
191		if diagnostic.Code != nil {
192			codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
193		}
194
195		tagsInfo := ""
196		if len(diagnostic.Tags) > 0 {
197			tags := []string{}
198			for _, tag := range diagnostic.Tags {
199				switch tag {
200				case protocol.Unnecessary:
201					tags = append(tags, "unnecessary")
202				case protocol.Deprecated:
203					tags = append(tags, "deprecated")
204				}
205			}
206			if len(tags) > 0 {
207				tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
208			}
209		}
210
211		return fmt.Sprintf("%s: %s [%s]%s%s %s",
212			severity,
213			location,
214			sourceInfo,
215			codeInfo,
216			tagsInfo,
217			diagnostic.Message)
218	}
219
220	for lspName, client := range lsps {
221		diagnostics := client.GetDiagnostics()
222		if len(diagnostics) > 0 {
223			for location, diags := range diagnostics {
224				path, err := location.Path()
225				if err != nil {
226					slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
227					continue
228				}
229				isCurrentFile := path == filePath
230
231				for _, diag := range diags {
232					formattedDiag := formatDiagnostic(path, diag, lspName)
233
234					if isCurrentFile {
235						fileDiagnostics = append(fileDiagnostics, formattedDiag)
236					} else {
237						projectDiagnostics = append(projectDiagnostics, formattedDiag)
238					}
239				}
240			}
241		}
242	}
243
244	sort.Slice(fileDiagnostics, func(i, j int) bool {
245		iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
246		jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
247		if iIsError != jIsError {
248			return iIsError // Errors come first
249		}
250		return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
251	})
252
253	sort.Slice(projectDiagnostics, func(i, j int) bool {
254		iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
255		jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
256		if iIsError != jIsError {
257			return iIsError
258		}
259		return projectDiagnostics[i] < projectDiagnostics[j]
260	})
261
262	var output strings.Builder
263
264	if len(fileDiagnostics) > 0 {
265		output.WriteString("\n<file_diagnostics>\n")
266		if len(fileDiagnostics) > 10 {
267			output.WriteString(strings.Join(fileDiagnostics[:10], "\n"))
268			fmt.Fprintf(&output, "\n... and %d more diagnostics", len(fileDiagnostics)-10)
269		} else {
270			output.WriteString(strings.Join(fileDiagnostics, "\n"))
271		}
272		output.WriteString("\n</file_diagnostics>\n")
273	}
274
275	if len(projectDiagnostics) > 0 {
276		output.WriteString("\n<project_diagnostics>\n")
277		if len(projectDiagnostics) > 10 {
278			output.WriteString(strings.Join(projectDiagnostics[:10], "\n"))
279			fmt.Fprintf(&output, "\n... and %d more diagnostics", len(projectDiagnostics)-10)
280		} else {
281			output.WriteString(strings.Join(projectDiagnostics, "\n"))
282		}
283		output.WriteString("\n</project_diagnostics>\n")
284	}
285
286	if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
287		fileErrors := countSeverity(fileDiagnostics, "Error")
288		fileWarnings := countSeverity(fileDiagnostics, "Warn")
289		projectErrors := countSeverity(projectDiagnostics, "Error")
290		projectWarnings := countSeverity(projectDiagnostics, "Warn")
291
292		output.WriteString("\n<diagnostic_summary>\n")
293		fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
294		fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
295		output.WriteString("</diagnostic_summary>\n")
296	}
297
298	return output.String()
299}
300
301func countSeverity(diagnostics []string, severity string) int {
302	count := 0
303	for _, diag := range diagnostics {
304		if strings.HasPrefix(diag, severity) {
305			count++
306		}
307	}
308	return count
309}