grep.go

  1package tools
  2
  3import (
  4	"bufio"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"net/http"
 11	"os"
 12	"os/exec"
 13	"path/filepath"
 14	"regexp"
 15	"sort"
 16	"strconv"
 17	"strings"
 18	"sync"
 19	"time"
 20
 21	"github.com/charmbracelet/crush/internal/fsext"
 22)
 23
 24// regexCache provides thread-safe caching of compiled regex patterns
 25type regexCache struct {
 26	cache map[string]*regexp.Regexp
 27	mu    sync.RWMutex
 28}
 29
 30// newRegexCache creates a new regex cache
 31func newRegexCache() *regexCache {
 32	return &regexCache{
 33		cache: make(map[string]*regexp.Regexp),
 34	}
 35}
 36
 37// get retrieves a compiled regex from cache or compiles and caches it
 38func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 39	// Try to get from cache first (read lock)
 40	rc.mu.RLock()
 41	if regex, exists := rc.cache[pattern]; exists {
 42		rc.mu.RUnlock()
 43		return regex, nil
 44	}
 45	rc.mu.RUnlock()
 46
 47	// Compile the regex (write lock)
 48	rc.mu.Lock()
 49	defer rc.mu.Unlock()
 50
 51	// Double-check in case another goroutine compiled it while we waited
 52	if regex, exists := rc.cache[pattern]; exists {
 53		return regex, nil
 54	}
 55
 56	// Compile and cache the regex
 57	regex, err := regexp.Compile(pattern)
 58	if err != nil {
 59		return nil, err
 60	}
 61
 62	rc.cache[pattern] = regex
 63	return regex, nil
 64}
 65
 66// Global regex cache instances
 67var (
 68	searchRegexCache = newRegexCache()
 69	globRegexCache   = newRegexCache()
 70	// Pre-compiled regex for glob conversion (used frequently)
 71	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 72)
 73
 74type GrepParams struct {
 75	Pattern     string `json:"pattern"`
 76	Path        string `json:"path"`
 77	Include     string `json:"include"`
 78	LiteralText bool   `json:"literal_text"`
 79}
 80
 81type grepMatch struct {
 82	path     string
 83	modTime  time.Time
 84	lineNum  int
 85	lineText string
 86}
 87
 88type GrepResponseMetadata struct {
 89	NumberOfMatches int  `json:"number_of_matches"`
 90	Truncated       bool `json:"truncated"`
 91}
 92
 93type grepTool struct {
 94	workingDir string
 95}
 96
 97const GrepToolName = "grep"
 98
 99//go:embed grep.md
