1package tools
2
3import (
4 "cmp"
5 "context"
6 _ "embed"
7 "errors"
8 "fmt"
9 "log/slog"
10 "maps"
11 "path/filepath"
12 "regexp"
13 "slices"
14 "sort"
15 "strings"
16
17 "charm.land/fantasy"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/lsp"
20 "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
21)
22
23type ReferencesParams struct {
24 Symbol string `json:"symbol" description:"The symbol name to search for (e.g., function name, variable name, type name)"`
25 Path string `json:"path,omitempty" description:"The directory to search in. Use a directory/file to narrow down the symbol search. Defaults to the current working directory."`
26}
27
28const ReferencesToolName = "lsp_references"
29
30//go:embed references.md
31var referencesDescription []byte
32
33func NewReferencesTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
34 return fantasy.NewAgentTool(
35 ReferencesToolName,
36 string(referencesDescription),
37 func(ctx context.Context, params ReferencesParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
38 if params.Symbol == "" {
39 return fantasy.NewTextErrorResponse("symbol is required"), nil
40 }
41
42 if lspClients.Len() == 0 {
43 return fantasy.NewTextErrorResponse("no LSP clients available"), nil
44 }
45
46 workingDir := cmp.Or(params.Path, ".")
47
48 matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
49 if err != nil {
50 return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
51 }
52
53 if len(matches) == 0 {
54 return fantasy.NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
55 }
56
57 var allLocations []protocol.Location
58 var allErrs error
59 for _, match := range matches {
60 locations, err := find(ctx, lspClients, params.Symbol, match)
61 if err != nil {
62 if strings.Contains(err.Error(), "no identifier found") {
63 // grep probably matched a comment, string value, or something else that's irrelevant
64 continue
65 }
66 slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
67 allErrs = errors.Join(allErrs, err)
68 continue
69 }
70 allLocations = append(allLocations, locations...)
71 // XXX: should we break here or look for all results?
72 }
73
74 if len(allLocations) > 0 {
75 output := formatReferences(cleanupLocations(allLocations))
76 return fantasy.NewTextResponse(output), nil
77 }
78
79 if allErrs != nil {
80 return fantasy.NewTextErrorResponse(allErrs.Error()), nil
81 }
82 return fantasy.NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
83 })
84}
85
86func find(ctx context.Context, lspClients *csync.Map[string, *lsp.Client], symbol string, match grepMatch) ([]protocol.Location, error) {
87 absPath, err := filepath.Abs(match.path)
88 if err != nil {
89 return nil, fmt.Errorf("failed to get absolute path: %s", err)
90 }
91
92 var client *lsp.Client
93 for c := range lspClients.Seq() {
94 if c.HandlesFile(absPath) {
95 client = c
96 break
97 }
98 }
99
100 if client == nil {
101 slog.Warn("No LSP clients to handle", "path", match.path)
102 return nil, nil
103 }
104
105 return client.FindReferences(
106 ctx,
107 absPath,
108 match.lineNum,
109 match.charNum+getSymbolOffset(symbol),
110 true,
111 )
112}
113
114// getSymbolOffset returns the character offset to the actual symbol name
115// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
116func getSymbolOffset(symbol string) int {
117 // Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
118 if idx := strings.LastIndex(symbol, "::"); idx != -1 {
119 return idx + 2
120 }
121 // Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
122 if idx := strings.LastIndex(symbol, "."); idx != -1 {
123 return idx + 1
124 }
125 // Check for \ separator (PHP namespaces).
126 if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
127 return idx + 1
128 }
129 return 0
130}
131
132func cleanupLocations(locations []protocol.Location) []protocol.Location {
133 slices.SortFunc(locations, func(a, b protocol.Location) int {
134 if a.URI != b.URI {
135 return strings.Compare(string(a.URI), string(b.URI))
136 }
137 if a.Range.Start.Line != b.Range.Start.Line {
138 return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
139 }
140 return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
141 })
142 return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
143 return a.URI == b.URI &&
144 a.Range.Start.Line == b.Range.Start.Line &&
145 a.Range.Start.Character == b.Range.Start.Character
146 })
147}
148
149func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
150 files := make(map[string][]protocol.Location)
151 for _, loc := range locations {
152 path, err := loc.URI.Path()
153 if err != nil {
154 slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
155 continue
156 }
157 files[path] = append(files[path], loc)
158 }
159 return files
160}
161
162func formatReferences(locations []protocol.Location) string {
163 fileRefs := groupByFilename(locations)
164 files := slices.Collect(maps.Keys(fileRefs))
165 sort.Strings(files)
166
167 var output strings.Builder
168 output.WriteString(fmt.Sprintf("Found %d reference(s) in %d file(s):\n\n", len(locations), len(files)))
169
170 for _, file := range files {
171 refs := fileRefs[file]
172 output.WriteString(fmt.Sprintf("%s (%d reference(s)):\n", file, len(refs)))
173 for _, ref := range refs {
174 line := ref.Range.Start.Line + 1
175 char := ref.Range.Start.Character + 1
176 output.WriteString(fmt.Sprintf(" Line %d, Column %d\n", line, char))
177 }
178 output.WriteString("\n")
179 }
180
181 return output.String()
182}