test(grep): help cover failure case for invoking rg

tauraamui created

Change summary

internal/llm/tools/grep.go      |  9 +++---
internal/llm/tools/grep_test.go | 45 +++++++++++++++++++++++++++++++++-
internal/llm/tools/rg.go        | 25 +++++++++++++++++-
3 files changed, 71 insertions(+), 8 deletions(-)

Detailed changes

internal/llm/tools/grep.go 🔗

@@ -216,7 +216,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 }
 
 func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
-	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
+	matches, err := searchWithRipgrep(ctx, getRgSearchCmd, pattern, rootPath, include)
 	if err != nil {
 		matches, err = searchFilesWithRegex(pattern, rootPath, include)
 		if err != nil {
@@ -236,8 +236,9 @@ func searchFiles(ctx context.Context, pattern, rootPath, include string, limit i
 	return matches, truncated, nil
 }
 
-func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
-	cmd := getRgSearchCmd(ctx, pattern, path, include)
+// NOTE(tauraamui): ideally I would want to not pass in the search specific args here but will leave for now
+func searchWithRipgrep(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) {
+	cmd := rgSearchCmd(ctx, pattern, path, include)
 	if cmd == nil {
 		return nil, fmt.Errorf("ripgrep not found in $PATH")
 	}
@@ -246,7 +247,7 @@ func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]gr
 	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
 		ignorePath := filepath.Join(path, ignoreFile)
 		if _, err := os.Stat(ignorePath); err == nil {
-			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
+			cmd.AddArgs("--ignore-file", ignorePath)
 		}
 	}
 

internal/llm/tools/grep_test.go 🔗

@@ -1,7 +1,9 @@
 package tools
 
 import (
+	"context"
 	"os"
+	"os/exec"
 	"path/filepath"
 	"regexp"
 	"testing"
@@ -87,7 +89,7 @@ func TestGrepWithIgnoreFiles(t *testing.T) {
 	for name, fn := range map[string]func(pattern, path, include string) ([]grepMatch, error){
 		"regex": searchFilesWithRegex,
 		"rg": func(pattern, path, include string) ([]grepMatch, error) {
-			return searchWithRipgrep(t.Context(), pattern, path, include)
+			return searchWithRipgrep(t.Context(), getRgSearchCmd, pattern, path, include)
 		},
 	} {
 		t.Run(name, func(t *testing.T) {
@@ -147,7 +149,7 @@ func TestSearchImplementations(t *testing.T) {
 	for name, fn := range map[string]func(pattern, path, include string) ([]grepMatch, error){
 		"regex": searchFilesWithRegex,
 		"rg": func(pattern, path, include string) ([]grepMatch, error) {
-			return searchWithRipgrep(t.Context(), pattern, path, include)
+			return searchWithRipgrep(t.Context(), getRgSearchCmd, pattern, path, include)
 		},
 	} {
 		t.Run(name, func(t *testing.T) {
@@ -175,6 +177,45 @@ func TestSearchImplementations(t *testing.T) {
 	}
 }
 
+type mockRgExecCmd struct {
+	args []string
+	err  error
+}
+
+func (m *mockRgExecCmd) AddArgs(args ...string) {
+	m.args = append(m.args, args...)
+}
+
+func (m *mockRgExecCmd) Output() ([]byte, error) {
+	if m.err != nil {
+		return nil, m.err
+	}
+	return []byte{}, nil
+}
+
+func TestSearchWithRipGrepButItFailsToRunHandleError(t *testing.T) {
+	// create separate proc state that exits with code 1
+	cmd := exec.Command("sh", "-c", "exit 1")
+	err := cmd.Run()
+	require.Error(t, err)
+
+	// Extract the ExitError with real ProcessState.
+	exitErr, ok := err.(*exec.ExitError)
+	require.True(t, ok)
+	require.Equal(t, 1, exitErr.ExitCode())
+
+	mockRgCmd := mockRgExecCmd{
+		err: exitErr,
+	}
+
+	matches, err := searchWithRipgrep(t.Context(), func(ctx context.Context, pattern, path, include string) execCmd {
+		return &mockRgCmd
+	}, "", "", "")
+
+	require.Empty(t, matches)
+	require.NoError(t, err)
+}
+
 // Benchmark to show performance improvement
 func BenchmarkRegexCacheVsCompile(b *testing.B) {
 	cache := newRegexCache()

internal/llm/tools/rg.go 🔗

@@ -37,7 +37,26 @@ func getRgCmd(ctx context.Context, globPattern string) *exec.Cmd {
 	return exec.CommandContext(ctx, name, args...)
 }
 
-func getRgSearchCmd(ctx context.Context, pattern, path, include string) *exec.Cmd {
+type execCmd interface {
+	AddArgs(arg ...string)
+	Output() ([]byte, error)
+}
+
+type wrappedCmd struct {
+	cmd *exec.Cmd
+}
+
+func (w wrappedCmd) AddArgs(arg ...string) {
+	w.cmd.Args = append(w.cmd.Args, arg...)
+}
+
+func (w wrappedCmd) Output() ([]byte, error) {
+	return w.cmd.Output()
+}
+
+type resolveRgSearchCmd func(ctx context.Context, pattern, path, include string) execCmd
+
+func getRgSearchCmd(ctx context.Context, pattern, path, include string) execCmd {
 	name := getRg()
 	if name == "" {
 		return nil
@@ -49,5 +68,7 @@ func getRgSearchCmd(ctx context.Context, pattern, path, include string) *exec.Cm
 	}
 	args = append(args, path)
 
-	return exec.CommandContext(ctx, name, args...)
+	return wrappedCmd{
+		cmd: exec.CommandContext(ctx, name, args...),
+	}
 }