1package tools
2
3import (
4 "cmp"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "log/slog"
11 "maps"
12 "path/filepath"
13 "regexp"
14 "slices"
15 "sort"
16 "strings"
17
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/lsp"
20 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
21)
22
23type ReferencesParams struct {
24 Symbol string `json:"symbol"`
25 Path string `json:"path"`
26}
27
28type referencesTool struct {
29 lspClients *csync.Map[string, *lsp.Client]
30}
31
32const ReferencesToolName = "lsp_references"
33
34//go:embed references.md
35var referencesDescription []byte
36
37func NewReferencesTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool {
38 return &referencesTool{
39 lspClients,
40 }
41}
42
43func (r *referencesTool) Name() string {
44 return ReferencesToolName
45}
46
47func (r *referencesTool) Info() ToolInfo {
48 return ToolInfo{
49 Name: ReferencesToolName,
50 Description: string(referencesDescription),
51 Parameters: map[string]any{
52 "symbol": map[string]any{
53 "type": "string",
54 "description": "The symbol name to search for (e.g., function name, variable name, type name).",
55 },
56 "path": map[string]any{
57 "type": "string",
58 "description": "The directory to search in. Should be the entire project most of the time. Defaults to the current working directory.",
59 },
60 },
61 Required: []string{"symbol"},
62 }
63}
64
65func (r *referencesTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
66 var params ReferencesParams
67 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
68 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
69 }
70
71 if params.Symbol == "" {
72 return NewTextErrorResponse("symbol is required"), nil
73 }
74
75 if r.lspClients.Len() == 0 {
76 return NewTextErrorResponse("no LSP clients available"), nil
77 }
78
79 workingDir := cmp.Or(params.Path, ".")
80
81 matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
82 if err != nil {
83 return NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
84 }
85
86 if len(matches) == 0 {
87 return NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
88 }
89
90 var allLocations []protocol.Location
91 var allErrs error
92 for _, match := range matches {
93 locations, err := r.find(ctx, params.Symbol, match)
94 if err != nil {
95 if strings.Contains(err.Error(), "no identifier found") {
96 // grep probably matched a comment, string value, or something else that's irrelevant
97 continue
98 }
99 slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
100 allErrs = errors.Join(allErrs, err)
101 continue
102 }
103 allLocations = append(allLocations, locations...)
104 // XXX: should we break here or look for all results?
105 }
106
107 if len(allLocations) > 0 {
108 output := formatReferences(cleanupLocations(allLocations))
109 return NewTextResponse(output), nil
110 }
111
112 if allErrs != nil {
113 return NewTextErrorResponse(allErrs.Error()), nil
114 }
115 return NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
116}
117
118func (r *referencesTool) find(ctx context.Context, symbol string, match grepMatch) ([]protocol.Location, error) {
119 absPath, err := filepath.Abs(match.path)
120 if err != nil {
121 return nil, fmt.Errorf("failed to get absolute path: %s", err)
122 }
123
124 var client *lsp.Client
125 for c := range r.lspClients.Seq() {
126 if c.HandlesFile(absPath) {
127 client = c
128 break
129 }
130 }
131
132 if client == nil {
133 slog.Warn("No LSP clients to handle", "path", match.path)
134 return nil, nil
135 }
136
137 return client.FindReferences(
138 ctx,
139 absPath,
140 match.lineNum,
141 match.charNum+getSymbolOffset(symbol),
142 true,
143 )
144}
145
146// getSymbolOffset returns the character offset to the actual symbol name
147// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
148func getSymbolOffset(symbol string) int {
149 // Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
150 if idx := strings.LastIndex(symbol, "::"); idx != -1 {
151 return idx + 2
152 }
153 // Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
154 if idx := strings.LastIndex(symbol, "."); idx != -1 {
155 return idx + 1
156 }
157 // Check for \ separator (PHP namespaces).
158 if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
159 return idx + 1
160 }
161 return 0
162}
163
164func cleanupLocations(locations []protocol.Location) []protocol.Location {
165 slices.SortFunc(locations, func(a, b protocol.Location) int {
166 if a.URI != b.URI {
167 return strings.Compare(string(a.URI), string(b.URI))
168 }
169 if a.Range.Start.Line != b.Range.Start.Line {
170 return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
171 }
172 return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
173 })
174 return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
175 return a.URI == b.URI &&
176 a.Range.Start.Line == b.Range.Start.Line &&
177 a.Range.Start.Character == b.Range.Start.Character
178 })
179}
180
181func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
182 files := make(map[string][]protocol.Location)
183 for _, loc := range locations {
184 path, err := loc.URI.Path()
185 if err != nil {
186 slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
187 continue
188 }
189 files[path] = append(files[path], loc)
190 }
191 return files
192}
193
194func formatReferences(locations []protocol.Location) string {
195 fileRefs := groupByFilename(locations)
196 files := slices.Collect(maps.Keys(fileRefs))
197 sort.Strings(files)
198
199 var output strings.Builder
200 output.WriteString(fmt.Sprintf("Found %d reference(s) in %d file(s):\n\n", len(locations), len(files)))
201
202 for _, file := range files {
203 refs := fileRefs[file]
204 output.WriteString(fmt.Sprintf("%s (%d reference(s)):\n", file, len(refs)))
205 for _, ref := range refs {
206 line := ref.Range.Start.Line + 1
207 char := ref.Range.Start.Character + 1
208 output.WriteString(fmt.Sprintf(" Line %d, Column %d\n", line, char))
209 }
210 output.WriteString("\n")
211 }
212
213 return output.String()
214}