100var grepDescription []byte
101
102func NewGrepTool(workingDir string) BaseTool {
103	return &grepTool{
104		workingDir: workingDir,
105	}
106}
107
108func (g *grepTool) Name() string {
109	return GrepToolName
110}
111
112func (g *grepTool) Info() ToolInfo {
113	return ToolInfo{
114		Name:        GrepToolName,
115		Description: string(grepDescription),
116		Parameters: map[string]any{
117			"pattern": map[string]any{
118				"type":        "string",
119				"description": "The regex pattern to search for in file contents",
120			},
121			"path": map[string]any{
122				"type":        "string",
123				"description": "The directory to search in. Defaults to the current working directory.",
124			},
125			"include": map[string]any{
126				"type":        "string",
127				"description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
128			},
129			"literal_text": map[string]any{
130				"type":        "boolean",
131				"description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
132			},
133		},
134		Required: []string{"pattern"},
135	}
136}
137
138// escapeRegexPattern escapes special regex characters so they're treated as literal characters
139func escapeRegexPattern(pattern string) string {
140	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
141	escaped := pattern
142
143	for _, char := range specialChars {
144		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
145	}
146
147	return escaped
148}
149
150func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
151	var params GrepParams
152	if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
153		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
154	}
155
156	if params.Pattern == "" {
157		return NewTextErrorResponse("pattern is required"), nil
158	}
159
160	// If literal_text is true, escape the pattern
161	searchPattern := params.Pattern
162	if params.LiteralText {
163		searchPattern = escapeRegexPattern(params.Pattern)
164	}
165
166	searchPath := params.Path
167	if searchPath == "" {
168		searchPath = g.workingDir
169	}
170
171	matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
172	if err != nil {
173		return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
174	}
175
176	var output strings.Builder
177	if len(matches) == 0 {
178		output.WriteString("No files found")
179	} else {
180		fmt.Fprintf(&output, "Found %d matches\n", len(matches))
181
182		currentFile := ""
183		for _, match := range matches {
184			if currentFile != match.path {
185				if currentFile != "" {
186					output.WriteString("\n")
187				}
188				currentFile = match.path
189				fmt.Fprintf(&output, "%s:\n", match.path)
190			}
191			if match.lineNum > 0 {
192				fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, match.lineText)
193			} else {
194				fmt.Fprintf(&output, "  %s\n", match.path)
195			}
196		}
197
198		if truncated {
199			output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
200		}
201	}
202
203	return WithResponseMetadata(
204		NewTextResponse(output.String()),
205		GrepResponseMetadata{
206			NumberOfMatches: len(matches),
207			Truncated:       truncated,
208		},
209	), nil
210}
211
212func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
213	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
214	if err != nil {
215		matches, err = searchFilesWithRegex(pattern, rootPath, include)
216		if err != nil {
217			return nil, false, err
218		}
219	}
220
221	sort.Slice(matches, func(i, j int) bool {
222		return matches[i].modTime.After(matches[j].modTime)
223	})
224
225	truncated := len(matches) > limit
226	if truncated {
227		matches = matches[:limit]
228	}
229
230	return matches, truncated, nil
231}
232
233func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
234	cmd := getRgSearchCmd(ctx, pattern, path, include)
235	if cmd == nil {
236		return nil, fmt.Errorf("ripgrep not found in $PATH")
237	}
238
239	// Only add ignore files if they exist
240	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
241		ignorePath := filepath.Join(path, ignoreFile)
242		if _, err := os.Stat(ignorePath); err == nil {
243			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
244		}
245	}
246
247	output, err := cmd.Output()
248	if err != nil {
249		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
250			return []grepMatch{}, nil
251		}
252		return nil, err
253	}
254
255	lines := strings.Split(strings.TrimSpace(string(output)), "\n")
256	matches := make([]grepMatch, 0, len(lines))
257
258	for _, line := range lines {
259		if line == "" {
260			continue
261		}
262
263		// Parse ripgrep output using null separation
264		filePath, lineNumStr, lineText, ok := parseRipgrepLine(line)
265		if !ok {
266			continue
267		}
268
269		lineNum, err := strconv.Atoi(lineNumStr)
270		if err != nil {
271			continue
272		}
273
274		fileInfo, err := os.Stat(filePath)
275		if err != nil {
276			continue // Skip files we can't access
277		}
278
279		matches = append(matches, grepMatch{
280			path:     filePath,
281			modTime:  fileInfo.ModTime(),
282			lineNum:  lineNum,
283			lineText: lineText,
284		})
285	}
286
287	return matches, nil
288}
289
290// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths
291func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) {
292	// Split on null byte first to separate filename from rest
293	parts := strings.SplitN(line, "\x00", 2)
294	if len(parts) != 2 {
295		return "", "", "", false
296	}
297
298	filePath = parts[0]
299	remainder := parts[1]
300
301	// Now split the remainder on first colon: "linenum:content"
302	colonIndex := strings.Index(remainder, ":")
303	if colonIndex == -1 {
304		return "", "", "", false
305	}
306
307	lineNumStr := remainder[:colonIndex]
308	lineText = remainder[colonIndex+1:]
309
310	if _, err := strconv.Atoi(lineNumStr); err != nil {
311		return "", "", "", false
312	}
313
314	return filePath, lineNumStr, lineText, true
315}
316
317func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
318	matches := []grepMatch{}
319
320	// Use cached regex compilation
321	regex, err := searchRegexCache.get(pattern)
322	if err != nil {
323		return nil, fmt.Errorf("invalid regex pattern: %w", err)
324	}
325
326	var includePattern *regexp.Regexp
327	if include != "" {
328		regexPattern := globToRegex(include)
329		includePattern, err = globRegexCache.get(regexPattern)
330		if err != nil {
331			return nil, fmt.Errorf("invalid include pattern: %w", err)
332		}
333	}
334
335	// Create walker with gitignore and crushignore support
336	walker := fsext.NewFastGlobWalker(rootPath)
337
338	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
339		if err != nil {
340			return nil // Skip errors
341		}
342
343		if info.IsDir() {
344			// Check if directory should be skipped
345			if walker.ShouldSkip(path) {
346				return filepath.SkipDir
347			}
348			return nil // Continue into directory
349		}
350
351		// Use walker's shouldSkip method for files
352		if walker.ShouldSkip(path) {
353			return nil
354		}
355
356		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
357		base := filepath.Base(path)
358		if base != "." && strings.HasPrefix(base, ".") {
359			return nil
360		}
361
362		if includePattern != nil && !includePattern.MatchString(path) {
363			return nil
364		}
365
366		match, lineNum, lineText, err := fileContainsPattern(path, regex)
367		if err != nil {
368			return nil // Skip files we can't read
369		}
370
371		if match {
372			matches = append(matches, grepMatch{
373				path:     path,
374				modTime:  info.ModTime(),
375				lineNum:  lineNum,
376				lineText: lineText,
377			})
378
379			if len(matches) >= 200 {
380				return filepath.SkipAll
381			}
382		}
383
384		return nil
385	})
386	if err != nil {
387		return nil, err
388	}
389
390	return matches, nil
391}
392
393func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
394	// Only search text files.
395	if !isTextFile(filePath) {
396		return false, 0, "", nil
397	}
398
399	file, err := os.Open(filePath)
400	if err != nil {
401		return false, 0, "", err
402	}
403	defer file.Close()
404
405	scanner := bufio.NewScanner(file)
406	lineNum := 0
407	for scanner.Scan() {
408		lineNum++
409		line := scanner.Text()
410		if pattern.MatchString(line) {
411			return true, lineNum, line, nil
412		}
413	}
414
415	return false, 0, "", scanner.Err()
416}
417
418// isTextFile checks if a file is a text file by examining its MIME type.
419func isTextFile(filePath string) bool {
420	file, err := os.Open(filePath)
421	if err != nil {
422		return false
423	}
424	defer file.Close()
425
426	// Read first 512 bytes for MIME type detection.
427	buffer := make([]byte, 512)
428	n, err := file.Read(buffer)
429	if err != nil && err != io.EOF {
430		return false
431	}
432
433	// Detect content type.
434	contentType := http.DetectContentType(buffer[:n])
435
436	// Check if it's a text MIME type.
437	return strings.HasPrefix(contentType, "text/") ||
438		contentType == "application/json" ||
439		contentType == "application/xml" ||
440		contentType == "application/javascript" ||
441		contentType == "application/x-sh"
442}
443
444func globToRegex(glob string) string {
445	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
446	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
447	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
448
449	// Use pre-compiled regex instead of compiling each time
450	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
451		inner := match[1 : len(match)-1]
452		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
453	})
454
455	return regexPattern
456}