references.go

  1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	_ "embed"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"path/filepath"
 12	"regexp"
 13	"slices"
 14	"sort"
 15	"strings"
 16
 17	"charm.land/fantasy"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 20)
 21
 22type ReferencesParams struct {
 23	Symbol string `json:"symbol" description:"The symbol name to search for (e.g., function name, variable name, type name)"`
 24	Path   string `json:"path,omitempty" description:"The directory to search in. Use a directory/file to narrow down the symbol search. Defaults to the current working directory."`
 25}
 26
 27type referencesTool struct {
 28	lspManager *lsp.Manager
 29}
 30
 31const ReferencesToolName = "lsp_references"
 32
 33//go:embed references.md
 34var referencesDescription string
 35
 36func NewReferencesTool(lspManager *lsp.Manager) fantasy.AgentTool {
 37	return fantasy.NewAgentTool(
 38		ReferencesToolName,
 39		referencesDescription,
 40		func(ctx context.Context, params ReferencesParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 41			if params.Symbol == "" {
 42				return fantasy.NewTextErrorResponse("symbol is required"), nil
 43			}
 44
 45			if lspManager.Clients().Len() == 0 {
 46				return fantasy.NewTextErrorResponse("no LSP clients available"), nil
 47			}
 48
 49			workingDir := cmp.Or(params.Path, ".")
 50
 51			matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
 52			if err != nil {
 53				return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
 54			}
 55
 56			if len(matches) == 0 {
 57				return fantasy.NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
 58			}
 59
 60			var allLocations []protocol.Location
 61			var allErrs error
 62			for _, match := range matches {
 63				locations, err := find(ctx, lspManager, params.Symbol, match)
 64				if err != nil {
 65					if strings.Contains(err.Error(), "no identifier found") {
 66						// grep probably matched a comment, string value, or something else that's irrelevant
 67						continue
 68					}
 69					slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
 70					allErrs = errors.Join(allErrs, err)
 71					continue
 72				}
 73				allLocations = append(allLocations, locations...)
 74				// Once we have results, we're done - LSP returns all references
 75				// for the symbol, not just from this file.
 76				if len(locations) > 0 {
 77					break
 78				}
 79			}
 80
 81			if len(allLocations) > 0 {
 82				output := formatReferences(cleanupLocations(allLocations))
 83				return fantasy.NewTextResponse(output), nil
 84			}
 85
 86			if allErrs != nil {
 87				return fantasy.NewTextErrorResponse(allErrs.Error()), nil
 88			}
 89			return fantasy.NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
 90		},
 91	)
 92}
 93
 94func (r *referencesTool) Name() string {
 95	return ReferencesToolName
 96}
 97
 98func find(ctx context.Context, lspManager *lsp.Manager, symbol string, match grepMatch) ([]protocol.Location, error) {
 99	absPath, err := filepath.Abs(match.path)
100	if err != nil {
101		return nil, fmt.Errorf("failed to get absolute path: %s", err)
102	}
103
104	var client *lsp.Client
105	for c := range lspManager.Clients().Seq() {
106		if c.HandlesFile(absPath) {
107			client = c
108			break
109		}
110	}
111
112	if client == nil {
113		slog.Warn("No LSP clients to handle", "path", match.path)
114		return nil, nil
115	}
116
117	return client.FindReferences(
118		ctx,
119		absPath,
120		match.lineNum,
121		match.charNum+getSymbolOffset(symbol),
122		true,
123	)
124}
125
126// getSymbolOffset returns the character offset to the actual symbol name
127// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
128func getSymbolOffset(symbol string) int {
129	// Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
130	if idx := strings.LastIndex(symbol, "::"); idx != -1 {
131		return idx + 2
132	}
133	// Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
134	if idx := strings.LastIndex(symbol, "."); idx != -1 {
135		return idx + 1
136	}
137	// Check for \ separator (PHP namespaces).
138	if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
139		return idx + 1
140	}
141	return 0
142}
143
144func cleanupLocations(locations []protocol.Location) []protocol.Location {
145	slices.SortFunc(locations, func(a, b protocol.Location) int {
146		if a.URI != b.URI {
147			return strings.Compare(string(a.URI), string(b.URI))
148		}
149		if a.Range.Start.Line != b.Range.Start.Line {
150			return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
151		}
152		return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
153	})
154	return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
155		return a.URI == b.URI &&
156			a.Range.Start.Line == b.Range.Start.Line &&
157			a.Range.Start.Character == b.Range.Start.Character
158	})
159}
160
161func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
162	files := make(map[string][]protocol.Location)
163	for _, loc := range locations {
164		path, err := loc.URI.Path()
165		if err != nil {
166			slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
167			continue
168		}
169		files[path] = append(files[path], loc)
170	}
171	return files
172}
173
174func formatReferences(locations []protocol.Location) string {
175	fileRefs := groupByFilename(locations)
176	files := slices.Collect(maps.Keys(fileRefs))
177	sort.Strings(files)
178
179	var output strings.Builder
180	fmt.Fprintf(&output, "Found %d reference(s) in %d file(s):\n\n", len(locations), len(files))
181
182	for _, file := range files {
183		refs := fileRefs[file]
184		fmt.Fprintf(&output, "%s (%d reference(s)):\n", file, len(refs))
185		for _, ref := range refs {
186			line := ref.Range.Start.Line + 1
187			char := ref.Range.Start.Character + 1
188			fmt.Fprintf(&output, "  Line %d, Column %d\n", line, char)
189		}
190		output.WriteString("\n")
191	}
192
193	return output.String()
194}