grep.go

  1package tools
  2
  3import (
  4	"bufio"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"os"
 11	"os/exec"
 12	"path/filepath"
 13	"regexp"
 14	"sort"
 15	"strconv"
 16	"strings"
 17	"sync"
 18	"time"
 19
 20	"github.com/charmbracelet/crush/internal/fsext"
 21)
 22
 23// regexCache provides thread-safe caching of compiled regex patterns
 24type regexCache struct {
 25	cache map[string]*regexp.Regexp
 26	mu    sync.RWMutex
 27}
 28
 29// newRegexCache creates a new regex cache
 30func newRegexCache() *regexCache {
 31	return &regexCache{
 32		cache: make(map[string]*regexp.Regexp),
 33	}
 34}
 35
 36// get retrieves a compiled regex from cache or compiles and caches it
 37func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 38	// Try to get from cache first (read lock)
 39	rc.mu.RLock()
 40	if regex, exists := rc.cache[pattern]; exists {
 41		rc.mu.RUnlock()
 42		return regex, nil
 43	}
 44	rc.mu.RUnlock()
 45
 46	// Compile the regex (write lock)
 47	rc.mu.Lock()
 48	defer rc.mu.Unlock()
 49
 50	// Double-check in case another goroutine compiled it while we waited
 51	if regex, exists := rc.cache[pattern]; exists {
 52		return regex, nil
 53	}
 54
 55	// Compile and cache the regex
 56	regex, err := regexp.Compile(pattern)
 57	if err != nil {
 58		return nil, err
 59	}
 60
 61	rc.cache[pattern] = regex
 62	return regex, nil
 63}
 64
 65// Global regex cache instances
 66var (
 67	searchRegexCache = newRegexCache()
 68	globRegexCache   = newRegexCache()
 69	// Pre-compiled regex for glob conversion (used frequently)
 70	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 71)
 72
 73type GrepParams struct {
 74	Pattern     string `json:"pattern"`
 75	Path        string `json:"path"`
 76	Include     string `json:"include"`
 77	LiteralText bool   `json:"literal_text"`
 78}
 79
 80type grepMatch struct {
 81	path     string
 82	modTime  time.Time
 83	lineNum  int
 84	lineText string
 85}
 86
 87type GrepResponseMetadata struct {
 88	NumberOfMatches int  `json:"number_of_matches"`
 89	Truncated       bool `json:"truncated"`
 90}
 91
 92type grepTool struct {
 93	workingDir string
 94}
 95
 96const GrepToolName = "grep"
 97
 98//go:embed grep.md
 99var grepDescription []byte
