1package tools
2
3import (
4 "cmp"
5 "context"
6 _ "embed"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "path/filepath"
12 "regexp"
13 "slices"
14 "sort"
15 "strings"
16
17 "charm.land/fantasy"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
20)
21
22type ReferencesParams struct {
23 Symbol string `json:"symbol" description:"The symbol name to search for (e.g., function name, variable name, type name)"`
24 Path string `json:"path,omitempty" description:"The directory to search in. Use a directory/file to narrow down the symbol search. Defaults to the current working directory."`
25}
26
27type referencesTool struct {
28 lspManager *lsp.Manager
29}
30
31const ReferencesToolName = "lsp_references"
32
33//go:embed references.md
34var referencesDescription []byte
35
36func NewReferencesTool(lspManager *lsp.Manager) fantasy.AgentTool {
37 return fantasy.NewAgentTool(
38 ReferencesToolName,
39 string(referencesDescription),
40 func(ctx context.Context, params ReferencesParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
41 if params.Symbol == "" {
42 return fantasy.NewTextErrorResponse("symbol is required"), nil
43 }
44
45 if lspManager.Clients().Len() == 0 {
46 return fantasy.NewTextErrorResponse("no LSP clients available"), nil
47 }
48
49 workingDir := cmp.Or(params.Path, ".")
50
51 matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
52 if err != nil {
53 return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
54 }
55
56 if len(matches) == 0 {
57 return fantasy.NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
58 }
59
60 var allLocations []protocol.Location
61 var allErrs error
62 for _, match := range matches {
63 locations, err := find(ctx, lspManager, params.Symbol, match)
64 if err != nil {
65 if strings.Contains(err.Error(), "no identifier found") {
66 // grep probably matched a comment, string value, or something else that's irrelevant
67 continue
68 }
69 slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
70 allErrs = errors.Join(allErrs, err)
71 continue
72 }
73 allLocations = append(allLocations, locations...)
74 // Once we have results, we're done - LSP returns all references
75 // for the symbol, not just from this file.
76 if len(locations) > 0 {
77 break
78 }
79 }
80
81 if len(allLocations) > 0 {
82 output := formatReferences(cleanupLocations(allLocations))
83 return fantasy.NewTextResponse(output), nil
84 }
85
86 if allErrs != nil {
87 return fantasy.NewTextErrorResponse(allErrs.Error()), nil
88 }
89 return fantasy.NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
90 })
91}
92
93func (r *referencesTool) Name() string {
94 return ReferencesToolName
95}
96
97func find(ctx context.Context, lspManager *lsp.Manager, symbol string, match grepMatch) ([]protocol.Location, error) {
98 absPath, err := filepath.Abs(match.path)
99 if err != nil {
100 return nil, fmt.Errorf("failed to get absolute path: %s", err)
101 }
102
103 var client *lsp.Client
104 for c := range lspManager.Clients().Seq() {
105 if c.HandlesFile(absPath) {
106 client = c
107 break
108 }
109 }
110
111 if client == nil {
112 slog.Warn("No LSP clients to handle", "path", match.path)
113 return nil, nil
114 }
115
116 return client.FindReferences(
117 ctx,
118 absPath,
119 match.lineNum,
120 match.charNum+getSymbolOffset(symbol),
121 true,
122 )
123}
124
125// getSymbolOffset returns the character offset to the actual symbol name
126// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
127func getSymbolOffset(symbol string) int {
128 // Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
129 if idx := strings.LastIndex(symbol, "::"); idx != -1 {
130 return idx + 2
131 }
132 // Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
133 if idx := strings.LastIndex(symbol, "."); idx != -1 {
134 return idx + 1
135 }
136 // Check for \ separator (PHP namespaces).
137 if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
138 return idx + 1
139 }
140 return 0
141}
142
143func cleanupLocations(locations []protocol.Location) []protocol.Location {
144 slices.SortFunc(locations, func(a, b protocol.Location) int {
145 if a.URI != b.URI {
146 return strings.Compare(string(a.URI), string(b.URI))
147 }
148 if a.Range.Start.Line != b.Range.Start.Line {
149 return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
150 }
151 return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
152 })
153 return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
154 return a.URI == b.URI &&
155 a.Range.Start.Line == b.Range.Start.Line &&
156 a.Range.Start.Character == b.Range.Start.Character
157 })
158}
159
160func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
161 files := make(map[string][]protocol.Location)
162 for _, loc := range locations {
163 path, err := loc.URI.Path()
164 if err != nil {
165 slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
166 continue
167 }
168 files[path] = append(files[path], loc)
169 }
170 return files
171}
172
173func formatReferences(locations []protocol.Location) string {
174 fileRefs := groupByFilename(locations)
175 files := slices.Collect(maps.Keys(fileRefs))
176 sort.Strings(files)
177
178 var output strings.Builder
179 fmt.Fprintf(&output, "Found %d reference(s) in %d file(s):\n\n", len(locations), len(files))
180
181 for _, file := range files {
182 refs := fileRefs[file]
183 fmt.Fprintf(&output, "%s (%d reference(s)):\n", file, len(refs))
184 for _, ref := range refs {
185 line := ref.Range.Start.Line + 1
186 char := ref.Range.Start.Character + 1
187 fmt.Fprintf(&output, " Line %d, Column %d\n", line, char)
188 }
189 output.WriteString("\n")
190 }
191
192 return output.String()
193}