grep.go

  1package tools
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"cmp"
  7	"context"
  8	_ "embed"
  9	"encoding/json"
 10	"fmt"
 11	"io"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"path/filepath"
 16	"regexp"
 17	"sort"
 18	"strings"
 19	"time"
 20
 21	"charm.land/fantasy"
 22	"github.com/charmbracelet/crush/internal/config"
 23	"github.com/charmbracelet/crush/internal/csync"
 24	"github.com/charmbracelet/crush/internal/fsext"
 25)
 26
 27// regexCache provides thread-safe caching of compiled regex patterns
 28type regexCache struct {
 29	*csync.Map[string, *regexp.Regexp]
 30}
 31
 32// newRegexCache creates a new regex cache
 33func newRegexCache() *regexCache {
 34	return &regexCache{
 35		Map: csync.NewMap[string, *regexp.Regexp](),
 36	}
 37}
 38
 39// get retrieves a compiled regex from cache or compiles and caches it
 40func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 41	re, ok := rc.Get(pattern)
 42	if ok && re != nil {
 43		return re, nil
 44	}
 45	re, err := regexp.Compile(pattern)
 46	if err != nil {
 47		return nil, err
 48	}
 49	rc.Set(pattern, re)
 50	return re, nil
 51}
 52
 53// ResetCache clears compiled regex caches to prevent unbounded growth across sessions.
 54func ResetCache() {
 55	searchRegexCache.Reset(map[string]*regexp.Regexp{})
 56	globRegexCache.Reset(map[string]*regexp.Regexp{})
 57}
 58
 59// Global regex cache instances
 60var (
 61	searchRegexCache = newRegexCache()
 62	globRegexCache   = newRegexCache()
 63	// Pre-compiled regex for glob conversion (used frequently)
 64	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 65)
 66
 67type GrepParams struct {
 68	Pattern     string `json:"pattern" description:"The regex pattern to search for in file contents"`
 69	Path        string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
 70	Include     string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
 71	LiteralText bool   `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
 72}
 73
 74type grepMatch struct {
 75	path     string
 76	modTime  time.Time
 77	lineNum  int
 78	charNum  int
 79	lineText string
 80}
 81
 82type GrepResponseMetadata struct {
 83	NumberOfMatches int  `json:"number_of_matches"`
 84	Truncated       bool `json:"truncated"`
 85}
 86
 87const (
 88	GrepToolName        = "grep"
 89	maxGrepContentWidth = 500
 90)
 91
 92//go:embed grep.md
 93var grepDescription []byte
 94
 95// escapeRegexPattern escapes special regex characters so they're treated as literal characters
 96func escapeRegexPattern(pattern string) string {
 97	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
 98	escaped := pattern
 99
100	for _, char := range specialChars {
101		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
102	}
103
104	return escaped
105}
106
107func NewGrepTool(workingDir string, config config.ToolGrep) fantasy.AgentTool {
108	return fantasy.NewAgentTool(
109		GrepToolName,
110		FirstLineDescription(grepDescription),
111		func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
112			if params.Pattern == "" {
113				return fantasy.NewTextErrorResponse("pattern is required"), nil
114			}
115
116			searchPattern := params.Pattern
117			if params.LiteralText {
118				searchPattern = escapeRegexPattern(params.Pattern)
119			}
120
121			searchPath := cmp.Or(params.Path, workingDir)
122
123			searchCtx, cancel := context.WithTimeout(ctx, config.GetTimeout())
124			defer cancel()
125
126			matches, truncated, err := searchFiles(searchCtx, searchPattern, searchPath, params.Include, 100)
127			if err != nil {
128				return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
129			}
130
131			var output strings.Builder
132			if len(matches) == 0 {
133				output.WriteString("No files found")
134			} else {
135				fmt.Fprintf(&output, "Found %d matches\n", len(matches))
136
137				currentFile := ""
138				for _, match := range matches {
139					if currentFile != match.path {
140						if currentFile != "" {
141							output.WriteString("\n")
142						}
143						currentFile = match.path
144						fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
145					}
146					if match.lineNum > 0 {
147						lineText := match.lineText
148						if len(lineText) > maxGrepContentWidth {
149							lineText = lineText[:maxGrepContentWidth] + "..."
150						}
151						if match.charNum > 0 {
152							fmt.Fprintf(&output, "  Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
153						} else {
154							fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, lineText)
155						}
156					} else {
157						fmt.Fprintf(&output, "  %s\n", match.path)
158					}
159				}
160
161				if truncated {
162					output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
163				}
164			}
165
166			return fantasy.WithResponseMetadata(
167				fantasy.NewTextResponse(output.String()),
168				GrepResponseMetadata{
169					NumberOfMatches: len(matches),
170					Truncated:       truncated,
171				},
172			), nil
173		})
174}
175
176func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
177	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
178	if err != nil {
179		matches, err = searchFilesWithRegex(pattern, rootPath, include)
180		if err != nil {
181			return nil, false, err
182		}
183	}
184
185	sort.Slice(matches, func(i, j int) bool {
186		return matches[i].modTime.After(matches[j].modTime)
187	})
188
189	truncated := len(matches) > limit
190	if truncated {
191		matches = matches[:limit]
192	}
193
194	return matches, truncated, nil
195}
196
197func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
198	cmd := getRgSearchCmd(ctx, pattern, path, include)
199	if cmd == nil {
200		return nil, fmt.Errorf("ripgrep not found in $PATH")
201	}
202
203	// Only add ignore files if they exist
204	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
205		ignorePath := filepath.Join(path, ignoreFile)
206		if _, err := os.Stat(ignorePath); err == nil {
207			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
208		}
209	}
210
211	output, err := cmd.Output()
212	if err != nil {
213		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
214			return []grepMatch{}, nil
215		}
216		return nil, err
217	}
218
219	var matches []grepMatch
220	for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
221		if len(line) == 0 {
222			continue
223		}
224		var match ripgrepMatch
225		if err := json.Unmarshal(line, &match); err != nil {
226			continue
227		}
228		if match.Type != "match" {
229			continue
230		}
231		for _, m := range match.Data.Submatches {
232			fi, err := os.Stat(match.Data.Path.Text)
233			if err != nil {
234				continue // Skip files we can't access
235			}
236			matches = append(matches, grepMatch{
237				path:     match.Data.Path.Text,
238				modTime:  fi.ModTime(),
239				lineNum:  match.Data.LineNumber,
240				charNum:  m.Start + 1, // ensure 1-based
241				lineText: strings.TrimSpace(match.Data.Lines.Text),
242			})
243			// only get the first match of each line
244			break
245		}
246	}
247	return matches, nil
248}
249
250type ripgrepMatch struct {
251	Type string `json:"type"`
252	Data struct {
253		Path struct {
254			Text string `json:"text"`
255		} `json:"path"`
256		Lines struct {
257			Text string `json:"text"`
258		} `json:"lines"`
259		LineNumber int `json:"line_number"`
260		Submatches []struct {
261			Start int `json:"start"`
262		} `json:"submatches"`
263	} `json:"data"`
264}
265
266func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
267	matches := []grepMatch{}
268
269	// Use cached regex compilation
270	regex, err := searchRegexCache.get(pattern)
271	if err != nil {
272		return nil, fmt.Errorf("invalid regex pattern: %w", err)
273	}
274
275	var includePattern *regexp.Regexp
276	if include != "" {
277		regexPattern := globToRegex(include)
278		includePattern, err = globRegexCache.get(regexPattern)
279		if err != nil {
280			return nil, fmt.Errorf("invalid include pattern: %w", err)
281		}
282	}
283
284	// Create walker with gitignore and crushignore support
285	walker := fsext.NewFastGlobWalker(rootPath)
286
287	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
288		if err != nil {
289			return nil // Skip errors
290		}
291
292		if info.IsDir() {
293			// Check if directory should be skipped
294			if walker.ShouldSkip(path) {
295				return filepath.SkipDir
296			}
297			return nil // Continue into directory
298		}
299
300		// Use walker's shouldSkip method for files
301		if walker.ShouldSkip(path) {
302			return nil
303		}
304
305		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
306		base := filepath.Base(path)
307		if base != "." && strings.HasPrefix(base, ".") {
308			return nil
309		}
310
311		if includePattern != nil && !includePattern.MatchString(path) {
312			return nil
313		}
314
315		match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
316		if err != nil {
317			return nil // Skip files we can't read
318		}
319
320		if match {
321			matches = append(matches, grepMatch{
322				path:     path,
323				modTime:  info.ModTime(),
324				lineNum:  lineNum,
325				charNum:  charNum,
326				lineText: lineText,
327			})
328
329			if len(matches) >= 200 {
330				return filepath.SkipAll
331			}
332		}
333
334		return nil
335	})
336	if err != nil {
337		return nil, err
338	}
339
340	return matches, nil
341}
342
343func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
344	if pattern == nil {
345		return false, 0, 0, "", nil
346	}
347	// Only search text files.
348	if !isTextFile(filePath) {
349		return false, 0, 0, "", nil
350	}
351
352	file, err := os.Open(filePath)
353	if err != nil {
354		return false, 0, 0, "", err
355	}
356	defer file.Close()
357
358	scanner := bufio.NewScanner(file)
359	lineNum := 0
360	for scanner.Scan() {
361		lineNum++
362		line := scanner.Text()
363		if loc := pattern.FindStringIndex(line); loc != nil {
364			charNum := loc[0] + 1
365			return true, lineNum, charNum, line, nil
366		}
367	}
368
369	return false, 0, 0, "", scanner.Err()
370}
371
372// isTextFile checks if a file is a text file by examining its MIME type.
373func isTextFile(filePath string) bool {
374	file, err := os.Open(filePath)
375	if err != nil {
376		return false
377	}
378	defer file.Close()
379
380	// Read first 512 bytes for MIME type detection.
381	buffer := make([]byte, 512)
382	n, err := file.Read(buffer)
383	if err != nil && err != io.EOF {
384		return false
385	}
386
387	// Detect content type.
388	contentType := http.DetectContentType(buffer[:n])
389
390	// Check if it's a text MIME type.
391	return strings.HasPrefix(contentType, "text/") ||
392		contentType == "application/json" ||
393		contentType == "application/xml" ||
394		contentType == "application/javascript" ||
395		contentType == "application/x-sh"
396}
397
398func globToRegex(glob string) string {
399	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
400	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
401	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
402
403	// Use pre-compiled regex instead of compiling each time
404	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
405		inner := match[1 : len(match)-1]
406		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
407	})
408
409	return regexPattern
410}