references.go

  1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	_ "embed"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"path/filepath"
 12	"regexp"
 13	"slices"
 14	"sort"
 15	"strings"
 16
 17	"charm.land/fantasy"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 21)
 22
 23type ReferencesParams struct {
 24	Symbol string `json:"symbol" description:"The symbol name to search for (e.g., function name, variable name, type name)"`
 25	Path   string `json:"path,omitempty" description:"The directory to search in. Use a directory/file to narrow down the symbol search. Defaults to the current working directory."`
 26}
 27
 28const ReferencesToolName = "lsp_references"
 29
 30//go:embed references.md
 31var referencesDescription []byte
 32
 33func NewReferencesTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
 34	return fantasy.NewAgentTool(
 35		ReferencesToolName,
 36		string(referencesDescription),
 37		func(ctx context.Context, params ReferencesParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 38			if params.Symbol == "" {
 39				return fantasy.NewTextErrorResponse("symbol is required"), nil
 40			}
 41
 42			if lspClients.Len() == 0 {
 43				return fantasy.NewTextErrorResponse("no LSP clients available"), nil
 44			}
 45
 46			workingDir := cmp.Or(params.Path, ".")
 47
 48			matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
 49			if err != nil {
 50				return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
 51			}
 52
 53			if len(matches) == 0 {
 54				return fantasy.NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
 55			}
 56
 57			var allLocations []protocol.Location
 58			var allErrs error
 59			for _, match := range matches {
 60				locations, err := find(ctx, lspClients, params.Symbol, match)
 61				if err != nil {
 62					if strings.Contains(err.Error(), "no identifier found") {
 63						// grep probably matched a comment, string value, or something else that's irrelevant
 64						continue
 65					}
 66					slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
 67					allErrs = errors.Join(allErrs, err)
 68					continue
 69				}
 70				allLocations = append(allLocations, locations...)
 71				// XXX: should we break here or look for all results?
 72			}
 73
 74			if len(allLocations) > 0 {
 75				output := formatReferences(cleanupLocations(allLocations))
 76				return fantasy.NewTextResponse(output), nil
 77			}
 78
 79			if allErrs != nil {
 80				return fantasy.NewTextErrorResponse(allErrs.Error()), nil
 81			}
 82			return fantasy.NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
 83		})
 84}
 85
 86func find(ctx context.Context, lspClients *csync.Map[string, *lsp.Client], symbol string, match grepMatch) ([]protocol.Location, error) {
 87	absPath, err := filepath.Abs(match.path)
 88	if err != nil {
 89		return nil, fmt.Errorf("failed to get absolute path: %s", err)
 90	}
 91
 92	var client *lsp.Client
 93	for c := range lspClients.Seq() {
 94		if c.HandlesFile(absPath) {
 95			client = c
 96			break
 97		}
 98	}
 99
100	if client == nil {
101		slog.Warn("No LSP clients to handle", "path", match.path)
102		return nil, nil
103	}
104
105	return client.FindReferences(
106		ctx,
107		absPath,
108		match.lineNum,
109		match.charNum+getSymbolOffset(symbol),
110		true,
111	)
112}
113
114// getSymbolOffset returns the character offset to the actual symbol name
115// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
116func getSymbolOffset(symbol string) int {
117	// Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
118	if idx := strings.LastIndex(symbol, "::"); idx != -1 {
119		return idx + 2
120	}
121	// Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
122	if idx := strings.LastIndex(symbol, "."); idx != -1 {
123		return idx + 1
124	}
125	// Check for \ separator (PHP namespaces).
126	if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
127		return idx + 1
128	}
129	return 0
130}
131
132func cleanupLocations(locations []protocol.Location) []protocol.Location {
133	slices.SortFunc(locations, func(a, b protocol.Location) int {
134		if a.URI != b.URI {
135			return strings.Compare(string(a.URI), string(b.URI))
136		}
137		if a.Range.Start.Line != b.Range.Start.Line {
138			return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
139		}
140		return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
141	})
142	return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
143		return a.URI == b.URI &&
144			a.Range.Start.Line == b.Range.Start.Line &&
145			a.Range.Start.Character == b.Range.Start.Character
146	})
147}
148
149func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
150	files := make(map[string][]protocol.Location)
151	for _, loc := range locations {
152		path, err := loc.URI.Path()
153		if err != nil {
154			slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
155			continue
156		}
157		files[path] = append(files[path], loc)
158	}
159	return files
160}
161
162func formatReferences(locations []protocol.Location) string {
163	fileRefs := groupByFilename(locations)
164	files := slices.Collect(maps.Keys(fileRefs))
165	sort.Strings(files)
166
167	var output strings.Builder
168	output.WriteString(fmt.Sprintf("Found %d reference(s) in %d file(s):\n\n", len(locations), len(files)))
169
170	for _, file := range files {
171		refs := fileRefs[file]
172		output.WriteString(fmt.Sprintf("%s (%d reference(s)):\n", file, len(refs)))
173		for _, ref := range refs {
174			line := ref.Range.Start.Line + 1
175			char := ref.Range.Start.Character + 1
176			output.WriteString(fmt.Sprintf("  Line %d, Column %d\n", line, char))
177		}
178		output.WriteString("\n")
179	}
180
181	return output.String()
182}