diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index 7ac059edc87559a39277dcf2710ce1cb08b6058a..99a8c1c570d89973264447d0bf071eb10881fda3 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -170,7 +170,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) searchPath = g.workingDir } - matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100) + matches, truncated, err := searchFiles(ctx, searchWithRipgrep, searchPattern, searchPath, params.Include, 100) if err != nil { return ToolResponse{}, fmt.Errorf("error searching files: %w", err) } @@ -215,8 +215,8 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) ), nil } -func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) { - matches, err := searchWithRipgrep(ctx, getRgSearchCmd, pattern, rootPath, include) +func searchFiles(ctx context.Context, ripGrepSearch searchWithRipgrapFn, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) { + matches, err := ripGrepSearch(ctx, getRgSearchCmd, pattern, rootPath, include) if err != nil { matches, err = searchFilesWithRegex(pattern, rootPath, include) if err != nil { @@ -236,6 +236,8 @@ func searchFiles(ctx context.Context, pattern, rootPath, include string, limit i return matches, truncated, nil } +type searchWithRipgrapFn func(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) + // NOTE(tauraamui): ideally I would want to not pass in the search specific args here but will leave for now func searchWithRipgrep(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) { cmd := rgSearchCmd(ctx, pattern, path, include) diff --git a/internal/llm/tools/grep_test.go b/internal/llm/tools/grep_test.go index d258d93d7739b8b3c155ecc8caeb9c3921133edf..b858218d9cd23af6b682046d5664cb88bd33c0ec 100644 --- a/internal/llm/tools/grep_test.go +++ b/internal/llm/tools/grep_test.go @@ -2,11 +2,13 @@ package tools import ( "context" + "fmt" "os" "os/exec" "path/filepath" "regexp" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -247,6 +249,65 @@ func TestSearchWithRipGrepButItFailsToRunHandleError(t *testing.T) { } } +func TestSearchFilesWithLimit(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + limit int + numMatches int + expectedMatches int + expectedTruncated bool + }{ + { + name: "limit of 100 truncates 150 results", + limit: 100, + numMatches: 150, + expectedMatches: 100, + expectedTruncated: true, + }, + { + name: "limit of 200 does not truncate 150 results", + limit: 200, + numMatches: 150, + expectedMatches: 150, + expectedTruncated: false, + }, + { + name: "limit of 150 exactly matches all files", + limit: 150, + numMatches: 150, + expectedMatches: 150, + expectedTruncated: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create mock ripgrep search that returns fake matches. + mockRipgrepSearch := func(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) { + matches := make([]grepMatch, tt.numMatches) + for i := 0; i < tt.numMatches; i++ { + matches[i] = grepMatch{ + path: fmt.Sprintf("/fake/path/file%03d.txt", i), + modTime: time.Now().Add(-time.Duration(i) * time.Minute), + lineNum: 1, + lineText: "test pattern", + } + } + return matches, nil + } + + matches, truncated, err := searchFiles(t.Context(), mockRipgrepSearch, "test pattern", "/fake/path", "", tt.limit) + require.NoError(t, err) + require.Equal(t, tt.expectedMatches, len(matches)) + require.Equal(t, tt.expectedTruncated, truncated) + }) + } +} + // Benchmark to show performance improvement func BenchmarkRegexCacheVsCompile(b *testing.B) { cache := newRegexCache()