1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	_ "embed"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"path/filepath"
 12	"regexp"
 13	"slices"
 14	"sort"
 15	"strings"
 16
 17	"charm.land/fantasy"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 21)
 22
 23type ReferencesParams struct {
 24	Symbol string `json:"symbol" description:"The symbol name to search for (e.g., function name, variable name, type name)"`
 25	Path   string `json:"path,omitempty" description:"The directory to search in. Use a directory/file to narrow down the symbol search. Defaults to the current working directory."`
 26}
 27
 28type referencesTool struct {
 29	lspClients *csync.Map[string, *lsp.Client]
 30}
 31
 32const ReferencesToolName = "lsp_references"
 33
 34//go:embed references.md
 35var referencesDescription []byte
 36
 37func NewReferencesTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
 38	return fantasy.NewAgentTool(
 39		ReferencesToolName,
 40		string(referencesDescription),
 41		func(ctx context.Context, params ReferencesParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 42			if params.Symbol == "" {
 43				return fantasy.NewTextErrorResponse("symbol is required"), nil
 44			}
 45
 46			if lspClients.Len() == 0 {
 47				return fantasy.NewTextErrorResponse("no LSP clients available"), nil
 48			}
 49
 50			workingDir := cmp.Or(params.Path, ".")
 51
 52			matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
 53			if err != nil {
 54				return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
 55			}
 56
 57			if len(matches) == 0 {
 58				return fantasy.NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
 59			}
 60
 61			var allLocations []protocol.Location
 62			var allErrs error
 63			for _, match := range matches {
 64				locations, err := find(ctx, lspClients, params.Symbol, match)
 65				if err != nil {
 66					if strings.Contains(err.Error(), "no identifier found") {
 67						// grep probably matched a comment, string value, or something else that's irrelevant
 68						continue
 69					}
 70					slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
 71					allErrs = errors.Join(allErrs, err)
 72					continue
 73				}
 74				allLocations = append(allLocations, locations...)
 75				// XXX: should we break here or look for all results?
 76			}
 77
 78			if len(allLocations) > 0 {
 79				output := formatReferences(cleanupLocations(allLocations))
 80				return fantasy.NewTextResponse(output), nil
 81			}
 82
 83			if allErrs != nil {
 84				return fantasy.NewTextErrorResponse(allErrs.Error()), nil
 85			}
 86			return fantasy.NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
 87		})
 88}
 89
 90func (r *referencesTool) Name() string {
 91	return ReferencesToolName
 92}
 93
 94func find(ctx context.Context, lspClients *csync.Map[string, *lsp.Client], symbol string, match grepMatch) ([]protocol.Location, error) {
 95	absPath, err := filepath.Abs(match.path)
 96	if err != nil {
 97		return nil, fmt.Errorf("failed to get absolute path: %s", err)
 98	}
 99
100	var client *lsp.Client
101	for c := range lspClients.Seq() {
102		if c.HandlesFile(absPath) {
103			client = c
104			break
105		}
106	}
107
108	if client == nil {
109		slog.Warn("No LSP clients to handle", "path", match.path)
110		return nil, nil
111	}
112
113	return client.FindReferences(
114		ctx,
115		absPath,
116		match.lineNum,
117		match.charNum+getSymbolOffset(symbol),
118		true,
119	)
120}
121
122// getSymbolOffset returns the character offset to the actual symbol name
123// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
124func getSymbolOffset(symbol string) int {
125	// Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
126	if idx := strings.LastIndex(symbol, "::"); idx != -1 {
127		return idx + 2
128	}
129	// Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
130	if idx := strings.LastIndex(symbol, "."); idx != -1 {
131		return idx + 1
132	}
133	// Check for \ separator (PHP namespaces).
134	if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
135		return idx + 1
136	}
137	return 0
138}
139
140func cleanupLocations(locations []protocol.Location) []protocol.Location {
141	slices.SortFunc(locations, func(a, b protocol.Location) int {
142		if a.URI != b.URI {
143			return strings.Compare(string(a.URI), string(b.URI))
144		}
145		if a.Range.Start.Line != b.Range.Start.Line {
146			return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
147		}
148		return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
149	})
150	return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
151		return a.URI == b.URI &&
152			a.Range.Start.Line == b.Range.Start.Line &&
153			a.Range.Start.Character == b.Range.Start.Character
154	})
155}
156
157func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
158	files := make(map[string][]protocol.Location)
159	for _, loc := range locations {
160		path, err := loc.URI.Path()
161		if err != nil {
162			slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
163			continue
164		}
165		files[path] = append(files[path], loc)
166	}
167	return files
168}
169
170func formatReferences(locations []protocol.Location) string {
171	fileRefs := groupByFilename(locations)
172	files := slices.Collect(maps.Keys(fileRefs))
173	sort.Strings(files)
174
175	var output strings.Builder
176	output.WriteString(fmt.Sprintf("Found %d reference(s) in %d file(s):\n\n", len(locations), len(files)))
177
178	for _, file := range files {
179		refs := fileRefs[file]
180		output.WriteString(fmt.Sprintf("%s (%d reference(s)):\n", file, len(refs)))
181		for _, ref := range refs {
182			line := ref.Range.Start.Line + 1
183			char := ref.Range.Start.Character + 1
184			output.WriteString(fmt.Sprintf("  Line %d, Column %d\n", line, char))
185		}
186		output.WriteString("\n")
187	}
188
189	return output.String()
190}