grep.go

  1package tools
  2
  3import (
  4	"bufio"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"os"
 11	"os/exec"
 12	"path/filepath"
 13	"regexp"
 14	"sort"
 15	"strconv"
 16	"strings"
 17	"sync"
 18	"time"
 19
 20	"github.com/charmbracelet/crush/internal/fsext"
 21)
 22
 23// regexCache provides thread-safe caching of compiled regex patterns
 24type regexCache struct {
 25	cache map[string]*regexp.Regexp
 26	mu    sync.RWMutex
 27}
 28
 29// newRegexCache creates a new regex cache
 30func newRegexCache() *regexCache {
 31	return &regexCache{
 32		cache: make(map[string]*regexp.Regexp),
 33	}
 34}
 35
 36// get retrieves a compiled regex from cache or compiles and caches it
 37func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 38	// Try to get from cache first (read lock)
 39	rc.mu.RLock()
 40	if regex, exists := rc.cache[pattern]; exists {
 41		rc.mu.RUnlock()
 42		return regex, nil
 43	}
 44	rc.mu.RUnlock()
 45
 46	// Compile the regex (write lock)
 47	rc.mu.Lock()
 48	defer rc.mu.Unlock()
 49
 50	// Double-check in case another goroutine compiled it while we waited
 51	if regex, exists := rc.cache[pattern]; exists {
 52		return regex, nil
 53	}
 54
 55	// Compile and cache the regex
 56	regex, err := regexp.Compile(pattern)
 57	if err != nil {
 58		return nil, err
 59	}
 60
 61	rc.cache[pattern] = regex
 62	return regex, nil
 63}
 64
 65// Global regex cache instances
 66var (
 67	searchRegexCache = newRegexCache()
 68	globRegexCache   = newRegexCache()
 69	// Pre-compiled regex for glob conversion (used frequently)
 70	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 71)
 72
 73type GrepParams struct {
 74	Pattern     string `json:"pattern"`
 75	Path        string `json:"path"`
 76	Include     string `json:"include"`
 77	LiteralText bool   `json:"literal_text"`
 78}
 79
 80type grepMatch struct {
 81	path     string
 82	modTime  time.Time
 83	lineNum  int
 84	lineText string
 85}
 86
 87type GrepResponseMetadata struct {
 88	NumberOfMatches int  `json:"number_of_matches"`
 89	Truncated       bool `json:"truncated"`
 90}
 91
 92type grepTool struct {
 93	workingDir string
 94}
 95
 96const (
 97	GrepToolName        = "grep"
 98	maxGrepContentWidth = 500
 99)
