grep.go

  1package tools
  2
  3import (
  4	"bufio"
  5	"context"
  6	_ "embed"
  7	"fmt"
  8	"io"
  9	"os"
 10	"os/exec"
 11	"path/filepath"
 12	"regexp"
 13	"sort"
 14	"strconv"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/fsext"
 20	"github.com/charmbracelet/fantasy/ai"
 21)
 22
 23// regexCache provides thread-safe caching of compiled regex patterns
 24type regexCache struct {
 25	cache map[string]*regexp.Regexp
 26	mu    sync.RWMutex
 27}
 28
 29// newRegexCache creates a new regex cache
 30func newRegexCache() *regexCache {
 31	return &regexCache{
 32		cache: make(map[string]*regexp.Regexp),
 33	}
 34}
 35
 36// get retrieves a compiled regex from cache or compiles and caches it
 37func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 38	// Try to get from cache first (read lock)
 39	rc.mu.RLock()
 40	if regex, exists := rc.cache[pattern]; exists {
 41		rc.mu.RUnlock()
 42		return regex, nil
 43	}
 44	rc.mu.RUnlock()
 45
 46	// Compile the regex (write lock)
 47	rc.mu.Lock()
 48	defer rc.mu.Unlock()
 49
 50	// Double-check in case another goroutine compiled it while we waited
 51	if regex, exists := rc.cache[pattern]; exists {
 52		return regex, nil
 53	}
 54
 55	// Compile and cache the regex
 56	regex, err := regexp.Compile(pattern)
 57	if err != nil {
 58		return nil, err
 59	}
 60
 61	rc.cache[pattern] = regex
 62	return regex, nil
 63}
 64
 65// Global regex cache instances
 66var (
 67	searchRegexCache = newRegexCache()
 68	globRegexCache   = newRegexCache()
 69	// Pre-compiled regex for glob conversion (used frequently)
 70	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 71)
 72
 73type GrepParams struct {
 74	Pattern     string `json:"pattern" description:"The regex pattern to search for in file contents"`
 75	Path        string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
 76	Include     string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
 77	LiteralText bool   `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
 78}
 79
 80type grepMatch struct {
 81	path     string
 82	modTime  time.Time
 83	lineNum  int
 84	lineText string
 85}
 86
 87type GrepResponseMetadata struct {
 88	NumberOfMatches int  `json:"number_of_matches"`
 89	Truncated       bool `json:"truncated"`
 90}
 91
 92const (
 93	GrepToolName        = "grep"
 94	maxGrepContentWidth = 500
 95)
 96
 97//go:embed grep.md
 98var grepDescription []byte
 99
100func NewGrepTool(workingDir string) ai.AgentTool {
101	return ai.NewAgentTool(
102		GrepToolName,
103		string(grepDescription),
104		func(ctx context.Context, params GrepParams, call ai.ToolCall) (ai.ToolResponse, error) {
105			if params.Pattern == "" {
106				return ai.NewTextErrorResponse("pattern is required"), nil
107			}
108
109			// If literal_text is true, escape the pattern
110			searchPattern := params.Pattern
111			if params.LiteralText {
112				searchPattern = escapeRegexPattern(params.Pattern)
113			}
114
115			searchPath := params.Path
116			if searchPath == "" {
117				searchPath = workingDir
118			}
119
120			matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
121			if err != nil {
122				return ai.ToolResponse{}, fmt.Errorf("error searching files: %w", err)
123			}
124
125			var output strings.Builder
126			if len(matches) == 0 {
127				output.WriteString("No files found")
128			} else {
129				fmt.Fprintf(&output, "Found %d matches\n", len(matches))
130
131				currentFile := ""
132				for _, match := range matches {
133					if currentFile != match.path {
134						if currentFile != "" {
135							output.WriteString("\n")
136						}
137						currentFile = match.path
138						fmt.Fprintf(&output, "%s:\n", match.path)
139					}
140					if match.lineNum > 0 {
141						lineText := match.lineText
142						if len(lineText) > maxGrepContentWidth {
143							lineText = lineText[:maxGrepContentWidth] + "..."
144						}
145						fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, lineText)
146					} else {
147						fmt.Fprintf(&output, "  %s\n", match.path)
148					}
149				}
150
151				if truncated {
152					output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
153				}
154			}
155
156			return ai.WithResponseMetadata(
157				ai.NewTextResponse(output.String()),
158				GrepResponseMetadata{
159					NumberOfMatches: len(matches),
160					Truncated:       truncated,
161				},
162			), nil
163		})
164}
165
166// escapeRegexPattern escapes special regex characters so they're treated as literal characters
167func escapeRegexPattern(pattern string) string {
168	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
169	escaped := pattern
170
171	for _, char := range specialChars {
172		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
173	}
174
175	return escaped
176}
177
178func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
179	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
180	if err != nil {
181		matches, err = searchFilesWithRegex(pattern, rootPath, include)
182		if err != nil {
183			return nil, false, err
184		}
185	}
186
187	sort.Slice(matches, func(i, j int) bool {
188		return matches[i].modTime.After(matches[j].modTime)
189	})
190
191	truncated := len(matches) > limit
192	if truncated {
193		matches = matches[:limit]
194	}
195
196	return matches, truncated, nil
197}
198
199func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
200	cmd := getRgSearchCmd(ctx, pattern, path, include)
201	if cmd == nil {
202		return nil, fmt.Errorf("ripgrep not found in $PATH")
203	}
204
205	// Only add ignore files if they exist
206	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
207		ignorePath := filepath.Join(path, ignoreFile)
208		if _, err := os.Stat(ignorePath); err == nil {
209			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
210		}
211	}
212
213	output, err := cmd.Output()
214	if err != nil {
215		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
216			return []grepMatch{}, nil
217		}
218		return nil, err
219	}
220
221	lines := strings.Split(strings.TrimSpace(string(output)), "\n")
222	matches := make([]grepMatch, 0, len(lines))
223
224	for _, line := range lines {
225		if line == "" {
226			continue
227		}
228
229		// Parse ripgrep output using null separation
230		filePath, lineNumStr, lineText, ok := parseRipgrepLine(line)
231		if !ok {
232			continue
233		}
234
235		lineNum, err := strconv.Atoi(lineNumStr)
236		if err != nil {
237			continue
238		}
239
240		fileInfo, err := os.Stat(filePath)
241		if err != nil {
242			continue // Skip files we can't access
243		}
244
245		matches = append(matches, grepMatch{
246			path:     filePath,
247			modTime:  fileInfo.ModTime(),
248			lineNum:  lineNum,
249			lineText: lineText,
250		})
251	}
252
253	return matches, nil
254}
255
256// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths
257func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) {
258	// Split on null byte first to separate filename from rest
259	parts := strings.SplitN(line, "\x00", 2)
260	if len(parts) != 2 {
261		return "", "", "", false
262	}
263
264	filePath = parts[0]
265	remainder := parts[1]
266
267	// Now split the remainder on first colon: "linenum:content"
268	colonIndex := strings.Index(remainder, ":")
269	if colonIndex == -1 {
270		return "", "", "", false
271	}
272
273	lineNumStr := remainder[:colonIndex]
274	lineText = remainder[colonIndex+1:]
275
276	if _, err := strconv.Atoi(lineNumStr); err != nil {
277		return "", "", "", false
278	}
279
280	return filePath, lineNumStr, lineText, true
281}
282
283func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
284	matches := []grepMatch{}
285
286	// Use cached regex compilation
287	regex, err := searchRegexCache.get(pattern)
288	if err != nil {
289		return nil, fmt.Errorf("invalid regex pattern: %w", err)
290	}
291
292	var includePattern *regexp.Regexp
293	if include != "" {
294		regexPattern := globToRegex(include)
295		includePattern, err = globRegexCache.get(regexPattern)
296		if err != nil {
297			return nil, fmt.Errorf("invalid include pattern: %w", err)
298		}
299	}
300
301	// Create walker with gitignore and crushignore support
302	walker := fsext.NewFastGlobWalker(rootPath)
303
304	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
305		if err != nil {
306			return nil // Skip errors
307		}
308
309		if info.IsDir() {
310			// Check if directory should be skipped
311			if walker.ShouldSkip(path) {
312				return filepath.SkipDir
313			}
314			return nil // Continue into directory
315		}
316
317		// Use walker's shouldSkip method for files
318		if walker.ShouldSkip(path) {
319			return nil
320		}
321
322		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
323		base := filepath.Base(path)
324		if base != "." && strings.HasPrefix(base, ".") {
325			return nil
326		}
327
328		if includePattern != nil && !includePattern.MatchString(path) {
329			return nil
330		}
331
332		match, lineNum, lineText, err := fileContainsPattern(path, regex)
333		if err != nil {
334			return nil // Skip files we can't read
335		}
336
337		if match {
338			matches = append(matches, grepMatch{
339				path:     path,
340				modTime:  info.ModTime(),
341				lineNum:  lineNum,
342				lineText: lineText,
343			})
344
345			if len(matches) >= 200 {
346				return filepath.SkipAll
347			}
348		}
349
350		return nil
351	})
352	if err != nil {
353		return nil, err
354	}
355
356	return matches, nil
357}
358
359func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
360	// Quick binary file detection
361	if isBinaryFile(filePath) {
362		return false, 0, "", nil
363	}
364
365	file, err := os.Open(filePath)
366	if err != nil {
367		return false, 0, "", err
368	}
369	defer file.Close()
370
371	scanner := bufio.NewScanner(file)
372	lineNum := 0
373	for scanner.Scan() {
374		lineNum++
375		line := scanner.Text()
376		if pattern.MatchString(line) {
377			return true, lineNum, line, nil
378		}
379	}
380
381	return false, 0, "", scanner.Err()
382}
383
384var binaryExts = map[string]struct{}{
385	".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
386	".bin": {}, ".obj": {}, ".o": {}, ".a": {},
387	".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
388	".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
389	".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
390	".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
391}
392
393// isBinaryFile performs a quick check to determine if a file is binary
394func isBinaryFile(filePath string) bool {
395	// Check file extension first (fastest)
396	ext := strings.ToLower(filepath.Ext(filePath))
397	if _, isBinary := binaryExts[ext]; isBinary {
398		return true
399	}
400
401	// Quick content check for files without clear extensions
402	file, err := os.Open(filePath)
403	if err != nil {
404		return false // If we can't open it, let the caller handle the error
405	}
406	defer file.Close()
407
408	// Read first 512 bytes to check for null bytes
409	buffer := make([]byte, 512)
410	n, err := file.Read(buffer)
411	if err != nil && err != io.EOF {
412		return false
413	}
414
415	// Check for null bytes (common in binary files)
416	for i := range n {
417		if buffer[i] == 0 {
418			return true
419		}
420	}
421
422	return false
423}
424
425func globToRegex(glob string) string {
426	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
427	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
428	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
429
430	// Use pre-compiled regex instead of compiling each time
431	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
432		inner := match[1 : len(match)-1]
433		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
434	})
435
436	return regexPattern
437}