grep.go

  1package tools
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"cmp"
  7	"context"
  8	_ "embed"
  9	"encoding/json"
 10	"fmt"
 11	"html/template"
 12	"io"
 13	"net/http"
 14	"os"
 15	"os/exec"
 16	"path/filepath"
 17	"regexp"
 18	"sort"
 19	"strings"
 20	"time"
 21
 22	"charm.land/fantasy"
 23	"github.com/charmbracelet/crush/internal/config"
 24	"github.com/charmbracelet/crush/internal/csync"
 25	"github.com/charmbracelet/crush/internal/fsext"
 26)
 27
 28// regexCache provides thread-safe caching of compiled regex patterns
 29type regexCache struct {
 30	*csync.Map[string, *regexp.Regexp]
 31}
 32
 33// newRegexCache creates a new regex cache
 34func newRegexCache() *regexCache {
 35	return &regexCache{
 36		Map: csync.NewMap[string, *regexp.Regexp](),
 37	}
 38}
 39
 40// get retrieves a compiled regex from cache or compiles and caches it
 41func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 42	re, ok := rc.Get(pattern)
 43	if ok && re != nil {
 44		return re, nil
 45	}
 46	re, err := regexp.Compile(pattern)
 47	if err != nil {
 48		return nil, err
 49	}
 50	rc.Set(pattern, re)
 51	return re, nil
 52}
 53
 54// ResetCache clears compiled regex caches to prevent unbounded growth across sessions.
 55func ResetCache() {
 56	searchRegexCache.Reset(map[string]*regexp.Regexp{})
 57	globRegexCache.Reset(map[string]*regexp.Regexp{})
 58}
 59
 60// Global regex cache instances
 61var (
 62	searchRegexCache = newRegexCache()
 63	globRegexCache   = newRegexCache()
 64	// Pre-compiled regex for glob conversion (used frequently)
 65	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 66)
 67
 68type GrepParams struct {
 69	Pattern     string `json:"pattern" description:"The regex pattern to search for in file contents"`
 70	Path        string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
 71	Include     string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
 72	LiteralText bool   `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
 73}
 74
 75type grepMatch struct {
 76	path     string
 77	modTime  time.Time
 78	lineNum  int
 79	charNum  int
 80	lineText string
 81}
 82
 83type GrepResponseMetadata struct {
 84	NumberOfMatches int  `json:"number_of_matches"`
 85	Truncated       bool `json:"truncated"`
 86}
 87
 88const (
 89	GrepToolName        = "grep"
 90	maxGrepContentWidth = 500
 91)
 92
 93//go:embed grep.md.tpl
 94var grepDescriptionTmpl []byte
 95
 96var grepDescriptionTpl = template.Must(
 97	template.New("grepDescription").
 98		Parse(string(grepDescriptionTmpl)),
 99)
100
101type grepDescriptionData struct {
102	MaxResults int
103}
104
105func grepDescription() string {
106	return renderTemplate(grepDescriptionTpl, grepDescriptionData{
107		MaxResults: 100,
108	})
109}
110
111// escapeRegexPattern escapes special regex characters so they're treated as literal characters
112func escapeRegexPattern(pattern string) string {
113	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
114	escaped := pattern
115
116	for _, char := range specialChars {
117		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
118	}
119
120	return escaped
121}
122
123func NewGrepTool(workingDir string, config config.ToolGrep) fantasy.AgentTool {
124	return fantasy.NewAgentTool(
125		GrepToolName,
126		grepDescription(),
127		func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
128			if params.Pattern == "" {
129				return fantasy.NewTextErrorResponse("pattern is required"), nil
130			}
131
132			searchPattern := params.Pattern
133			if params.LiteralText {
134				searchPattern = escapeRegexPattern(params.Pattern)
135			}
136
137			searchPath := cmp.Or(params.Path, workingDir)
138
139			searchCtx, cancel := context.WithTimeout(ctx, config.GetTimeout())
140			defer cancel()
141
142			matches, truncated, err := searchFiles(searchCtx, searchPattern, searchPath, params.Include, 100)
143			if err != nil {
144				return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
145			}
146
147			var output strings.Builder
148			if len(matches) == 0 {
149				output.WriteString("No files found")
150			} else {
151				fmt.Fprintf(&output, "Found %d matches\n", len(matches))
152
153				currentFile := ""
154				for _, match := range matches {
155					if currentFile != match.path {
156						if currentFile != "" {
157							output.WriteString("\n")
158						}
159						currentFile = match.path
160						fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
161					}
162					if match.lineNum > 0 {
163						lineText := match.lineText
164						if len(lineText) > maxGrepContentWidth {
165							lineText = lineText[:maxGrepContentWidth] + "..."
166						}
167						if match.charNum > 0 {
168							fmt.Fprintf(&output, "  Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
169						} else {
170							fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, lineText)
171						}
172					} else {
173						fmt.Fprintf(&output, "  %s\n", match.path)
174					}
175				}
176
177				if truncated {
178					output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
179				}
180			}
181
182			return fantasy.WithResponseMetadata(
183				fantasy.NewTextResponse(output.String()),
184				GrepResponseMetadata{
185					NumberOfMatches: len(matches),
186					Truncated:       truncated,
187				},
188			), nil
189		},
190	)
191}
192
193func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
194	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
195	if err != nil {
196		matches, err = searchFilesWithRegex(pattern, rootPath, include)
197		if err != nil {
198			return nil, false, err
199		}
200	}
201
202	sort.Slice(matches, func(i, j int) bool {
203		return matches[i].modTime.After(matches[j].modTime)
204	})
205
206	truncated := len(matches) > limit
207	if truncated {
208		matches = matches[:limit]
209	}
210
211	return matches, truncated, nil
212}
213
214func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
215	cmd := getRgSearchCmd(ctx, pattern, path, include)
216	if cmd == nil {
217		return nil, fmt.Errorf("ripgrep not found in $PATH")
218	}
219
220	// Only add ignore files if they exist
221	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
222		ignorePath := filepath.Join(path, ignoreFile)
223		if _, err := os.Stat(ignorePath); err == nil {
224			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
225		}
226	}
227
228	output, err := cmd.Output()
229	if err != nil {
230		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
231			return []grepMatch{}, nil
232		}
233		return nil, err
234	}
235
236	var matches []grepMatch
237	for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
238		if len(line) == 0 {
239			continue
240		}
241		var match ripgrepMatch
242		if err := json.Unmarshal(line, &match); err != nil {
243			continue
244		}
245		if match.Type != "match" {
246			continue
247		}
248		for _, m := range match.Data.Submatches {
249			fi, err := os.Stat(match.Data.Path.Text)
250			if err != nil {
251				continue // Skip files we can't access
252			}
253			matches = append(matches, grepMatch{
254				path:     match.Data.Path.Text,
255				modTime:  fi.ModTime(),
256				lineNum:  match.Data.LineNumber,
257				charNum:  m.Start + 1, // ensure 1-based
258				lineText: strings.TrimSpace(match.Data.Lines.Text),
259			})
260			// only get the first match of each line
261			break
262		}
263	}
264	return matches, nil
265}
266
267type ripgrepMatch struct {
268	Type string `json:"type"`
269	Data struct {
270		Path struct {
271			Text string `json:"text"`
272		} `json:"path"`
273		Lines struct {
274			Text string `json:"text"`
275		} `json:"lines"`
276		LineNumber int `json:"line_number"`
277		Submatches []struct {
278			Start int `json:"start"`
279		} `json:"submatches"`
280	} `json:"data"`
281}
282
283func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
284	matches := []grepMatch{}
285
286	// Use cached regex compilation
287	regex, err := searchRegexCache.get(pattern)
288	if err != nil {
289		return nil, fmt.Errorf("invalid regex pattern: %w", err)
290	}
291
292	var includePattern *regexp.Regexp
293	if include != "" {
294		regexPattern := globToRegex(include)
295		includePattern, err = globRegexCache.get(regexPattern)
296		if err != nil {
297			return nil, fmt.Errorf("invalid include pattern: %w", err)
298		}
299	}
300
301	// Create walker with gitignore and crushignore support
302	walker := fsext.NewFastGlobWalker(rootPath)
303
304	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
305		if err != nil {
306			return nil // Skip errors
307		}
308
309		if info.IsDir() {
310			// Check if directory should be skipped
311			if walker.ShouldSkip(path) {
312				return filepath.SkipDir
313			}
314			return nil // Continue into directory
315		}
316
317		// Use walker's shouldSkip method for files
318		if walker.ShouldSkip(path) {
319			return nil
320		}
321
322		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
323		base := filepath.Base(path)
324		if base != "." && strings.HasPrefix(base, ".") {
325			return nil
326		}
327
328		if includePattern != nil && !includePattern.MatchString(path) {
329			return nil
330		}
331
332		match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
333		if err != nil {
334			return nil // Skip files we can't read
335		}
336
337		if match {
338			matches = append(matches, grepMatch{
339				path:     path,
340				modTime:  info.ModTime(),
341				lineNum:  lineNum,
342				charNum:  charNum,
343				lineText: lineText,
344			})
345
346			if len(matches) >= 200 {
347				return filepath.SkipAll
348			}
349		}
350
351		return nil
352	})
353	if err != nil {
354		return nil, err
355	}
356
357	return matches, nil
358}
359
360func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
361	if pattern == nil {
362		return false, 0, 0, "", nil
363	}
364	// Only search text files.
365	if !isTextFile(filePath) {
366		return false, 0, 0, "", nil
367	}
368
369	file, err := os.Open(filePath)
370	if err != nil {
371		return false, 0, 0, "", err
372	}
373	defer file.Close()
374
375	reader := bufio.NewReader(file)
376	lineNum := 0
377	for {
378		line, err := reader.ReadString('\n')
379		lineNum++
380		line = strings.TrimSuffix(line, "\n")
381		line = strings.TrimSuffix(line, "\r")
382		if loc := pattern.FindStringIndex(line); loc != nil {
383			charNum := loc[0] + 1
384			return true, lineNum, charNum, line, nil
385		}
386		if err == io.EOF {
387			break
388		}
389		if err != nil {
390			return false, 0, 0, "", err
391		}
392	}
393
394	return false, 0, 0, "", nil
395}
396
397// isTextFile checks if a file is a text file by examining its MIME type.
398func isTextFile(filePath string) bool {
399	file, err := os.Open(filePath)
400	if err != nil {
401		return false
402	}
403	defer file.Close()
404
405	// Read first 512 bytes for MIME type detection.
406	buffer := make([]byte, 512)
407	n, err := file.Read(buffer)
408	if err != nil && err != io.EOF {
409		return false
410	}
411
412	// Detect content type.
413	contentType := http.DetectContentType(buffer[:n])
414
415	// Check if it's a text MIME type.
416	return strings.HasPrefix(contentType, "text/") ||
417		contentType == "application/json" ||
418		contentType == "application/xml" ||
419		contentType == "application/javascript" ||
420		contentType == "application/x-sh"
421}
422
423func globToRegex(glob string) string {
424	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
425	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
426	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
427
428	// Use pre-compiled regex instead of compiling each time
429	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
430		inner := match[1 : len(match)-1]
431		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
432	})
433
434	return regexPattern
435}