100
101func NewGrepTool(workingDir string) BaseTool {
102	return &grepTool{
103		workingDir: workingDir,
104	}
105}
106
107func (g *grepTool) Name() string {
108	return GrepToolName
109}
110
111func (g *grepTool) Info() ToolInfo {
112	return ToolInfo{
113		Name:        GrepToolName,
114		Description: string(grepDescription),
115		Parameters: map[string]any{
116			"pattern": map[string]any{
117				"type":        "string",
118				"description": "The regex pattern to search for in file contents",
119			},
120			"path": map[string]any{
121				"type":        "string",
122				"description": "The directory to search in. Defaults to the current working directory.",
123			},
124			"include": map[string]any{
125				"type":        "string",
126				"description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
127			},
128			"literal_text": map[string]any{
129				"type":        "boolean",
130				"description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
131			},
132		},
133		Required: []string{"pattern"},
134	}
135}
136
137// escapeRegexPattern escapes special regex characters so they're treated as literal characters
138func escapeRegexPattern(pattern string) string {
139	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
140	escaped := pattern
141
142	for _, char := range specialChars {
143		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
144	}
145
146	return escaped
147}
148
149func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
150	var params GrepParams
151	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
152		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
153	}
154
155	if params.Pattern == "" {
156		return NewTextErrorResponse("pattern is required"), nil
157	}
158
159	// If literal_text is true, escape the pattern
160	searchPattern := params.Pattern
161	if params.LiteralText {
162		searchPattern = escapeRegexPattern(params.Pattern)
163	}
164
165	searchPath := params.Path
166	if searchPath == "" {
167		searchPath = g.workingDir
168	}
169
170	matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
171	if err != nil {
172		return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
173	}
174
175	var output strings.Builder
176	if len(matches) == 0 {
177		output.WriteString("No files found")
178	} else {
179		fmt.Fprintf(&output, "Found %d matches\n", len(matches))
180
181		currentFile := ""
182		for _, match := range matches {
183			if currentFile != match.path {
184				if currentFile != "" {
185					output.WriteString("\n")
186				}
187				currentFile = match.path
188				fmt.Fprintf(&output, "%s:\n", match.path)
189			}
190			if match.lineNum > 0 {
191				fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, match.lineText)
192			} else {
193				fmt.Fprintf(&output, "  %s\n", match.path)
194			}
195		}
196
197		if truncated {
198			output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
199		}
200	}
201
202	return WithResponseMetadata(
203		NewTextResponse(output.String()),
204		GrepResponseMetadata{
205			NumberOfMatches: len(matches),
206			Truncated:       truncated,
207		},
208	), nil
209}
210
211func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
212	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
213	if err != nil {
214		matches, err = searchFilesWithRegex(pattern, rootPath, include)
215		if err != nil {
216			return nil, false, err
217		}
218	}
219
220	sort.Slice(matches, func(i, j int) bool {
221		return matches[i].modTime.After(matches[j].modTime)
222	})
223
224	truncated := len(matches) > limit
225	if truncated {
226		matches = matches[:limit]
227	}
228
229	return matches, truncated, nil
230}
231
232func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
233	cmd := getRgSearchCmd(ctx, pattern, path, include)
234	if cmd == nil {
235		return nil, fmt.Errorf("ripgrep not found in $PATH")
236	}
237
238	// Only add ignore files if they exist
239	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
240		ignorePath := filepath.Join(path, ignoreFile)
241		if _, err := os.Stat(ignorePath); err == nil {
242			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
243		}
244	}
245
246	output, err := cmd.Output()
247	if err != nil {
248		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
249			return []grepMatch{}, nil
250		}
251		return nil, err
252	}
253
254	lines := strings.Split(strings.TrimSpace(string(output)), "\n")
255	matches := make([]grepMatch, 0, len(lines))
256
257	for _, line := range lines {
258		if line == "" {
259			continue
260		}
261
262		// Parse ripgrep output format: file:line:content
263		parts := strings.SplitN(line, ":", 3)
264		if len(parts) < 3 {
265			continue
266		}
267
268		filePath := parts[0]
269		lineNum, err := strconv.Atoi(parts[1])
270		if err != nil {
271			continue
272		}
273		lineText := parts[2]
274
275		fileInfo, err := os.Stat(filePath)
276		if err != nil {
277			continue // Skip files we can't access
278		}
279
280		matches = append(matches, grepMatch{
281			path:     filePath,
282			modTime:  fileInfo.ModTime(),
283			lineNum:  lineNum,
284			lineText: lineText,
285		})
286	}
287
288	return matches, nil
289}
290
291func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
292	matches := []grepMatch{}
293
294	// Use cached regex compilation
295	regex, err := searchRegexCache.get(pattern)
296	if err != nil {
297		return nil, fmt.Errorf("invalid regex pattern: %w", err)
298	}
299
300	var includePattern *regexp.Regexp
301	if include != "" {
302		regexPattern := globToRegex(include)
303		includePattern, err = globRegexCache.get(regexPattern)
304		if err != nil {
305			return nil, fmt.Errorf("invalid include pattern: %w", err)
306		}
307	}
308
309	// Create walker with gitignore and crushignore support
310	walker := fsext.NewFastGlobWalker(rootPath)
311
312	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
313		if err != nil {
314			return nil // Skip errors
315		}
316
317		if info.IsDir() {
318			// Check if directory should be skipped
319			if walker.ShouldSkip(path) {
320				return filepath.SkipDir
321			}
322			return nil // Continue into directory
323		}
324
325		// Use walker's shouldSkip method for files
326		if walker.ShouldSkip(path) {
327			return nil
328		}
329
330		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
331		base := filepath.Base(path)
332		if base != "." && strings.HasPrefix(base, ".") {
333			return nil
334		}
335
336		if includePattern != nil && !includePattern.MatchString(path) {
337			return nil
338		}
339
340		match, lineNum, lineText, err := fileContainsPattern(path, regex)
341		if err != nil {
342			return nil // Skip files we can't read
343		}
344
345		if match {
346			matches = append(matches, grepMatch{
347				path:     path,
348				modTime:  info.ModTime(),
349				lineNum:  lineNum,
350				lineText: lineText,
351			})
352
353			if len(matches) >= 200 {
354				return filepath.SkipAll
355			}
356		}
357
358		return nil
359	})
360	if err != nil {
361		return nil, err
362	}
363
364	return matches, nil
365}
366
367func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
368	// Quick binary file detection
369	if isBinaryFile(filePath) {
370		return false, 0, "", nil
371	}
372
373	file, err := os.Open(filePath)
374	if err != nil {
375		return false, 0, "", err
376	}
377	defer file.Close()
378
379	scanner := bufio.NewScanner(file)
380	lineNum := 0
381	for scanner.Scan() {
382		lineNum++
383		line := scanner.Text()
384		if pattern.MatchString(line) {
385			return true, lineNum, line, nil
386		}
387	}
388
389	return false, 0, "", scanner.Err()
390}
391
392var binaryExts = map[string]struct{}{
393	".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
394	".bin": {}, ".obj": {}, ".o": {}, ".a": {},
395	".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
396	".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
397	".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
398	".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
399}
400
401// isBinaryFile performs a quick check to determine if a file is binary
402func isBinaryFile(filePath string) bool {
403	// Check file extension first (fastest)
404	ext := strings.ToLower(filepath.Ext(filePath))
405	if _, isBinary := binaryExts[ext]; isBinary {
406		return true
407	}
408
409	// Quick content check for files without clear extensions
410	file, err := os.Open(filePath)
411	if err != nil {
412		return false // If we can't open it, let the caller handle the error
413	}
414	defer file.Close()
415
416	// Read first 512 bytes to check for null bytes
417	buffer := make([]byte, 512)
418	n, err := file.Read(buffer)
419	if err != nil && err != io.EOF {
420		return false
421	}
422
423	// Check for null bytes (common in binary files)
424	for i := range n {
425		if buffer[i] == 0 {
426			return true
427		}
428	}
429
430	return false
431}
432
433func globToRegex(glob string) string {
434	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
435	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
436	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
437
438	// Use pre-compiled regex instead of compiling each time
439	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
440		inner := match[1 : len(match)-1]
441		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
442	})
443
444	return regexPattern
445}