1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"errors"
  9	"fmt"
 10	"log/slog"
 11	"maps"
 12	"path/filepath"
 13	"regexp"
 14	"slices"
 15	"sort"
 16	"strings"
 17
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
 21)
 22
 23type ReferencesParams struct {
 24	Symbol string `json:"symbol"`
 25	Path   string `json:"path"`
 26}
 27
 28type referencesTool struct {
 29	lspClients *csync.Map[string, *lsp.Client]
 30}
 31
 32const ReferencesToolName = "lsp_references"
 33
 34//go:embed references.md
 35var referencesDescription []byte
 36
 37func NewReferencesTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool {
 38	return &referencesTool{
 39		lspClients,
 40	}
 41}
 42
 43func (r *referencesTool) Name() string {
 44	return ReferencesToolName
 45}
 46
 47func (r *referencesTool) Info() ToolInfo {
 48	return ToolInfo{
 49		Name:        ReferencesToolName,
 50		Description: string(referencesDescription),
 51		Parameters: map[string]any{
 52			"symbol": map[string]any{
 53				"type":        "string",
 54				"description": "The symbol name to search for (e.g., function name, variable name, type name).",
 55			},
 56			"path": map[string]any{
 57				"type":        "string",
 58				"description": "The directory to search in. Should be the entire project most of the time. Defaults to the current working directory.",
 59			},
 60		},
 61		Required: []string{"symbol"},
 62	}
 63}
 64
 65func (r *referencesTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 66	var params ReferencesParams
 67	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 68		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 69	}
 70
 71	if params.Symbol == "" {
 72		return NewTextErrorResponse("symbol is required"), nil
 73	}
 74
 75	if r.lspClients.Len() == 0 {
 76		return NewTextErrorResponse("no LSP clients available"), nil
 77	}
 78
 79	workingDir := cmp.Or(params.Path, ".")
 80
 81	matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
 82	if err != nil {
 83		return NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
 84	}
 85
 86	if len(matches) == 0 {
 87		return NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
 88	}
 89
 90	var allLocations []protocol.Location
 91	var allErrs error
 92	for _, match := range matches {
 93		locations, err := r.find(ctx, params.Symbol, match)
 94		if err != nil {
 95			if strings.Contains(err.Error(), "no identifier found") {
 96				// grep probably matched a comment, string value, or something else that's irrelevant
 97				continue
 98			}
 99			slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
100			allErrs = errors.Join(allErrs, err)
101			continue
102		}
103		allLocations = append(allLocations, locations...)
104		// XXX: should we break here or look for all results?
105	}
106
107	if len(allLocations) > 0 {
108		output := formatReferences(cleanupLocations(allLocations))
109		return NewTextResponse(output), nil
110	}
111
112	if allErrs != nil {
113		return NewTextErrorResponse(allErrs.Error()), nil
114	}
115	return NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
116}
117
118func (r *referencesTool) find(ctx context.Context, symbol string, match grepMatch) ([]protocol.Location, error) {
119	absPath, err := filepath.Abs(match.path)
120	if err != nil {
121		return nil, fmt.Errorf("failed to get absolute path: %s", err)
122	}
123
124	var client *lsp.Client
125	for c := range r.lspClients.Seq() {
126		if c.HandlesFile(absPath) {
127			client = c
128			break
129		}
130	}
131
132	if client == nil {
133		slog.Warn("No LSP clients to handle", "path", match.path)
134		return nil, nil
135	}
136
137	return client.FindReferences(
138		ctx,
139		absPath,
140		match.lineNum,
141		match.charNum+getSymbolOffset(symbol),
142		true,
143	)
144}
145
146// getSymbolOffset returns the character offset to the actual symbol name
147// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
148func getSymbolOffset(symbol string) int {
149	// Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
150	if idx := strings.LastIndex(symbol, "::"); idx != -1 {
151		return idx + 2
152	}
153	// Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
154	if idx := strings.LastIndex(symbol, "."); idx != -1 {
155		return idx + 1
156	}
157	// Check for \ separator (PHP namespaces).
158	if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
159		return idx + 1
160	}
161	return 0
162}
163
164func cleanupLocations(locations []protocol.Location) []protocol.Location {
165	slices.SortFunc(locations, func(a, b protocol.Location) int {
166		if a.URI != b.URI {
167			return strings.Compare(string(a.URI), string(b.URI))
168		}
169		if a.Range.Start.Line != b.Range.Start.Line {
170			return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
171		}
172		return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
173	})
174	return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
175		return a.URI == b.URI &&
176			a.Range.Start.Line == b.Range.Start.Line &&
177			a.Range.Start.Character == b.Range.Start.Character
178	})
179}
180
181func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
182	files := make(map[string][]protocol.Location)
183	for _, loc := range locations {
184		path, err := loc.URI.Path()
185		if err != nil {
186			slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
187			continue
188		}
189		files[path] = append(files[path], loc)
190	}
191	return files
192}
193
194func formatReferences(locations []protocol.Location) string {
195	fileRefs := groupByFilename(locations)
196	files := slices.Collect(maps.Keys(fileRefs))
197	sort.Strings(files)
198
199	var output strings.Builder
200	output.WriteString(fmt.Sprintf("Found %d reference(s) in %d file(s):\n\n", len(locations), len(files)))
201
202	for _, file := range files {
203		refs := fileRefs[file]
204		output.WriteString(fmt.Sprintf("%s (%d reference(s)):\n", file, len(refs)))
205		for _, ref := range refs {
206			line := ref.Range.Start.Line + 1
207			char := ref.Range.Start.Character + 1
208			output.WriteString(fmt.Sprintf("  Line %d, Column %d\n", line, char))
209		}
210		output.WriteString("\n")
211	}
212
213	return output.String()
214}