1package tools
  2
  3import (
  4	"bufio"
  5	"context"
  6	_ "embed"
  7	"encoding/json"
  8	"fmt"
  9	"io"
 10	"os"
 11	"os/exec"
 12	"path/filepath"
 13	"regexp"
 14	"sort"
 15	"strconv"
 16	"strings"
 17	"sync"
 18	"time"
 19
 20	"github.com/charmbracelet/crush/internal/fsext"
 21	"github.com/charmbracelet/crush/internal/proto"
 22)
 23
 24// regexCache provides thread-safe caching of compiled regex patterns
 25type regexCache struct {
 26	cache map[string]*regexp.Regexp
 27	mu    sync.RWMutex
 28}
 29
 30// newRegexCache creates a new regex cache
 31func newRegexCache() *regexCache {
 32	return ®exCache{
 33		cache: make(map[string]*regexp.Regexp),
 34	}
 35}
 36
 37// get retrieves a compiled regex from cache or compiles and caches it
 38func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
 39	// Try to get from cache first (read lock)
 40	rc.mu.RLock()
 41	if regex, exists := rc.cache[pattern]; exists {
 42		rc.mu.RUnlock()
 43		return regex, nil
 44	}
 45	rc.mu.RUnlock()
 46
 47	// Compile the regex (write lock)
 48	rc.mu.Lock()
 49	defer rc.mu.Unlock()
 50
 51	// Double-check in case another goroutine compiled it while we waited
 52	if regex, exists := rc.cache[pattern]; exists {
 53		return regex, nil
 54	}
 55
 56	// Compile and cache the regex
 57	regex, err := regexp.Compile(pattern)
 58	if err != nil {
 59		return nil, err
 60	}
 61
 62	rc.cache[pattern] = regex
 63	return regex, nil
 64}
 65
 66// Global regex cache instances
 67var (
 68	searchRegexCache = newRegexCache()
 69	globRegexCache   = newRegexCache()
 70	// Pre-compiled regex for glob conversion (used frequently)
 71	globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
 72)
 73
 74type grepMatch struct {
 75	path     string
 76	modTime  time.Time
 77	lineNum  int
 78	lineText string
 79}
 80
 81type (
 82	GrepParams           = proto.GrepParams
 83	GrepResponseMetadata = proto.GrepResponseMetadata
 84)
 85
 86type grepTool struct {
 87	workingDir string
 88}
 89
 90const GrepToolName = proto.GrepToolName
 91
 92//go:embed grep.md
 93var grepDescription []byte
 94
 95func NewGrepTool(workingDir string) BaseTool {
 96	return &grepTool{
 97		workingDir: workingDir,
 98	}
 99}
