From a64a4def3ea855ac2c84cd0c12d165fe5098b1a5 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Fri, 17 Oct 2025 09:56:58 -0300 Subject: [PATCH] feat(lsp): find references tool (#1233) Signed-off-by: Carlos Alexandro Becker --- go.mod | 2 +- go.sum | 4 +- internal/llm/agent/agent.go | 2 +- internal/llm/tools/diagnostics.go | 2 +- internal/llm/tools/grep.go | 108 +++++++------- internal/llm/tools/grep_test.go | 29 ++++ internal/llm/tools/references.go | 214 +++++++++++++++++++++++++++ internal/llm/tools/references.md | 36 +++++ internal/llm/tools/rg.go | 2 +- internal/llm/tools/testdata/grep.txt | 3 + internal/lsp/client.go | 10 ++ 11 files changed, 348 insertions(+), 64 deletions(-) create mode 100644 internal/llm/tools/references.go create mode 100644 internal/llm/tools/references.md create mode 100644 internal/llm/tools/testdata/grep.txt diff --git a/go.mod b/go.mod index e0b92a9380af54233306de80c826a5191878298a..c0bc32fe29ac100f98c589edf7697f104aa854a5 100644 --- a/go.mod +++ b/go.mod @@ -77,7 +77,7 @@ require ( github.com/charmbracelet/ultraviolet v0.0.0-20250915111650-81d4262876ef github.com/charmbracelet/x/cellbuf v0.0.14-0.20250811133356-e0c5dbe5ea4a // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20250829135019-44e44e21330d - github.com/charmbracelet/x/powernap v0.0.0-20250919153222-1038f7e6fef4 + github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4 github.com/charmbracelet/x/term v0.2.1 github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/windows v0.2.2 // indirect diff --git a/go.sum b/go.sum index d1c0349a66d5e8c0e9bf6968d849cb0cbf6d26c5..0fa4e9f695cf5d60a60be753aaee9a0b2e14c192 100644 --- a/go.sum +++ b/go.sum @@ -106,8 +106,8 @@ github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sB github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8= github.com/charmbracelet/x/exp/slice v0.0.0-20250829135019-44e44e21330d h1:H2oh4WlSsXy8qwLd7I3eAvPd/X3S40aM9l+h47WF1eA= github.com/charmbracelet/x/exp/slice v0.0.0-20250829135019-44e44e21330d/go.mod h1:vI5nDVMWi6veaYH+0Fmvpbe/+cv/iJfMntdh+N0+Tms= -github.com/charmbracelet/x/powernap v0.0.0-20250919153222-1038f7e6fef4 h1:ZhDGU688EHQXslD9KphRpXwK0pKP03egUoZAATUDlV0= -github.com/charmbracelet/x/powernap v0.0.0-20250919153222-1038f7e6fef4/go.mod h1:cmdl5zlP5mR8TF2Y68UKc7hdGUDiSJ2+4hk0h04Hsx4= +github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4 h1:i/XilBPYK4L1Yo/mc9FPx0SyJzIsN0y4sj1MWq9Sscc= +github.com/charmbracelet/x/powernap v0.0.0-20251015113943-25f979b54ad4/go.mod h1:cmdl5zlP5mR8TF2Y68UKc7hdGUDiSJ2+4hk0h04Hsx4= github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index e338eef782912bdfea48ca72ebfd33c4cd981f62..b2b222db1a481b1eb4c7e945467bd5c74506d5ab 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -525,7 +525,7 @@ func (a *agent) getAllTools() ([]tools.BaseTool, error) { if a.agentCfg.ID == "coder" { allTools = slices.AppendSeq(allTools, a.mcpTools.Seq()) if a.lspClients.Len() > 0 { - allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients)) + allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients), tools.NewReferencesTool(a.lspClients)) } } if a.agentToolFn != nil { diff --git a/internal/llm/tools/diagnostics.go b/internal/llm/tools/diagnostics.go index 8e0c332cef76e40d5e24e74ed3260b95aab8b04b..c2625e9495963f1de467656b2d74e71e0b3c78fa 100644 --- a/internal/llm/tools/diagnostics.go +++ b/internal/llm/tools/diagnostics.go @@ -23,7 +23,7 @@ type diagnosticsTool struct { lspClients *csync.Map[string, *lsp.Client] } -const DiagnosticsToolName = "diagnostics" +const DiagnosticsToolName = "lsp_diagnostics" //go:embed diagnostics.md var diagnosticsDescription []byte diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 237d4e18dab0bc518b9d4b6e2c73ef5035d2b348..ed844b6c10081deab6a314f380da72e0893102ca 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -2,6 +2,7 @@ package tools import ( "bufio" + "bytes" "context" _ "embed" "encoding/json" @@ -13,7 +14,6 @@ import ( "path/filepath" "regexp" "sort" - "strconv" "strings" "sync" "time" @@ -82,6 +82,7 @@ type grepMatch struct { path string modTime time.Time lineNum int + charNum int lineText string } @@ -189,7 +190,11 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) fmt.Fprintf(&output, "%s:\n", match.path) } if match.lineNum > 0 { - fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, match.lineText) + if match.charNum > 0 { + fmt.Fprintf(&output, " Line %d, Char %d: %s\n", match.lineNum, match.charNum, match.lineText) + } else { + fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, match.lineText) + } } else { fmt.Fprintf(&output, " %s\n", match.path) } @@ -252,66 +257,51 @@ func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]gr return nil, err } - lines := strings.Split(strings.TrimSpace(string(output)), "\n") - matches := make([]grepMatch, 0, len(lines)) - - for _, line := range lines { - if line == "" { + var matches []grepMatch + for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) { + if len(line) == 0 { continue } - - // Parse ripgrep output using null separation - filePath, lineNumStr, lineText, ok := parseRipgrepLine(line) - if !ok { + var match ripgrepMatch + if err := json.Unmarshal(line, &match); err != nil { continue } - - lineNum, err := strconv.Atoi(lineNumStr) - if err != nil { + if match.Type != "match" { continue } - - fileInfo, err := os.Stat(filePath) - if err != nil { - continue // Skip files we can't access + for _, m := range match.Data.Submatches { + fi, err := os.Stat(match.Data.Path.Text) + if err != nil { + continue // Skip files we can't access + } + matches = append(matches, grepMatch{ + path: match.Data.Path.Text, + modTime: fi.ModTime(), + lineNum: match.Data.LineNumber, + charNum: m.Start + 1, // ensure 1-based + lineText: strings.TrimSpace(match.Data.Lines.Text), + }) + // only get the first match of each line + break } - - matches = append(matches, grepMatch{ - path: filePath, - modTime: fileInfo.ModTime(), - lineNum: lineNum, - lineText: lineText, - }) } - return matches, nil } -// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths -func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) { - // Split on null byte first to separate filename from rest - parts := strings.SplitN(line, "\x00", 2) - if len(parts) != 2 { - return "", "", "", false - } - - filePath = parts[0] - remainder := parts[1] - - // Now split the remainder on first colon: "linenum:content" - colonIndex := strings.Index(remainder, ":") - if colonIndex == -1 { - return "", "", "", false - } - - lineNumStr := remainder[:colonIndex] - lineText = remainder[colonIndex+1:] - - if _, err := strconv.Atoi(lineNumStr); err != nil { - return "", "", "", false - } - - return filePath, lineNumStr, lineText, true +type ripgrepMatch struct { + Type string `json:"type"` + Data struct { + Path struct { + Text string `json:"text"` + } `json:"path"` + Lines struct { + Text string `json:"text"` + } `json:"lines"` + LineNumber int `json:"line_number"` + Submatches []struct { + Start int `json:"start"` + } `json:"submatches"` + } `json:"data"` } func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) { @@ -363,7 +353,7 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error return nil } - match, lineNum, lineText, err := fileContainsPattern(path, regex) + match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex) if err != nil { return nil // Skip files we can't read } @@ -373,6 +363,7 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error path: path, modTime: info.ModTime(), lineNum: lineNum, + charNum: charNum, lineText: lineText, }) @@ -390,15 +381,15 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error return matches, nil } -func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) { +func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) { // Only search text files. if !isTextFile(filePath) { - return false, 0, "", nil + return false, 0, 0, "", nil } file, err := os.Open(filePath) if err != nil { - return false, 0, "", err + return false, 0, 0, "", err } defer file.Close() @@ -407,12 +398,13 @@ func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, st for scanner.Scan() { lineNum++ line := scanner.Text() - if pattern.MatchString(line) { - return true, lineNum, line, nil + if loc := pattern.FindStringIndex(line); loc != nil { + charNum := loc[0] + 1 + return true, lineNum, charNum, line, nil } } - return false, 0, "", scanner.Err() + return false, 0, 0, "", scanner.Err() } // isTextFile checks if a file is a text file by examining its MIME type. diff --git a/internal/llm/tools/grep_test.go b/internal/llm/tools/grep_test.go index 435b3045b93a8e1297ff2aaeff9ee8977b974b56..753ee05942b78578fd2e9170384cac3fd5d9496e 100644 --- a/internal/llm/tools/grep_test.go +++ b/internal/llm/tools/grep_test.go @@ -390,3 +390,32 @@ func TestIsTextFile(t *testing.T) { }) } } + +func TestColumnMatch(t *testing.T) { + t.Parallel() + + // Test both implementations + for name, fn := range map[string]func(pattern, path, include string) ([]grepMatch, error){ + "regex": searchFilesWithRegex, + "rg": func(pattern, path, include string) ([]grepMatch, error) { + return searchWithRipgrep(t.Context(), pattern, path, include) + }, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + + if name == "rg" && getRg() == "" { + t.Skip("rg is not in $PATH") + } + + matches, err := fn("THIS", "./testdata/", "") + require.NoError(t, err) + require.Len(t, matches, 1) + match := matches[0] + require.Equal(t, 2, match.lineNum) + require.Equal(t, 14, match.charNum) + require.Equal(t, "I wanna grep THIS particular word", match.lineText) + require.Equal(t, "testdata/grep.txt", filepath.ToSlash(filepath.Clean(match.path))) + }) + } +} diff --git a/internal/llm/tools/references.go b/internal/llm/tools/references.go new file mode 100644 index 0000000000000000000000000000000000000000..a1bc393cd5d28755f5f0b694c1b2df40bee1a39e --- /dev/null +++ b/internal/llm/tools/references.go @@ -0,0 +1,214 @@ +package tools + +import ( + "cmp" + "context" + _ "embed" + "encoding/json" + "errors" + "fmt" + "log/slog" + "maps" + "path/filepath" + "regexp" + "slices" + "sort" + "strings" + + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/x/powernap/pkg/lsp/protocol" +) + +type ReferencesParams struct { + Symbol string `json:"symbol"` + Path string `json:"path"` +} + +type referencesTool struct { + lspClients *csync.Map[string, *lsp.Client] +} + +const ReferencesToolName = "lsp_references" + +//go:embed references.md +var referencesDescription []byte + +func NewReferencesTool(lspClients *csync.Map[string, *lsp.Client]) BaseTool { + return &referencesTool{ + lspClients, + } +} + +func (r *referencesTool) Name() string { + return ReferencesToolName +} + +func (r *referencesTool) Info() ToolInfo { + return ToolInfo{ + Name: ReferencesToolName, + Description: string(referencesDescription), + Parameters: map[string]any{ + "symbol": map[string]any{ + "type": "string", + "description": "The symbol name to search for (e.g., function name, variable name, type name).", + }, + "path": map[string]any{ + "type": "string", + "description": "The directory to search in. Should be the entire project most of the time. Defaults to the current working directory.", + }, + }, + Required: []string{"symbol"}, + } +} + +func (r *referencesTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + var params ReferencesParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil + } + + if params.Symbol == "" { + return NewTextErrorResponse("symbol is required"), nil + } + + if r.lspClients.Len() == 0 { + return NewTextErrorResponse("no LSP clients available"), nil + } + + workingDir := cmp.Or(params.Path, ".") + + matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil + } + + if len(matches) == 0 { + return NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil + } + + var allLocations []protocol.Location + var allErrs error + for _, match := range matches { + locations, err := r.find(ctx, params.Symbol, match) + if err != nil { + if strings.Contains(err.Error(), "no identifier found") { + // grep probably matched a comment, string value, or something else that's irrelevant + continue + } + slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum) + allErrs = errors.Join(allErrs, err) + continue + } + allLocations = append(allLocations, locations...) + // XXX: should we break here or look for all results? + } + + if len(allLocations) > 0 { + output := formatReferences(cleanupLocations(allLocations)) + return NewTextResponse(output), nil + } + + if allErrs != nil { + return NewTextErrorResponse(allErrs.Error()), nil + } + return NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil +} + +func (r *referencesTool) find(ctx context.Context, symbol string, match grepMatch) ([]protocol.Location, error) { + absPath, err := filepath.Abs(match.path) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path: %s", err) + } + + var client *lsp.Client + for c := range r.lspClients.Seq() { + if c.HandlesFile(absPath) { + client = c + break + } + } + + if client == nil { + slog.Warn("No LSP clients to handle", "path", match.path) + return nil, nil + } + + return client.FindReferences( + ctx, + absPath, + match.lineNum, + match.charNum+getSymbolOffset(symbol), + true, + ) +} + +// getSymbolOffset returns the character offset to the actual symbol name +// in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method"). +func getSymbolOffset(symbol string) int { + // Check for :: separator (Rust, C++, Ruby modules/classes, PHP static). + if idx := strings.LastIndex(symbol, "::"); idx != -1 { + return idx + 2 + } + // Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods). + if idx := strings.LastIndex(symbol, "."); idx != -1 { + return idx + 1 + } + // Check for \ separator (PHP namespaces). + if idx := strings.LastIndex(symbol, "\\"); idx != -1 { + return idx + 1 + } + return 0 +} + +func cleanupLocations(locations []protocol.Location) []protocol.Location { + slices.SortFunc(locations, func(a, b protocol.Location) int { + if a.URI != b.URI { + return strings.Compare(string(a.URI), string(b.URI)) + } + if a.Range.Start.Line != b.Range.Start.Line { + return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line) + } + return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character) + }) + return slices.CompactFunc(locations, func(a, b protocol.Location) bool { + return a.URI == b.URI && + a.Range.Start.Line == b.Range.Start.Line && + a.Range.Start.Character == b.Range.Start.Character + }) +} + +func groupByFilename(locations []protocol.Location) map[string][]protocol.Location { + files := make(map[string][]protocol.Location) + for _, loc := range locations { + path, err := loc.URI.Path() + if err != nil { + slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err) + continue + } + files[path] = append(files[path], loc) + } + return files +} + +func formatReferences(locations []protocol.Location) string { + fileRefs := groupByFilename(locations) + files := slices.Collect(maps.Keys(fileRefs)) + sort.Strings(files) + + var output strings.Builder + output.WriteString(fmt.Sprintf("Found %d reference(s) in %d file(s):\n\n", len(locations), len(files))) + + for _, file := range files { + refs := fileRefs[file] + output.WriteString(fmt.Sprintf("%s (%d reference(s)):\n", file, len(refs))) + for _, ref := range refs { + line := ref.Range.Start.Line + 1 + char := ref.Range.Start.Character + 1 + output.WriteString(fmt.Sprintf(" Line %d, Column %d\n", line, char)) + } + output.WriteString("\n") + } + + return output.String() +} diff --git a/internal/llm/tools/references.md b/internal/llm/tools/references.md new file mode 100644 index 0000000000000000000000000000000000000000..951ce71a68b9d62060649cda999107ab9243f42a --- /dev/null +++ b/internal/llm/tools/references.md @@ -0,0 +1,36 @@ +Find all references to/usage of a symbol by name using the Language Server Protocol (LSP). + +WHEN TO USE THIS TOOL: + +- **ALWAYS USE THIS FIRST** when searching for where a function, method, variable, type, or constant is used +- **DO NOT use grep/glob for symbol searches** - this tool is semantic-aware and much more accurate +- Use when you need to find all usages of a specific symbol (function, variable, type, class, method, etc.) +- More accurate than grep because it understands code semantics and scope +- Finds only actual references, not string matches in comments or unrelated code +- Helpful for understanding where a symbol is used throughout the codebase +- Useful for refactoring or analyzing code dependencies +- Good for finding all call sites of a function, method, type, package, constant, variable, etc. + +HOW TO USE: + +- Provide the symbol name (e.g., "MyFunction", "myVariable", "MyType") +- Optionally specify a path to narrow the search to a specific directory +- The tool will automatically find the symbol and locate all references + +FEATURES: + +- Returns all references grouped by file +- Shows line and column numbers for each reference +- Supports multiple programming languages through LSP +- Automatically finds the symbol without needing exact position + +LIMITATIONS: + +- May not find references in files that haven't been opened or indexed +- Results depend on the LSP server's capabilities + +TIPS: + +- **Use this tool instead of grep when looking for symbol references** - it's more accurate and semantic-aware +- Simply provide the symbol name and let the tool find it for you +- This tool understands code structure, so it won't match unrelated strings or comments diff --git a/internal/llm/tools/rg.go b/internal/llm/tools/rg.go index 8809b57c8db30b4ac1ed6c070df5a7218c59e233..76dbb5daf2234669ac3d90552cbbc5af5cc003d0 100644 --- a/internal/llm/tools/rg.go +++ b/internal/llm/tools/rg.go @@ -43,7 +43,7 @@ func getRgSearchCmd(ctx context.Context, pattern, path, include string) *exec.Cm return nil } // Use -n to show line numbers, -0 for null separation to handle Windows paths - args := []string{"-H", "-n", "-0", pattern} + args := []string{"--json", "-H", "-n", "-0", pattern} if include != "" { args = append(args, "--glob", include) } diff --git a/internal/llm/tools/testdata/grep.txt b/internal/llm/tools/testdata/grep.txt new file mode 100644 index 0000000000000000000000000000000000000000..edac9ec894634e3b924fb9a0928a272ac4f29e7e --- /dev/null +++ b/internal/llm/tools/testdata/grep.txt @@ -0,0 +1,3 @@ +test file for grep +I wanna grep THIS particular word +and nothing else diff --git a/internal/lsp/client.go b/internal/lsp/client.go index ff9a3ac9b5249663c151fb2df04a4acb168e4de4..afbe95cc2deb1c37b64c9e9b68fb705a4a0a59f9 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -445,6 +445,16 @@ func (c *Client) WaitForDiagnostics(ctx context.Context, d time.Duration) { } } +// FindReferences finds all references to the symbol at the given position. +func (c *Client) FindReferences(ctx context.Context, filepath string, line, character int, includeDeclaration bool) ([]protocol.Location, error) { + if err := c.OpenFileOnDemand(ctx, filepath); err != nil { + return nil, err + } + // NOTE: line and character should be 0-based. + // See: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#position + return c.client.FindReferences(ctx, filepath, line-1, character-1, includeDeclaration) +} + // HasRootMarkers checks if any of the specified root marker patterns exist in the given directory. // Uses glob patterns to match files, allowing for more flexible matching. func HasRootMarkers(dir string, rootMarkers []string) bool {