100
101//go:embed grep.md
102var grepDescription []byte
103
104func NewGrepTool(workingDir string) BaseTool {
105	return &grepTool{
106		workingDir: workingDir,
107	}
108}
109
110func (g *grepTool) Name() string {
111	return GrepToolName
112}
113
114func (g *grepTool) Info() ToolInfo {
115	return ToolInfo{
116		Name:        GrepToolName,
117		Description: string(grepDescription),
118		Parameters: map[string]any{
119			"pattern": map[string]any{
120				"type":        "string",
121				"description": "The regex pattern to search for in file contents",
122			},
123			"path": map[string]any{
124				"type":        "string",
125				"description": "The directory to search in. Defaults to the current working directory.",
126			},
127			"include": map[string]any{
128				"type":        "string",
129				"description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
130			},
131			"literal_text": map[string]any{
132				"type":        "boolean",
133				"description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
134			},
135		},
136		Required: []string{"pattern"},
137	}
138}
139
140// escapeRegexPattern escapes special regex characters so they're treated as literal characters
141func escapeRegexPattern(pattern string) string {
142	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
143	escaped := pattern
144
145	for _, char := range specialChars {
146		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
147	}
148
149	return escaped
150}
151
152func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
153	var params GrepParams
154	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
155		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
156	}
157
158	if params.Pattern == "" {
159		return NewTextErrorResponse("pattern is required"), nil
160	}
161
162	// If literal_text is true, escape the pattern
163	searchPattern := params.Pattern
164	if params.LiteralText {
165		searchPattern = escapeRegexPattern(params.Pattern)
166	}
167
168	searchPath := params.Path
169	if searchPath == "" {
170		searchPath = g.workingDir
171	}
172
173	matches, truncated, err := searchFiles(ctx, searchWithRipgrep, searchPattern, searchPath, params.Include, 100)
174	if err != nil {
175		return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
176	}
177
178	var output strings.Builder
179	if len(matches) == 0 {
180		output.WriteString("No files found")
181	} else {
182		fmt.Fprintf(&output, "Found %d matches\n", len(matches))
183
184		currentFile := ""
185		for _, match := range matches {
186			if currentFile != match.path {
187				if currentFile != "" {
188					output.WriteString("\n")
189				}
190				currentFile = match.path
191				fmt.Fprintf(&output, "%s:\n", match.path)
192			}
193			if match.lineNum > 0 {
194				lineText := match.lineText
195				if len(lineText) > maxGrepContentWidth {
196					lineText = lineText[:maxGrepContentWidth] + "..."
197				}
198				fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, lineText)
199			} else {
200				fmt.Fprintf(&output, "  %s\n", match.path)
201			}
202		}
203
204		if truncated {
205			output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
206		}
207	}
208
209	return WithResponseMetadata(
210		NewTextResponse(output.String()),
211		GrepResponseMetadata{
212			NumberOfMatches: len(matches),
213			Truncated:       truncated,
214		},
215	), nil
216}
217
218func searchFiles(ctx context.Context, ripGrepSearch searchWithRipgrapFn, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
219	matches, err := ripGrepSearch(ctx, getRgSearchCmd, pattern, rootPath, include)
220	if err != nil {
221		matches, err = searchFilesWithRegex(pattern, rootPath, include)
222		if err != nil {
223			return nil, false, err
224		}
225	}
226
227	sort.Slice(matches, func(i, j int) bool {
228		return matches[i].modTime.After(matches[j].modTime)
229	})
230
231	truncated := len(matches) > limit
232	if truncated {
233		matches = matches[:limit]
234	}
235
236	return matches, truncated, nil
237}
238
239type searchWithRipgrapFn func(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error)
240
241// NOTE(tauraamui): ideally I would want to not pass in the search specific args here but will leave for now
242func searchWithRipgrep(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) {
243	cmd := rgSearchCmd(ctx, pattern, path, include)
244	if cmd == nil {
245		return nil, fmt.Errorf("ripgrep not found in $PATH")
246	}
247
248	// Only add ignore files if they exist
249	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
250		ignorePath := filepath.Join(path, ignoreFile)
251		if _, err := os.Stat(ignorePath); err == nil {
252			cmd.AddArgs("--ignore-file", ignorePath)
253		}
254	}
255
256	output, err := cmd.Output()
257	if err != nil {
258		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
259			return []grepMatch{}, nil
260		}
261		return nil, err
262	}
263
264	lines := strings.Split(strings.TrimSpace(string(output)), "\n")
265	matches := make([]grepMatch, 0, len(lines))
266
267	for _, line := range lines {
268		if line == "" {
269			continue
270		}
271
272		// Parse ripgrep output using null separation
273		filePath, lineNumStr, lineText, ok := parseRipgrepLine(line)
274		if !ok {
275			continue
276		}
277
278		lineNum, err := strconv.Atoi(lineNumStr)
279		if err != nil {
280			continue
281		}
282
283		fileInfo, err := os.Stat(filePath)
284		if err != nil {
285			continue // Skip files we can't access
286		}
287
288		matches = append(matches, grepMatch{
289			path:     filePath,
290			modTime:  fileInfo.ModTime(),
291			lineNum:  lineNum,
292			lineText: lineText,
293		})
294	}
295
296	return matches, nil
297}
298
299// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths
300func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) {
301	// Split on null byte first to separate filename from rest
302	parts := strings.SplitN(line, "\x00", 2)
303	if len(parts) != 2 {
304		return "", "", "", false
305	}
306
307	filePath = parts[0]
308	remainder := parts[1]
309
310	// Now split the remainder on first colon: "linenum:content"
311	colonIndex := strings.Index(remainder, ":")
312	if colonIndex == -1 {
313		return "", "", "", false
314	}
315
316	lineNumStr := remainder[:colonIndex]
317	lineText = remainder[colonIndex+1:]
318
319	if _, err := strconv.Atoi(lineNumStr); err != nil {
320		return "", "", "", false
321	}
322
323	return filePath, lineNumStr, lineText, true
324}
325
326func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
327	matches := []grepMatch{}
328
329	// Use cached regex compilation
330	regex, err := searchRegexCache.get(pattern)
331	if err != nil {
332		return nil, fmt.Errorf("invalid regex pattern: %w", err)
333	}
334
335	var includePattern *regexp.Regexp
336	if include != "" {
337		regexPattern := globToRegex(include)
338		includePattern, err = globRegexCache.get(regexPattern)
339		if err != nil {
340			return nil, fmt.Errorf("invalid include pattern: %w", err)
341		}
342	}
343
344	// Create walker with gitignore and crushignore support
345	walker := fsext.NewFastGlobWalker(rootPath)
346
347	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
348		if err != nil {
349			return nil // Skip errors
350		}
351
352		if info.IsDir() {
353			// Check if directory should be skipped
354			if walker.ShouldSkip(path) {
355				return filepath.SkipDir
356			}
357			return nil // Continue into directory
358		}
359
360		// Use walker's shouldSkip method for files
361		if walker.ShouldSkip(path) {
362			return nil
363		}
364
365		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
366		base := filepath.Base(path)
367		if base != "." && strings.HasPrefix(base, ".") {
368			return nil
369		}
370
371		if includePattern != nil && !includePattern.MatchString(path) {
372			return nil
373		}
374
375		match, lineNum, lineText, err := fileContainsPattern(path, regex)
376		if err != nil {
377			return nil // Skip files we can't read
378		}
379
380		if match {
381			matches = append(matches, grepMatch{
382				path:     path,
383				modTime:  info.ModTime(),
384				lineNum:  lineNum,
385				lineText: lineText,
386			})
387
388			if len(matches) >= 200 {
389				return filepath.SkipAll
390			}
391		}
392
393		return nil
394	})
395	if err != nil {
396		return nil, err
397	}
398
399	return matches, nil
400}
401
402func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
403	// Quick binary file detection
404	if isBinaryFile(filePath) {
405		return false, 0, "", nil
406	}
407
408	file, err := os.Open(filePath)
409	if err != nil {
410		return false, 0, "", err
411	}
412	defer file.Close()
413
414	scanner := bufio.NewScanner(file)
415	lineNum := 0
416	for scanner.Scan() {
417		lineNum++
418		line := scanner.Text()
419		if pattern.MatchString(line) {
420			return true, lineNum, line, nil
421		}
422	}
423
424	return false, 0, "", scanner.Err()
425}
426
427var binaryExts = map[string]struct{}{
428	".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
429	".bin": {}, ".obj": {}, ".o": {}, ".a": {},
430	".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
431	".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
432	".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
433	".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
434}
435
436// isBinaryFile performs a quick check to determine if a file is binary
437func isBinaryFile(filePath string) bool {
438	// Check file extension first (fastest)
439	ext := strings.ToLower(filepath.Ext(filePath))
440	if _, isBinary := binaryExts[ext]; isBinary {
441		return true
442	}
443
444	// Quick content check for files without clear extensions
445	file, err := os.Open(filePath)
446	if err != nil {
447		return false // If we can't open it, let the caller handle the error
448	}
449	defer file.Close()
450
451	// Read first 512 bytes to check for null bytes
452	buffer := make([]byte, 512)
453	n, err := file.Read(buffer)
454	if err != nil && err != io.EOF {
455		return false
456	}
457
458	// Check for null bytes (common in binary files)
459	for i := range n {
460		if buffer[i] == 0 {
461			return true
462		}
463	}
464
465	return false
466}
467
468func globToRegex(glob string) string {
469	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
470	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
471	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
472
473	// Use pre-compiled regex instead of compiling each time
474	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
475		inner := match[1 : len(match)-1]
476		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
477	})
478
479	return regexPattern
480}