grep.go

  1package tools
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"cmp"
  7	"context"
  8	_ "embed"
  9	"encoding/json"
 10	"fmt"
 11	"html/template"
 12	"io"
 13	"net/http"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"regexp"
 18	"sort"
 19	"strings"
 20	"time"
 21
 22	"charm.land/fantasy"
 23	"github.com/charmbracelet/crush/internal/config"
 24	"github.com/charmbracelet/crush/internal/csync"
 25	"github.com/charmbracelet/crush/internal/fsext"
 26)
 27
 28// regexCache provides thread-safe caching of compiled regex patterns
 29type regexCache struct {
 30	*csync.Map[string, *regexp.Regexp]
 31}
 32
 33// newRegexCache creates a new regex cache
 34func newRegexCache() *regexCache {
 35	return &regexCache{
 36		Map: csync.NewMap[string, *regexp.Regexp](),
 37	}
 38}
 39
 40// get retrieves a compiled regex from cache or compiles and caches it
 41func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 42	re, ok := rc.Get(pattern)
 43	if ok && re != nil {
 44		return re, nil
 45	}
 46	re, err := regexp.Compile(pattern)
 47	if err != nil {
 48		return nil, err
 49	}
 50	rc.Set(pattern, re)
 51	return re, nil
 52}
 53
 54// ResetCache clears compiled regex caches to prevent unbounded growth across sessions.
 55func ResetCache() {
 56	searchRegexCache.Reset(map[string]*regexp.Regexp{})
 57	globRegexCache.Reset(map[string]*regexp.Regexp{})
 58}
 59
 60// Global regex cache instances
 61var (
 62	searchRegexCache = newRegexCache()
 63	globRegexCache   = newRegexCache()
 64	// Pre-compiled regex for glob conversion (used frequently)
 65	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 66)
 67
 68type GrepParams struct {
 69	Pattern     string `json:"pattern" description:"The regex pattern to search for in file contents"`
 70	Path        string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
 71	Include     string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
 72	LiteralText bool   `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
 73}
 74
 75type grepMatch struct {
 76	path     string
 77	modTime  time.Time
 78	lineNum  int
 79	charNum  int
 80	lineText string
 81}
 82
 83type GrepResponseMetadata struct {
 84	NumberOfMatches int  `json:"number_of_matches"`
 85	Truncated       bool `json:"truncated"`
 86}
 87
 88const (
 89	GrepToolName        = "grep"
 90	maxGrepContentWidth = 500
 91)
 92
 93//go:embed grep.md.tpl
 94var grepDescriptionTmpl []byte
 95
 96var grepDescriptionTpl = template.Must(
 97	template.New("grepDescription").
 98		Parse(string(grepDescriptionTmpl)),
 99)
100
101type grepDescriptionData struct {
102	MaxResults int
103}
104
105func grepDescription() string {
106	return renderTemplate(grepDescriptionTpl, grepDescriptionData{
107		MaxResults: 100,
108	})
109}
110
111// escapeRegexPattern escapes special regex characters so they're treated as literal characters
112func escapeRegexPattern(pattern string) string {
113	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
114	escaped := pattern
115
116	for _, char := range specialChars {
117		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
118	}
119
120	return escaped
121}
122
123func NewGrepTool(workingDir string, config config.ToolGrep) fantasy.AgentTool {
124	return fantasy.NewAgentTool(
125		GrepToolName,
126		grepDescription(),
127		func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
128			if params.Pattern == "" {
129				return fantasy.NewTextErrorResponse("pattern is required"), nil
130			}
131
132			searchPattern := params.Pattern
133			if params.LiteralText {
134				searchPattern = escapeRegexPattern(params.Pattern)
135			}
136
137			searchPath := cmp.Or(params.Path, workingDir)
138
139			searchCtx, cancel := context.WithTimeout(ctx, config.GetTimeout())
140			defer cancel()
141
142			matches, truncated, err := searchFiles(searchCtx, searchPattern, searchPath, params.Include, 100)
143			if err != nil {
144				return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
145			}
146
147			var output strings.Builder
148			if len(matches) == 0 {
149				output.WriteString("No files found")
150			} else {
151				fmt.Fprintf(&output, "Found %d matches\n", len(matches))
152
153				currentFile := ""
154				for _, match := range matches {
155					if currentFile != match.path {
156						if currentFile != "" {
157							output.WriteString("\n")
158						}
159						currentFile = match.path
160						fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
161					}
162					if match.lineNum > 0 {
163						lineText := match.lineText
164						if len(lineText) > maxGrepContentWidth {
165							lineText = lineText[:maxGrepContentWidth] + "..."
166						}
167						if match.charNum > 0 {
168							fmt.Fprintf(&output, "  Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
169						} else {
170							fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, lineText)
171						}
172					} else {
173						fmt.Fprintf(&output, "  %s\n", match.path)
174					}
175				}
176
177				if truncated {
178					output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
179				}
180			}
181
182			return fantasy.WithResponseMetadata(
183				fantasy.NewTextResponse(output.String()),
184				GrepResponseMetadata{
185					NumberOfMatches: len(matches),
186					Truncated:       truncated,
187				},
188			), nil
189		})
190}
191
192func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
193	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
194	if err != nil {
195		matches, err = searchFilesWithRegex(pattern, rootPath, include)
196		if err != nil {
197			return nil, false, err
198		}
199	}
200
201	sort.Slice(matches, func(i, j int) bool {
202		return matches[i].modTime.After(matches[j].modTime)
203	})
204
205	truncated := len(matches) > limit
206	if truncated {
207		matches = matches[:limit]
208	}
209
210	return matches, truncated, nil
211}
212
213func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
214	cmd := getRgSearchCmd(ctx, pattern, path, include)
215	if cmd == nil {
216		return nil, fmt.Errorf("ripgrep not found in $PATH")
217	}
218
219	// Only add ignore files if they exist
220	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
221		ignorePath := filepath.Join(path, ignoreFile)
222		if _, err := os.Stat(ignorePath); err == nil {
223			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
224		}
225	}
226
227	output, err := cmd.Output()
228	if err != nil {
229		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
230			return []grepMatch{}, nil
231		}
232		return nil, err
233	}
234
235	var matches []grepMatch
236	for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
237		if len(line) == 0 {
238			continue
239		}
240		var match ripgrepMatch
241		if err := json.Unmarshal(line, &match); err != nil {
242			continue
243		}
244		if match.Type != "match" {
245			continue
246		}
247		for _, m := range match.Data.Submatches {
248			fi, err := os.Stat(match.Data.Path.Text)
249			if err != nil {
250				continue // Skip files we can't access
251			}
252			matches = append(matches, grepMatch{
253				path:     match.Data.Path.Text,
254				modTime:  fi.ModTime(),
255				lineNum:  match.Data.LineNumber,
256				charNum:  m.Start + 1, // ensure 1-based
257				lineText: strings.TrimSpace(match.Data.Lines.Text),
258			})
259			// only get the first match of each line
260			break
261		}
262	}
263	return matches, nil
264}
265
266type ripgrepMatch struct {
267	Type string `json:"type"`
268	Data struct {
269		Path struct {
270			Text string `json:"text"`
271		} `json:"path"`
272		Lines struct {
273			Text string `json:"text"`
274		} `json:"lines"`
275		LineNumber int `json:"line_number"`
276		Submatches []struct {
277			Start int `json:"start"`
278		} `json:"submatches"`
279	} `json:"data"`
280}
281
282func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
283	matches := []grepMatch{}
284
285	// Use cached regex compilation
286	regex, err := searchRegexCache.get(pattern)
287	if err != nil {
288		return nil, fmt.Errorf("invalid regex pattern: %w", err)
289	}
290
291	var includePattern *regexp.Regexp
292	if include != "" {
293		regexPattern := globToRegex(include)
294		includePattern, err = globRegexCache.get(regexPattern)
295		if err != nil {
296			return nil, fmt.Errorf("invalid include pattern: %w", err)
297		}
298	}
299
300	// Create walker with gitignore and crushignore support
301	walker := fsext.NewFastGlobWalker(rootPath)
302
303	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
304		if err != nil {
305			return nil // Skip errors
306		}
307
308		if info.IsDir() {
309			// Check if directory should be skipped
310			if walker.ShouldSkip(path) {
311				return filepath.SkipDir
312			}
313			return nil // Continue into directory
314		}
315
316		// Use walker's shouldSkip method for files
317		if walker.ShouldSkip(path) {
318			return nil
319		}
320
321		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
322		base := filepath.Base(path)
323		if base != "." && strings.HasPrefix(base, ".") {
324			return nil
325		}
326
327		if includePattern != nil && !includePattern.MatchString(path) {
328			return nil
329		}
330
331		match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
332		if err != nil {
333			return nil // Skip files we can't read
334		}
335
336		if match {
337			matches = append(matches, grepMatch{
338				path:     path,
339				modTime:  info.ModTime(),
340				lineNum:  lineNum,
341				charNum:  charNum,
342				lineText: lineText,
343			})
344
345			if len(matches) >= 200 {
346				return filepath.SkipAll
347			}
348		}
349
350		return nil
351	})
352	if err != nil {
353		return nil, err
354	}
355
356	return matches, nil
357}
358
359func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
360	if pattern == nil {
361		return false, 0, 0, "", nil
362	}
363	// Only search text files.
364	if !isTextFile(filePath) {
365		return false, 0, 0, "", nil
366	}
367
368	file, err := os.Open(filePath)
369	if err != nil {
370		return false, 0, 0, "", err
371	}
372	defer file.Close()
373
374	scanner := bufio.NewScanner(file)
375	lineNum := 0
376	for scanner.Scan() {
377		lineNum++
378		line := scanner.Text()
379		if loc := pattern.FindStringIndex(line); loc != nil {
380			charNum := loc[0] + 1
381			return true, lineNum, charNum, line, nil
382		}
383	}
384
385	return false, 0, 0, "", scanner.Err()
386}
387
388// isTextFile checks if a file is a text file by examining its MIME type.
389func isTextFile(filePath string) bool {
390	file, err := os.Open(filePath)
391	if err != nil {
392		return false
393	}
394	defer file.Close()
395
396	// Read first 512 bytes for MIME type detection.
397	buffer := make([]byte, 512)
398	n, err := file.Read(buffer)
399	if err != nil && err != io.EOF {
400		return false
401	}
402
403	// Detect content type.
404	contentType := http.DetectContentType(buffer[:n])
405
406	// Check if it's a text MIME type.
407	return strings.HasPrefix(contentType, "text/") ||
408		contentType == "application/json" ||
409		contentType == "application/xml" ||
410		contentType == "application/javascript" ||
411		contentType == "application/x-sh"
412}
413
414func globToRegex(glob string) string {
415	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
416	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
417	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
418
419	// Use pre-compiled regex instead of compiling each time
420	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
421		inner := match[1 : len(match)-1]
422		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
423	})
424
425	return regexPattern
426}