1package tools
2
3import (
4 "cmp"
5 "context"
6 _ "embed"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "path/filepath"
12 "regexp"
13 "slices"
14 "sort"
15 "strings"
16
17 "charm.land/fantasy"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/lsp"
20 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
21)
22
23type ReferencesParams struct {
24 Symbol string `json:"symbol" description:"The symbol name to search for (e.g., function name, variable name, type name)"`
25 Path string `json:"path,omitempty" description:"The directory to search in. Use a directory/file to narrow down the symbol search. Defaults to the current working directory."`
26}
27
28type referencesTool struct {
29 lspClients *csync.Map[string, *lsp.Client]
30}
31
32const ReferencesToolName = "lsp_references"
33
34//go:embed references.md
35var referencesDescription []byte
36
37func NewReferencesTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
38 return fantasy.NewAgentTool(
39 ReferencesToolName,
40 string(referencesDescription),
41 func(ctx context.Context, params ReferencesParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
42 if params.Symbol == "" {
43 return fantasy.NewTextErrorResponse("symbol is required"), nil
44 }
45
46 if lspClients.Len() == 0 {
47 return fantasy.NewTextErrorResponse("no LSP clients available"), nil
48 }
49
50 workingDir := cmp.Or(params.Path, ".")
51
52 matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
53 if err != nil {
54 return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
55 }
56
57 if len(matches) == 0 {
58 return fantasy.NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
59 }
60
61 var allLocations []protocol.Location
62 var allErrs error
63 for _, match := range matches {
64 locations, err := find(ctx, lspClients, params.Symbol, match)
65 if err != nil {
66 if strings.Contains(err.Error(), "no identifier found") {
67 // grep probably matched a comment, string value, or something else that's irrelevant
68 continue
69 }
70 slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
71 allErrs = errors.Join(allErrs, err)
72 continue
73 }
74 allLocations = append(allLocations, locations...)
75 // XXX: should we break here or look for all results?
76 }
77
78 if len(allLocations) > 0 {
79 output := formatReferences(cleanupLocations(allLocations))
80 return fantasy.NewTextResponse(output), nil
81 }
82
83 if allErrs != nil {
84 return fantasy.NewTextErrorResponse(allErrs.Error()), nil
85 }
86 return fantasy.NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
87 })
88}
89
90func (r *referencesTool) Name() string {
91 return ReferencesToolName
92}
93
94func find(ctx context.Context, lspClients *csync.Map[string, *lsp.Client], symbol string, match grepMatch) ([]protocol.Location, error) {
95 absPath, err := filepath.Abs(match.path)
96 if err != nil {
97 return nil, fmt.Errorf("failed to get absolute path: %s", err)
98 }
99
100 var client *lsp.Client
101 for c := range lspClients.Seq() {
102 if c.HandlesFile(absPath) {
103 client = c
104 break
105 }
106 }
107
108 if client == nil {
109 slog.Warn("No LSP clients to handle", "path", match.path)
110 return nil, nil
111 }
112
113 return client.FindReferences(
114 ctx,
115 absPath,
116 match.lineNum,
117 match.charNum+getSymbolOffset(symbol),
118 true,
119 )
120}
121
122// getSymbolOffset returns the character offset to the actual symbol name
123// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
124func getSymbolOffset(symbol string) int {
125 // Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
126 if idx := strings.LastIndex(symbol, "::"); idx != -1 {
127 return idx + 2
128 }
129 // Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
130 if idx := strings.LastIndex(symbol, "."); idx != -1 {
131 return idx + 1
132 }
133 // Check for \ separator (PHP namespaces).
134 if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
135 return idx + 1
136 }
137 return 0
138}
139
140func cleanupLocations(locations []protocol.Location) []protocol.Location {
141 slices.SortFunc(locations, func(a, b protocol.Location) int {
142 if a.URI != b.URI {
143 return strings.Compare(string(a.URI), string(b.URI))
144 }
145 if a.Range.Start.Line != b.Range.Start.Line {
146 return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
147 }
148 return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
149 })
150 return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
151 return a.URI == b.URI &&
152 a.Range.Start.Line == b.Range.Start.Line &&
153 a.Range.Start.Character == b.Range.Start.Character
154 })
155}
156
157func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
158 files := make(map[string][]protocol.Location)
159 for _, loc := range locations {
160 path, err := loc.URI.Path()
161 if err != nil {
162 slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
163 continue
164 }
165 files[path] = append(files[path], loc)
166 }
167 return files
168}
169
170func formatReferences(locations []protocol.Location) string {
171 fileRefs := groupByFilename(locations)
172 files := slices.Collect(maps.Keys(fileRefs))
173 sort.Strings(files)
174
175 var output strings.Builder
176 output.WriteString(fmt.Sprintf("Found %d reference(s) in %d file(s):\n\n", len(locations), len(files)))
177
178 for _, file := range files {
179 refs := fileRefs[file]
180 output.WriteString(fmt.Sprintf("%s (%d reference(s)):\n", file, len(refs)))
181 for _, ref := range refs {
182 line := ref.Range.Start.Line + 1
183 char := ref.Range.Start.Character + 1
184 output.WriteString(fmt.Sprintf(" Line %d, Column %d\n", line, char))
185 }
186 output.WriteString("\n")
187 }
188
189 return output.String()
190}