test(grep): ensure lines over len are trunced as well as output amount

tauraamui created

Change summary

internal/llm/tools/grep.go      |  8 ++-
internal/llm/tools/grep_test.go | 61 +++++++++++++++++++++++++++++++++++
2 files changed, 66 insertions(+), 3 deletions(-)

Detailed changes

internal/llm/tools/grep.go 🔗

@@ -170,7 +170,7 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 		searchPath = g.workingDir
 	}
 
-	matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
+	matches, truncated, err := searchFiles(ctx, searchWithRipgrep, searchPattern, searchPath, params.Include, 100)
 	if err != nil {
 		return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
 	}
@@ -215,8 +215,8 @@ func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 	), nil
 }
 
-func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
-	matches, err := searchWithRipgrep(ctx, getRgSearchCmd, pattern, rootPath, include)
+func searchFiles(ctx context.Context, ripGrepSearch searchWithRipgrapFn, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
+	matches, err := ripGrepSearch(ctx, getRgSearchCmd, pattern, rootPath, include)
 	if err != nil {
 		matches, err = searchFilesWithRegex(pattern, rootPath, include)
 		if err != nil {
@@ -236,6 +236,8 @@ func searchFiles(ctx context.Context, pattern, rootPath, include string, limit i
 	return matches, truncated, nil
 }
 
+type searchWithRipgrapFn func(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error)
+
 // NOTE(tauraamui): ideally I would want to not pass in the search specific args here but will leave for now
 func searchWithRipgrep(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) {
 	cmd := rgSearchCmd(ctx, pattern, path, include)

internal/llm/tools/grep_test.go 🔗

@@ -2,11 +2,13 @@ package tools
 
 import (
 	"context"
+	"fmt"
 	"os"
 	"os/exec"
 	"path/filepath"
 	"regexp"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/require"
 )
@@ -247,6 +249,65 @@ func TestSearchWithRipGrepButItFailsToRunHandleError(t *testing.T) {
 	}
 }
 
+func TestSearchFilesWithLimit(t *testing.T) {
+	t.Parallel()
+
+	tests := []struct {
+		name              string
+		limit             int
+		numMatches        int
+		expectedMatches   int
+		expectedTruncated bool
+	}{
+		{
+			name:              "limit of 100 truncates 150 results",
+			limit:             100,
+			numMatches:        150,
+			expectedMatches:   100,
+			expectedTruncated: true,
+		},
+		{
+			name:              "limit of 200 does not truncate 150 results",
+			limit:             200,
+			numMatches:        150,
+			expectedMatches:   150,
+			expectedTruncated: false,
+		},
+		{
+			name:              "limit of 150 exactly matches all files",
+			limit:             150,
+			numMatches:        150,
+			expectedMatches:   150,
+			expectedTruncated: false,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			t.Parallel()
+
+			// Create mock ripgrep search that returns fake matches.
+			mockRipgrepSearch := func(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) {
+				matches := make([]grepMatch, tt.numMatches)
+				for i := 0; i < tt.numMatches; i++ {
+					matches[i] = grepMatch{
+						path:     fmt.Sprintf("/fake/path/file%03d.txt", i),
+						modTime:  time.Now().Add(-time.Duration(i) * time.Minute),
+						lineNum:  1,
+						lineText: "test pattern",
+					}
+				}
+				return matches, nil
+			}
+
+			matches, truncated, err := searchFiles(t.Context(), mockRipgrepSearch, "test pattern", "/fake/path", "", tt.limit)
+			require.NoError(t, err)
+			require.Equal(t, tt.expectedMatches, len(matches))
+			require.Equal(t, tt.expectedTruncated, truncated)
+		})
+	}
+}
+
 // Benchmark to show performance improvement
 func BenchmarkRegexCacheVsCompile(b *testing.B) {
 	cache := newRegexCache()