diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 3af4781a8a8277c7a5082ead0fd0b561ab70cce1..7ac059edc87559a39277dcf2710ce1cb08b6058a 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -216,7 +216,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) } func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) { - matches, err := searchWithRipgrep(ctx, pattern, rootPath, include) + matches, err := searchWithRipgrep(ctx, getRgSearchCmd, pattern, rootPath, include) if err != nil { matches, err = searchFilesWithRegex(pattern, rootPath, include) if err != nil { @@ -236,8 +236,9 @@ func searchFiles(ctx context.Context, pattern, rootPath, include string, limit i return matches, truncated, nil } -func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) { - cmd := getRgSearchCmd(ctx, pattern, path, include) +// NOTE(tauraamui): ideally I would want to not pass in the search specific args here but will leave for now +func searchWithRipgrep(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) { + cmd := rgSearchCmd(ctx, pattern, path, include) if cmd == nil { return nil, fmt.Errorf("ripgrep not found in $PATH") } @@ -246,7 +247,7 @@ func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]gr for _, ignoreFile := range []string{".gitignore", ".crushignore"} { ignorePath := filepath.Join(path, ignoreFile) if _, err := os.Stat(ignorePath); err == nil { - cmd.Args = append(cmd.Args, "--ignore-file", ignorePath) + cmd.AddArgs("--ignore-file", ignorePath) } } diff --git a/internal/llm/tools/grep_test.go b/internal/llm/tools/grep_test.go index 53c96b22df444adfba59c6b13995a104411a57be..a23563583661148d0026cd4f09667dfbe92a9e78 100644 --- a/internal/llm/tools/grep_test.go +++ b/internal/llm/tools/grep_test.go @@ -1,7 +1,9 @@ package tools import ( + "context" "os" + "os/exec" "path/filepath" "regexp" "testing" @@ -87,7 +89,7 @@ func TestGrepWithIgnoreFiles(t *testing.T) { for name, fn := range map[string]func(pattern, path, include string) ([]grepMatch, error){ "regex": searchFilesWithRegex, "rg": func(pattern, path, include string) ([]grepMatch, error) { - return searchWithRipgrep(t.Context(), pattern, path, include) + return searchWithRipgrep(t.Context(), getRgSearchCmd, pattern, path, include) }, } { t.Run(name, func(t *testing.T) { @@ -147,7 +149,7 @@ func TestSearchImplementations(t *testing.T) { for name, fn := range map[string]func(pattern, path, include string) ([]grepMatch, error){ "regex": searchFilesWithRegex, "rg": func(pattern, path, include string) ([]grepMatch, error) { - return searchWithRipgrep(t.Context(), pattern, path, include) + return searchWithRipgrep(t.Context(), getRgSearchCmd, pattern, path, include) }, } { t.Run(name, func(t *testing.T) { @@ -175,6 +177,45 @@ func TestSearchImplementations(t *testing.T) { } } +type mockRgExecCmd struct { + args []string + err error +} + +func (m *mockRgExecCmd) AddArgs(args ...string) { + m.args = append(m.args, args...) +} + +func (m *mockRgExecCmd) Output() ([]byte, error) { + if m.err != nil { + return nil, m.err + } + return []byte{}, nil +} + +func TestSearchWithRipGrepButItFailsToRunHandleError(t *testing.T) { + // create separate proc state that exits with code 1 + cmd := exec.Command("sh", "-c", "exit 1") + err := cmd.Run() + require.Error(t, err) + + // Extract the ExitError with real ProcessState. + exitErr, ok := err.(*exec.ExitError) + require.True(t, ok) + require.Equal(t, 1, exitErr.ExitCode()) + + mockRgCmd := mockRgExecCmd{ + err: exitErr, + } + + matches, err := searchWithRipgrep(t.Context(), func(ctx context.Context, pattern, path, include string) execCmd { + return &mockRgCmd + }, "", "", "") + + require.Empty(t, matches) + require.NoError(t, err) +} + // Benchmark to show performance improvement func BenchmarkRegexCacheVsCompile(b *testing.B) { cache := newRegexCache() diff --git a/internal/llm/tools/rg.go b/internal/llm/tools/rg.go index 8809b57c8db30b4ac1ed6c070df5a7218c59e233..db90f738b78d75758f2f8fe6faec307380b55fe6 100644 --- a/internal/llm/tools/rg.go +++ b/internal/llm/tools/rg.go @@ -37,7 +37,26 @@ func getRgCmd(ctx context.Context, globPattern string) *exec.Cmd { return exec.CommandContext(ctx, name, args...) } -func getRgSearchCmd(ctx context.Context, pattern, path, include string) *exec.Cmd { +type execCmd interface { + AddArgs(arg ...string) + Output() ([]byte, error) +} + +type wrappedCmd struct { + cmd *exec.Cmd +} + +func (w wrappedCmd) AddArgs(arg ...string) { + w.cmd.Args = append(w.cmd.Args, arg...) +} + +func (w wrappedCmd) Output() ([]byte, error) { + return w.cmd.Output() +} + +type resolveRgSearchCmd func(ctx context.Context, pattern, path, include string) execCmd + +func getRgSearchCmd(ctx context.Context, pattern, path, include string) execCmd { name := getRg() if name == "" { return nil @@ -49,5 +68,7 @@ func getRgSearchCmd(ctx context.Context, pattern, path, include string) *exec.Cm } args = append(args, path) - return exec.CommandContext(ctx, name, args...) + return wrappedCmd{ + cmd: exec.CommandContext(ctx, name, args...), + } }