1package tools
2
3import (
4 "bufio"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "fmt"
9 "io"
10 "os"
11 "os/exec"
12 "path/filepath"
13 "regexp"
14 "sort"
15 "strconv"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/charmbracelet/crush/internal/fsext"
21 "github.com/charmbracelet/crush/internal/proto"
22)
23
24// regexCache provides thread-safe caching of compiled regex patterns
25type regexCache struct {
26 cache map[string]*regexp.Regexp
27 mu sync.RWMutex
28}
29
30// newRegexCache creates a new regex cache
31func newRegexCache() *regexCache {
32 return ®exCache{
33 cache: make(map[string]*regexp.Regexp),
34 }
35}
36
37// get retrieves a compiled regex from cache or compiles and caches it
38func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
39 // Try to get from cache first (read lock)
40 rc.mu.RLock()
41 if regex, exists := rc.cache[pattern]; exists {
42 rc.mu.RUnlock()
43 return regex, nil
44 }
45 rc.mu.RUnlock()
46
47 // Compile the regex (write lock)
48 rc.mu.Lock()
49 defer rc.mu.Unlock()
50
51 // Double-check in case another goroutine compiled it while we waited
52 if regex, exists := rc.cache[pattern]; exists {
53 return regex, nil
54 }
55
56 // Compile and cache the regex
57 regex, err := regexp.Compile(pattern)
58 if err != nil {
59 return nil, err
60 }
61
62 rc.cache[pattern] = regex
63 return regex, nil
64}
65
66// Global regex cache instances
67var (
68 searchRegexCache = newRegexCache()
69 globRegexCache = newRegexCache()
70 // Pre-compiled regex for glob conversion (used frequently)
71 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
72)
73
74type grepMatch struct {
75 path string
76 modTime time.Time
77 lineNum int
78 lineText string
79}
80
81type (
82 GrepParams = proto.GrepParams
83 GrepResponseMetadata = proto.GrepResponseMetadata
84)
85
86type grepTool struct {
87 workingDir string
88}
89
90const GrepToolName = proto.GrepToolName
91
92//go:embed grep.md
93var grepDescription []byte
94
95func NewGrepTool(workingDir string) BaseTool {
96 return &grepTool{
97 workingDir: workingDir,
98 }
99}
100
101func (g *grepTool) Name() string {
102 return GrepToolName
103}
104
105func (g *grepTool) Info() ToolInfo {
106 return ToolInfo{
107 Name: GrepToolName,
108 Description: string(grepDescription),
109 Parameters: map[string]any{
110 "pattern": map[string]any{
111 "type": "string",
112 "description": "The regex pattern to search for in file contents",
113 },
114 "path": map[string]any{
115 "type": "string",
116 "description": "The directory to search in. Defaults to the current working directory.",
117 },
118 "include": map[string]any{
119 "type": "string",
120 "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
121 },
122 "literal_text": map[string]any{
123 "type": "boolean",
124 "description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
125 },
126 },
127 Required: []string{"pattern"},
128 }
129}
130
131// escapeRegexPattern escapes special regex characters so they're treated as literal characters
132func escapeRegexPattern(pattern string) string {
133 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
134 escaped := pattern
135
136 for _, char := range specialChars {
137 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
138 }
139
140 return escaped
141}
142
143func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
144 var params GrepParams
145 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
146 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
147 }
148
149 if params.Pattern == "" {
150 return NewTextErrorResponse("pattern is required"), nil
151 }
152
153 // If literal_text is true, escape the pattern
154 searchPattern := params.Pattern
155 if params.LiteralText {
156 searchPattern = escapeRegexPattern(params.Pattern)
157 }
158
159 searchPath := params.Path
160 if searchPath == "" {
161 searchPath = g.workingDir
162 }
163
164 matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
165 if err != nil {
166 return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
167 }
168
169 var output strings.Builder
170 if len(matches) == 0 {
171 output.WriteString("No files found")
172 } else {
173 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
174
175 currentFile := ""
176 for _, match := range matches {
177 if currentFile != match.path {
178 if currentFile != "" {
179 output.WriteString("\n")
180 }
181 currentFile = match.path
182 fmt.Fprintf(&output, "%s:\n", match.path)
183 }
184 if match.lineNum > 0 {
185 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, match.lineText)
186 } else {
187 fmt.Fprintf(&output, " %s\n", match.path)
188 }
189 }
190
191 if truncated {
192 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
193 }
194 }
195
196 return WithResponseMetadata(
197 NewTextResponse(output.String()),
198 GrepResponseMetadata{
199 NumberOfMatches: len(matches),
200 Truncated: truncated,
201 },
202 ), nil
203}
204
205func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
206 matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
207 if err != nil {
208 matches, err = searchFilesWithRegex(pattern, rootPath, include)
209 if err != nil {
210 return nil, false, err
211 }
212 }
213
214 sort.Slice(matches, func(i, j int) bool {
215 return matches[i].modTime.After(matches[j].modTime)
216 })
217
218 truncated := len(matches) > limit
219 if truncated {
220 matches = matches[:limit]
221 }
222
223 return matches, truncated, nil
224}
225
226func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
227 cmd := getRgSearchCmd(ctx, pattern, path, include)
228 if cmd == nil {
229 return nil, fmt.Errorf("ripgrep not found in $PATH")
230 }
231
232 // Only add ignore files if they exist
233 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
234 ignorePath := filepath.Join(path, ignoreFile)
235 if _, err := os.Stat(ignorePath); err == nil {
236 cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
237 }
238 }
239
240 output, err := cmd.Output()
241 if err != nil {
242 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
243 return []grepMatch{}, nil
244 }
245 return nil, err
246 }
247
248 lines := strings.Split(strings.TrimSpace(string(output)), "\n")
249 matches := make([]grepMatch, 0, len(lines))
250
251 for _, line := range lines {
252 if line == "" {
253 continue
254 }
255
256 // Parse ripgrep output using null separation
257 filePath, lineNumStr, lineText, ok := parseRipgrepLine(line)
258 if !ok {
259 continue
260 }
261
262 lineNum, err := strconv.Atoi(lineNumStr)
263 if err != nil {
264 continue
265 }
266
267 fileInfo, err := os.Stat(filePath)
268 if err != nil {
269 continue // Skip files we can't access
270 }
271
272 matches = append(matches, grepMatch{
273 path: filePath,
274 modTime: fileInfo.ModTime(),
275 lineNum: lineNum,
276 lineText: lineText,
277 })
278 }
279
280 return matches, nil
281}
282
283// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths
284func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) {
285 // Split on null byte first to separate filename from rest
286 parts := strings.SplitN(line, "\x00", 2)
287 if len(parts) != 2 {
288 return "", "", "", false
289 }
290
291 filePath = parts[0]
292 remainder := parts[1]
293
294 // Now split the remainder on first colon: "linenum:content"
295 colonIndex := strings.Index(remainder, ":")
296 if colonIndex == -1 {
297 return "", "", "", false
298 }
299
300 lineNumStr := remainder[:colonIndex]
301 lineText = remainder[colonIndex+1:]
302
303 if _, err := strconv.Atoi(lineNumStr); err != nil {
304 return "", "", "", false
305 }
306
307 return filePath, lineNumStr, lineText, true
308}
309
310func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
311 matches := []grepMatch{}
312
313 // Use cached regex compilation
314 regex, err := searchRegexCache.get(pattern)
315 if err != nil {
316 return nil, fmt.Errorf("invalid regex pattern: %w", err)
317 }
318
319 var includePattern *regexp.Regexp
320 if include != "" {
321 regexPattern := globToRegex(include)
322 includePattern, err = globRegexCache.get(regexPattern)
323 if err != nil {
324 return nil, fmt.Errorf("invalid include pattern: %w", err)
325 }
326 }
327
328 // Create walker with gitignore and crushignore support
329 walker := fsext.NewFastGlobWalker(rootPath)
330
331 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
332 if err != nil {
333 return nil // Skip errors
334 }
335
336 if info.IsDir() {
337 // Check if directory should be skipped
338 if walker.ShouldSkip(path) {
339 return filepath.SkipDir
340 }
341 return nil // Continue into directory
342 }
343
344 // Use walker's shouldSkip method for files
345 if walker.ShouldSkip(path) {
346 return nil
347 }
348
349 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
350 base := filepath.Base(path)
351 if base != "." && strings.HasPrefix(base, ".") {
352 return nil
353 }
354
355 if includePattern != nil && !includePattern.MatchString(path) {
356 return nil
357 }
358
359 match, lineNum, lineText, err := fileContainsPattern(path, regex)
360 if err != nil {
361 return nil // Skip files we can't read
362 }
363
364 if match {
365 matches = append(matches, grepMatch{
366 path: path,
367 modTime: info.ModTime(),
368 lineNum: lineNum,
369 lineText: lineText,
370 })
371
372 if len(matches) >= 200 {
373 return filepath.SkipAll
374 }
375 }
376
377 return nil
378 })
379 if err != nil {
380 return nil, err
381 }
382
383 return matches, nil
384}
385
386func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
387 // Quick binary file detection
388 if isBinaryFile(filePath) {
389 return false, 0, "", nil
390 }
391
392 file, err := os.Open(filePath)
393 if err != nil {
394 return false, 0, "", err
395 }
396 defer file.Close()
397
398 scanner := bufio.NewScanner(file)
399 lineNum := 0
400 for scanner.Scan() {
401 lineNum++
402 line := scanner.Text()
403 if pattern.MatchString(line) {
404 return true, lineNum, line, nil
405 }
406 }
407
408 return false, 0, "", scanner.Err()
409}
410
411var binaryExts = map[string]struct{}{
412 ".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
413 ".bin": {}, ".obj": {}, ".o": {}, ".a": {},
414 ".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
415 ".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
416 ".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
417 ".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
418}
419
420// isBinaryFile performs a quick check to determine if a file is binary
421func isBinaryFile(filePath string) bool {
422 // Check file extension first (fastest)
423 ext := strings.ToLower(filepath.Ext(filePath))
424 if _, isBinary := binaryExts[ext]; isBinary {
425 return true
426 }
427
428 // Quick content check for files without clear extensions
429 file, err := os.Open(filePath)
430 if err != nil {
431 return false // If we can't open it, let the caller handle the error
432 }
433 defer file.Close()
434
435 // Read first 512 bytes to check for null bytes
436 buffer := make([]byte, 512)
437 n, err := file.Read(buffer)
438 if err != nil && err != io.EOF {
439 return false
440 }
441
442 // Check for null bytes (common in binary files)
443 for i := range n {
444 if buffer[i] == 0 {
445 return true
446 }
447 }
448
449 return false
450}
451
452func globToRegex(glob string) string {
453 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
454 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
455 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
456
457 // Use pre-compiled regex instead of compiling each time
458 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
459 inner := match[1 : len(match)-1]
460 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
461 })
462
463 return regexPattern
464}