100
101func (g *grepTool) Name() string {
102	return GrepToolName
103}
104
105func (g *grepTool) Info() ToolInfo {
106	return ToolInfo{
107		Name:        GrepToolName,
108		Description: string(grepDescription),
109		Parameters: map[string]any{
110			"pattern": map[string]any{
111				"type":        "string",
112				"description": "The regex pattern to search for in file contents",
113			},
114			"path": map[string]any{
115				"type":        "string",
116				"description": "The directory to search in. Defaults to the current working directory.",
117			},
118			"include": map[string]any{
119				"type":        "string",
120				"description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
121			},
122			"literal_text": map[string]any{
123				"type":        "boolean",
124				"description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
125			},
126		},
127		Required: []string{"pattern"},
128	}
129}
130
131// escapeRegexPattern escapes special regex characters so they're treated as literal characters
132func escapeRegexPattern(pattern string) string {
133	specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
134	escaped := pattern
135
136	for _, char := range specialChars {
137		escaped = strings.ReplaceAll(escaped, char, "\\"+char)
138	}
139
140	return escaped
141}
142
143func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
144	var params GrepParams
145	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
146		return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
147	}
148
149	if params.Pattern == "" {
150		return NewTextErrorResponse("pattern is required"), nil
151	}
152
153	// If literal_text is true, escape the pattern
154	searchPattern := params.Pattern
155	if params.LiteralText {
156		searchPattern = escapeRegexPattern(params.Pattern)
157	}
158
159	searchPath := params.Path
160	if searchPath == "" {
161		searchPath = g.workingDir
162	}
163
164	matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
165	if err != nil {
166		return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
167	}
168
169	var output strings.Builder
170	if len(matches) == 0 {
171		output.WriteString("No files found")
172	} else {
173		fmt.Fprintf(&output, "Found %d matches\n", len(matches))
174
175		currentFile := ""
176		for _, match := range matches {
177			if currentFile != match.path {
178				if currentFile != "" {
179					output.WriteString("\n")
180				}
181				currentFile = match.path
182				fmt.Fprintf(&output, "%s:\n", match.path)
183			}
184			if match.lineNum > 0 {
185				fmt.Fprintf(&output, "  Line %d: %s\n", match.lineNum, match.lineText)
186			} else {
187				fmt.Fprintf(&output, "  %s\n", match.path)
188			}
189		}
190
191		if truncated {
192			output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
193		}
194	}
195
196	return WithResponseMetadata(
197		NewTextResponse(output.String()),
198		GrepResponseMetadata{
199			NumberOfMatches: len(matches),
200			Truncated:       truncated,
201		},
202	), nil
203}
204
205func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
206	matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
207	if err != nil {
208		matches, err = searchFilesWithRegex(pattern, rootPath, include)
209		if err != nil {
210			return nil, false, err
211		}
212	}
213
214	sort.Slice(matches, func(i, j int) bool {
215		return matches[i].modTime.After(matches[j].modTime)
216	})
217
218	truncated := len(matches) > limit
219	if truncated {
220		matches = matches[:limit]
221	}
222
223	return matches, truncated, nil
224}
225
226func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
227	cmd := getRgSearchCmd(ctx, pattern, path, include)
228	if cmd == nil {
229		return nil, fmt.Errorf("ripgrep not found in $PATH")
230	}
231
232	// Only add ignore files if they exist
233	for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
234		ignorePath := filepath.Join(path, ignoreFile)
235		if _, err := os.Stat(ignorePath); err == nil {
236			cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
237		}
238	}
239
240	output, err := cmd.Output()
241	if err != nil {
242		if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
243			return []grepMatch{}, nil
244		}
245		return nil, err
246	}
247
248	lines := strings.Split(strings.TrimSpace(string(output)), "\n")
249	matches := make([]grepMatch, 0, len(lines))
250
251	for _, line := range lines {
252		if line == "" {
253			continue
254		}
255
256		// Parse ripgrep output using null separation
257		filePath, lineNumStr, lineText, ok := parseRipgrepLine(line)
258		if !ok {
259			continue
260		}
261
262		lineNum, err := strconv.Atoi(lineNumStr)
263		if err != nil {
264			continue
265		}
266
267		fileInfo, err := os.Stat(filePath)
268		if err != nil {
269			continue // Skip files we can't access
270		}
271
272		matches = append(matches, grepMatch{
273			path:     filePath,
274			modTime:  fileInfo.ModTime(),
275			lineNum:  lineNum,
276			lineText: lineText,
277		})
278	}
279
280	return matches, nil
281}
282
283// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths
284func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) {
285	// Split on null byte first to separate filename from rest
286	parts := strings.SplitN(line, "\x00", 2)
287	if len(parts) != 2 {
288		return "", "", "", false
289	}
290
291	filePath = parts[0]
292	remainder := parts[1]
293
294	// Now split the remainder on first colon: "linenum:content"
295	colonIndex := strings.Index(remainder, ":")
296	if colonIndex == -1 {
297		return "", "", "", false
298	}
299
300	lineNumStr := remainder[:colonIndex]
301	lineText = remainder[colonIndex+1:]
302
303	if _, err := strconv.Atoi(lineNumStr); err != nil {
304		return "", "", "", false
305	}
306
307	return filePath, lineNumStr, lineText, true
308}
309
310func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
311	matches := []grepMatch{}
312
313	// Use cached regex compilation
314	regex, err := searchRegexCache.get(pattern)
315	if err != nil {
316		return nil, fmt.Errorf("invalid regex pattern: %w", err)
317	}
318
319	var includePattern *regexp.Regexp
320	if include != "" {
321		regexPattern := globToRegex(include)
322		includePattern, err = globRegexCache.get(regexPattern)
323		if err != nil {
324			return nil, fmt.Errorf("invalid include pattern: %w", err)
325		}
326	}
327
328	// Create walker with gitignore and crushignore support
329	walker := fsext.NewFastGlobWalker(rootPath)
330
331	err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
332		if err != nil {
333			return nil // Skip errors
334		}
335
336		if info.IsDir() {
337			// Check if directory should be skipped
338			if walker.ShouldSkip(path) {
339				return filepath.SkipDir
340			}
341			return nil // Continue into directory
342		}
343
344		// Use walker's shouldSkip method for files
345		if walker.ShouldSkip(path) {
346			return nil
347		}
348
349		// Skip hidden files (starting with a dot) to match ripgrep's default behavior
350		base := filepath.Base(path)
351		if base != "." && strings.HasPrefix(base, ".") {
352			return nil
353		}
354
355		if includePattern != nil && !includePattern.MatchString(path) {
356			return nil
357		}
358
359		match, lineNum, lineText, err := fileContainsPattern(path, regex)
360		if err != nil {
361			return nil // Skip files we can't read
362		}
363
364		if match {
365			matches = append(matches, grepMatch{
366				path:     path,
367				modTime:  info.ModTime(),
368				lineNum:  lineNum,
369				lineText: lineText,
370			})
371
372			if len(matches) >= 200 {
373				return filepath.SkipAll
374			}
375		}
376
377		return nil
378	})
379	if err != nil {
380		return nil, err
381	}
382
383	return matches, nil
384}
385
386func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
387	// Quick binary file detection
388	if isBinaryFile(filePath) {
389		return false, 0, "", nil
390	}
391
392	file, err := os.Open(filePath)
393	if err != nil {
394		return false, 0, "", err
395	}
396	defer file.Close()
397
398	scanner := bufio.NewScanner(file)
399	lineNum := 0
400	for scanner.Scan() {
401		lineNum++
402		line := scanner.Text()
403		if pattern.MatchString(line) {
404			return true, lineNum, line, nil
405		}
406	}
407
408	return false, 0, "", scanner.Err()
409}
410
411var binaryExts = map[string]struct{}{
412	".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
413	".bin": {}, ".obj": {}, ".o": {}, ".a": {},
414	".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
415	".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
416	".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
417	".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
418}
419
420// isBinaryFile performs a quick check to determine if a file is binary
421func isBinaryFile(filePath string) bool {
422	// Check file extension first (fastest)
423	ext := strings.ToLower(filepath.Ext(filePath))
424	if _, isBinary := binaryExts[ext]; isBinary {
425		return true
426	}
427
428	// Quick content check for files without clear extensions
429	file, err := os.Open(filePath)
430	if err != nil {
431		return false // If we can't open it, let the caller handle the error
432	}
433	defer file.Close()
434
435	// Read first 512 bytes to check for null bytes
436	buffer := make([]byte, 512)
437	n, err := file.Read(buffer)
438	if err != nil && err != io.EOF {
439		return false
440	}
441
442	// Check for null bytes (common in binary files)
443	for i := range n {
444		if buffer[i] == 0 {
445			return true
446		}
447	}
448
449	return false
450}
451
452func globToRegex(glob string) string {
453	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
454	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
455	regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
456
457	// Use pre-compiled regex instead of compiling each time
458	regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
459		inner := match[1 : len(match)-1]
460		return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
461	})
462
463	return regexPattern
464}