grep.go

  1package tools
  2
  3import (
  4	"bufio"
  5	"bytes"
  6	"cmp"
  7	"context"
  8	_ "embed"
  9	"encoding/json"
 10	"fmt"
 11	"io"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"path/filepath"
 16	"regexp"
 17	"sort"
 18	"strings"
 19	"time"
 20
 21	"charm.land/fantasy"
 22	"github.com/charmbracelet/crush/internal/config"
 23	"github.com/charmbracelet/crush/internal/csync"
 24	"github.com/charmbracelet/crush/internal/fsext"
 25)
 26
 27// regexCache provides thread-safe caching of compiled regex patterns
 28type regexCache struct {
 29	*csync.Map[string, *regexp.Regexp]
 30}
 31
 32// newRegexCache creates a new regex cache
 33func newRegexCache() *regexCache {
 34	return &regexCache{
 35		Map: csync.NewMap[string, *regexp.Regexp](),
 36	}
 37}
 38
 39// get retrieves a compiled regex from cache or compiles and caches it
 40func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 41	var rerr error
 42	return rc.GetOrSet(pattern, func() *regexp.Regexp {
 43		regex, err := regexp.Compile(pattern)
 44		if err != nil {
 45			rerr = err
 46		}
 47		return regex
 48	}), rerr
 49}
 50
 51// ResetCache clears compiled regex caches to prevent unbounded growth across sessions.
 52func ResetCache() {
 53	searchRegexCache.Reset(map[string]*regexp.Regexp{})
 54	globRegexCache.Reset(map[string]*regexp.Regexp{})
 55}
 56
 57// Global regex cache instances
 58var (
 59	searchRegexCache = newRegexCache()
 60	globRegexCache   = newRegexCache()
 61	// Pre-compiled regex for glob conversion (used frequently)
 62	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 63)
 64
 65type GrepParams struct {
 66	Pattern     string `json:"pattern" description:"The regex pattern to search for in file contents"`
 67	Path        string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
 68	Include     string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
 69	LiteralText bool   `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
 70}
 71
 72type grepMatch struct {
 73	path     string
 74	modTime  time.Time
 75	lineNum  int
 76	charNum  int
 77	lineText string
 78}
 79
 80type GrepResponseMetadata struct {
 81	NumberOfMatches int  `json:"number_of_matches"`
 82	Truncated       bool `json:"truncated"`
 83}
 84
 85const (
 86	GrepToolName        = "grep"
 87	maxGrepContentWidth = 500
 88)
 89
 90//go:embed grep.md
 91var grepDescription []byte
 92
 93// escapeRegexPattern escapes special regex characters so they're treated as literal characters
 94func escapeRegexPattern(pattern string) string {
 95	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
 96	escaped := pattern
 97
 98	for _, char := range specialChars {
 99		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
100	}
101
102	return escaped
103}
104
105func NewGrepTool(workingDir string, config config.ToolGrep) fantasy.AgentTool {
106	return fantasy.NewAgentTool(
107		GrepToolName,
108		string(grepDescription),
109		func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
110			if params.Pattern == "" {
111				return fantasy.NewTextErrorResponse("pattern is required"), nil
112			}
113
114			searchPattern := params.Pattern
115			if params.LiteralText {
116				searchPattern = escapeRegexPattern(params.Pattern)
117			}
118
119			searchPath := cmp.Or(params.Path, workingDir)
120
121			searchCtx, cancel := context.WithTimeout(ctx, config.GetTimeout())
122			defer cancel()
123
124			matches, truncated, err := searchFiles(searchCtx, searchPattern, searchPath, params.Include, 100)
125			if err != nil {
126				return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
127			}
128
129			var output strings.Builder
130			if len(matches) == 0 {
131				output.WriteString("No files found")
132			} else {
133				fmt.Fprintf(&output, "Found %d matches\n", len(matches))
134
135				currentFile := ""
136				for _, match := range matches {
137					if currentFile != match.path {
138						if currentFile != "" {
139							output.WriteString("\n")
140						}
141						currentFile = match.path
142						fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
143					}
144					if match.lineNum > 0 {
145						lineText := match.lineText
146						if len(lineText) > maxGrepContentWidth {
147							lineText = lineText[:maxGrepContentWidth] + "..."
148						}
149						if match.charNum > 0 {
150							fmt.Fprintf(&output, "  Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
151						} else {
152							fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, lineText)
153						}
154					} else {
155						fmt.Fprintf(&output, "  %s\n", match.path)
156					}
157				}
158
159				if truncated {
160					output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
161				}
162			}
163
164			return fantasy.WithResponseMetadata(
165				fantasy.NewTextResponse(output.String()),
166				GrepResponseMetadata{
167					NumberOfMatches: len(matches),
168					Truncated:       truncated,
169				},
170			), nil
171		})
172}
173
174func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
175	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
176	if err != nil {
177		matches, err = searchFilesWithRegex(pattern, rootPath, include)
178		if err != nil {
179			return nil, false, err
180		}
181	}
182
183	sort.Slice(matches, func(i, j int) bool {
184		return matches[i].modTime.After(matches[j].modTime)
185	})
186
187	truncated := len(matches) > limit
188	if truncated {
189		matches = matches[:limit]
190	}
191
192	return matches, truncated, nil
193}
194
195func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
196	cmd := getRgSearchCmd(ctx, pattern, path, include)
197	if cmd == nil {
198		return nil, fmt.Errorf("ripgrep not found in $PATH")
199	}
200
201	// Only add ignore files if they exist
202	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
203		ignorePath := filepath.Join(path, ignoreFile)
204		if _, err := os.Stat(ignorePath); err == nil {
205			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
206		}
207	}
208
209	output, err := cmd.Output()
210	if err != nil {
211		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
212			return []grepMatch{}, nil
213		}
214		return nil, err
215	}
216
217	var matches []grepMatch
218	for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
219		if len(line) == 0 {
220			continue
221		}
222		var match ripgrepMatch
223		if err := json.Unmarshal(line, &match); err != nil {
224			continue
225		}
226		if match.Type != "match" {
227			continue
228		}
229		for _, m := range match.Data.Submatches {
230			fi, err := os.Stat(match.Data.Path.Text)
231			if err != nil {
232				continue // Skip files we can't access
233			}
234			matches = append(matches, grepMatch{
235				path:     match.Data.Path.Text,
236				modTime:  fi.ModTime(),
237				lineNum:  match.Data.LineNumber,
238				charNum:  m.Start + 1, // ensure 1-based
239				lineText: strings.TrimSpace(match.Data.Lines.Text),
240			})
241			// only get the first match of each line
242			break
243		}
244	}
245	return matches, nil
246}
247
248type ripgrepMatch struct {
249	Type string `json:"type"`
250	Data struct {
251		Path struct {
252			Text string `json:"text"`
253		} `json:"path"`
254		Lines struct {
255			Text string `json:"text"`
256		} `json:"lines"`
257		LineNumber int `json:"line_number"`
258		Submatches []struct {
259			Start int `json:"start"`
260		} `json:"submatches"`
261	} `json:"data"`
262}
263
264func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
265	matches := []grepMatch{}
266
267	// Use cached regex compilation
268	regex, err := searchRegexCache.get(pattern)
269	if err != nil {
270		return nil, fmt.Errorf("invalid regex pattern: %w", err)
271	}
272
273	var includePattern *regexp.Regexp
274	if include != "" {
275		regexPattern := globToRegex(include)
276		includePattern, err = globRegexCache.get(regexPattern)
277		if err != nil {
278			return nil, fmt.Errorf("invalid include pattern: %w", err)
279		}
280	}
281
282	// Create walker with gitignore and crushignore support
283	walker := fsext.NewFastGlobWalker(rootPath)
284
285	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
286		if err != nil {
287			return nil // Skip errors
288		}
289
290		if info.IsDir() {
291			// Check if directory should be skipped
292			if walker.ShouldSkip(path) {
293				return filepath.SkipDir
294			}
295			return nil // Continue into directory
296		}
297
298		// Use walker's shouldSkip method for files
299		if walker.ShouldSkip(path) {
300			return nil
301		}
302
303		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
304		base := filepath.Base(path)
305		if base != "." && strings.HasPrefix(base, ".") {
306			return nil
307		}
308
309		if includePattern != nil && !includePattern.MatchString(path) {
310			return nil
311		}
312
313		match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
314		if err != nil {
315			return nil // Skip files we can't read
316		}
317
318		if match {
319			matches = append(matches, grepMatch{
320				path:     path,
321				modTime:  info.ModTime(),
322				lineNum:  lineNum,
323				charNum:  charNum,
324				lineText: lineText,
325			})
326
327			if len(matches) >= 200 {
328				return filepath.SkipAll
329			}
330		}
331
332		return nil
333	})
334	if err != nil {
335		return nil, err
336	}
337
338	return matches, nil
339}
340
341func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
342	// Only search text files.
343	if !isTextFile(filePath) {
344		return false, 0, 0, "", nil
345	}
346
347	file, err := os.Open(filePath)
348	if err != nil {
349		return false, 0, 0, "", err
350	}
351	defer file.Close()
352
353	scanner := bufio.NewScanner(file)
354	lineNum := 0
355	for scanner.Scan() {
356		lineNum++
357		line := scanner.Text()
358		if loc := pattern.FindStringIndex(line); loc != nil {
359			charNum := loc[0] + 1
360			return true, lineNum, charNum, line, nil
361		}
362	}
363
364	return false, 0, 0, "", scanner.Err()
365}
366
367// isTextFile checks if a file is a text file by examining its MIME type.
368func isTextFile(filePath string) bool {
369	file, err := os.Open(filePath)
370	if err != nil {
371		return false
372	}
373	defer file.Close()
374
375	// Read first 512 bytes for MIME type detection.
376	buffer := make([]byte, 512)
377	n, err := file.Read(buffer)
378	if err != nil && err != io.EOF {
379		return false
380	}
381
382	// Detect content type.
383	contentType := http.DetectContentType(buffer[:n])
384
385	// Check if it's a text MIME type.
386	return strings.HasPrefix(contentType, "text/") ||
387		contentType == "application/json" ||
388		contentType == "application/xml" ||
389		contentType == "application/javascript" ||
390		contentType == "application/x-sh"
391}
392
393func globToRegex(glob string) string {
394	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
395	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
396	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
397
398	// Use pre-compiled regex instead of compiling each time
399	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
400		inner := match[1 : len(match)-1]
401		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
402	})
403
404	return regexPattern
405}