1package tools
2
3import (
4 "bufio"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "fmt"
9 "io"
10 "os"
11 "os/exec"
12 "path/filepath"
13 "regexp"
14 "sort"
15 "strconv"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/charmbracelet/crush/internal/fsext"
21)
22
23// regexCache provides thread-safe caching of compiled regex patterns
24type regexCache struct {
25 cache map[string]*regexp.Regexp
26 mu sync.RWMutex
27}
28
29// newRegexCache creates a new regex cache
30func newRegexCache() *regexCache {
31 return ®exCache{
32 cache: make(map[string]*regexp.Regexp),
33 }
34}
35
36// get retrieves a compiled regex from cache or compiles and caches it
37func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
38 // Try to get from cache first (read lock)
39 rc.mu.RLock()
40 if regex, exists := rc.cache[pattern]; exists {
41 rc.mu.RUnlock()
42 return regex, nil
43 }
44 rc.mu.RUnlock()
45
46 // Compile the regex (write lock)
47 rc.mu.Lock()
48 defer rc.mu.Unlock()
49
50 // Double-check in case another goroutine compiled it while we waited
51 if regex, exists := rc.cache[pattern]; exists {
52 return regex, nil
53 }
54
55 // Compile and cache the regex
56 regex, err := regexp.Compile(pattern)
57 if err != nil {
58 return nil, err
59 }
60
61 rc.cache[pattern] = regex
62 return regex, nil
63}
64
65// Global regex cache instances
66var (
67 searchRegexCache = newRegexCache()
68 globRegexCache = newRegexCache()
69 // Pre-compiled regex for glob conversion (used frequently)
70 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
71)
72
73type GrepParams struct {
74 Pattern string `json:"pattern"`
75 Path string `json:"path"`
76 Include string `json:"include"`
77 LiteralText bool `json:"literal_text"`
78}
79
80type grepMatch struct {
81 path string
82 modTime time.Time
83 lineNum int
84 lineText string
85}
86
87type GrepResponseMetadata struct {
88 NumberOfMatches int `json:"number_of_matches"`
89 Truncated bool `json:"truncated"`
90}
91
92type grepTool struct {
93 workingDir string
94}
95
96const GrepToolName = "grep"
97
98//go:embed grep.md
99var grepDescription []byte
100
101func NewGrepTool(workingDir string) BaseTool {
102 return &grepTool{
103 workingDir: workingDir,
104 }
105}
106
107func (g *grepTool) Name() string {
108 return GrepToolName
109}
110
111func (g *grepTool) Info() ToolInfo {
112 return ToolInfo{
113 Name: GrepToolName,
114 Description: string(grepDescription),
115 Parameters: map[string]any{
116 "pattern": map[string]any{
117 "type": "string",
118 "description": "The regex pattern to search for in file contents",
119 },
120 "path": map[string]any{
121 "type": "string",
122 "description": "The directory to search in. Defaults to the current working directory.",
123 },
124 "include": map[string]any{
125 "type": "string",
126 "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
127 },
128 "literal_text": map[string]any{
129 "type": "boolean",
130 "description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
131 },
132 },
133 Required: []string{"pattern"},
134 }
135}
136
137// escapeRegexPattern escapes special regex characters so they're treated as literal characters
138func escapeRegexPattern(pattern string) string {
139 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
140 escaped := pattern
141
142 for _, char := range specialChars {
143 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
144 }
145
146 return escaped
147}
148
149func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
150 var params GrepParams
151 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
152 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
153 }
154
155 if params.Pattern == "" {
156 return NewTextErrorResponse("pattern is required"), nil
157 }
158
159 // If literal_text is true, escape the pattern
160 searchPattern := params.Pattern
161 if params.LiteralText {
162 searchPattern = escapeRegexPattern(params.Pattern)
163 }
164
165 searchPath := params.Path
166 if searchPath == "" {
167 searchPath = g.workingDir
168 }
169
170 matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
171 if err != nil {
172 return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
173 }
174
175 var output strings.Builder
176 if len(matches) == 0 {
177 output.WriteString("No files found")
178 } else {
179 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
180
181 currentFile := ""
182 for _, match := range matches {
183 if currentFile != match.path {
184 if currentFile != "" {
185 output.WriteString("\n")
186 }
187 currentFile = match.path
188 fmt.Fprintf(&output, "%s:\n", match.path)
189 }
190 if match.lineNum > 0 {
191 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, match.lineText)
192 } else {
193 fmt.Fprintf(&output, " %s\n", match.path)
194 }
195 }
196
197 if truncated {
198 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
199 }
200 }
201
202 return WithResponseMetadata(
203 NewTextResponse(output.String()),
204 GrepResponseMetadata{
205 NumberOfMatches: len(matches),
206 Truncated: truncated,
207 },
208 ), nil
209}
210
211func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
212 matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
213 if err != nil {
214 matches, err = searchFilesWithRegex(pattern, rootPath, include)
215 if err != nil {
216 return nil, false, err
217 }
218 }
219
220 sort.Slice(matches, func(i, j int) bool {
221 return matches[i].modTime.After(matches[j].modTime)
222 })
223
224 truncated := len(matches) > limit
225 if truncated {
226 matches = matches[:limit]
227 }
228
229 return matches, truncated, nil
230}
231
232func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
233 cmd := getRgSearchCmd(ctx, pattern, path, include)
234 if cmd == nil {
235 return nil, fmt.Errorf("ripgrep not found in $PATH")
236 }
237
238 // Only add ignore files if they exist
239 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
240 ignorePath := filepath.Join(path, ignoreFile)
241 if _, err := os.Stat(ignorePath); err == nil {
242 cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
243 }
244 }
245
246 output, err := cmd.Output()
247 if err != nil {
248 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
249 return []grepMatch{}, nil
250 }
251 return nil, err
252 }
253
254 lines := strings.Split(strings.TrimSpace(string(output)), "\n")
255 matches := make([]grepMatch, 0, len(lines))
256
257 for _, line := range lines {
258 if line == "" {
259 continue
260 }
261
262 // Parse ripgrep output using null separation
263 filePath, lineNumStr, lineText, ok := parseRipgrepLine(line)
264 if !ok {
265 continue
266 }
267
268 lineNum, err := strconv.Atoi(lineNumStr)
269 if err != nil {
270 continue
271 }
272
273 fileInfo, err := os.Stat(filePath)
274 if err != nil {
275 continue // Skip files we can't access
276 }
277
278 matches = append(matches, grepMatch{
279 path: filePath,
280 modTime: fileInfo.ModTime(),
281 lineNum: lineNum,
282 lineText: lineText,
283 })
284 }
285
286 return matches, nil
287}
288
289// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths
290func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) {
291 // Split on null byte first to separate filename from rest
292 parts := strings.SplitN(line, "\x00", 2)
293 if len(parts) != 2 {
294 return "", "", "", false
295 }
296
297 filePath = parts[0]
298 remainder := parts[1]
299
300 // Now split the remainder on first colon: "linenum:content"
301 colonIndex := strings.Index(remainder, ":")
302 if colonIndex == -1 {
303 return "", "", "", false
304 }
305
306 lineNumStr := remainder[:colonIndex]
307 lineText = remainder[colonIndex+1:]
308
309 if _, err := strconv.Atoi(lineNumStr); err != nil {
310 return "", "", "", false
311 }
312
313 return filePath, lineNumStr, lineText, true
314}
315
316func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
317 matches := []grepMatch{}
318
319 // Use cached regex compilation
320 regex, err := searchRegexCache.get(pattern)
321 if err != nil {
322 return nil, fmt.Errorf("invalid regex pattern: %w", err)
323 }
324
325 var includePattern *regexp.Regexp
326 if include != "" {
327 regexPattern := globToRegex(include)
328 includePattern, err = globRegexCache.get(regexPattern)
329 if err != nil {
330 return nil, fmt.Errorf("invalid include pattern: %w", err)
331 }
332 }
333
334 // Create walker with gitignore and crushignore support
335 walker := fsext.NewFastGlobWalker(rootPath)
336
337 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
338 if err != nil {
339 return nil // Skip errors
340 }
341
342 if info.IsDir() {
343 // Check if directory should be skipped
344 if walker.ShouldSkip(path) {
345 return filepath.SkipDir
346 }
347 return nil // Continue into directory
348 }
349
350 // Use walker's shouldSkip method for files
351 if walker.ShouldSkip(path) {
352 return nil
353 }
354
355 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
356 base := filepath.Base(path)
357 if base != "." && strings.HasPrefix(base, ".") {
358 return nil
359 }
360
361 if includePattern != nil && !includePattern.MatchString(path) {
362 return nil
363 }
364
365 match, lineNum, lineText, err := fileContainsPattern(path, regex)
366 if err != nil {
367 return nil // Skip files we can't read
368 }
369
370 if match {
371 matches = append(matches, grepMatch{
372 path: path,
373 modTime: info.ModTime(),
374 lineNum: lineNum,
375 lineText: lineText,
376 })
377
378 if len(matches) >= 200 {
379 return filepath.SkipAll
380 }
381 }
382
383 return nil
384 })
385 if err != nil {
386 return nil, err
387 }
388
389 return matches, nil
390}
391
392func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
393 // Quick binary file detection
394 if isBinaryFile(filePath) {
395 return false, 0, "", nil
396 }
397
398 file, err := os.Open(filePath)
399 if err != nil {
400 return false, 0, "", err
401 }
402 defer file.Close()
403
404 scanner := bufio.NewScanner(file)
405 lineNum := 0
406 for scanner.Scan() {
407 lineNum++
408 line := scanner.Text()
409 if pattern.MatchString(line) {
410 return true, lineNum, line, nil
411 }
412 }
413
414 return false, 0, "", scanner.Err()
415}
416
417var binaryExts = map[string]struct{}{
418 ".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
419 ".bin": {}, ".obj": {}, ".o": {}, ".a": {},
420 ".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
421 ".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
422 ".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
423 ".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
424}
425
426// isBinaryFile performs a quick check to determine if a file is binary
427func isBinaryFile(filePath string) bool {
428 // Check file extension first (fastest)
429 ext := strings.ToLower(filepath.Ext(filePath))
430 if _, isBinary := binaryExts[ext]; isBinary {
431 return true
432 }
433
434 // Quick content check for files without clear extensions
435 file, err := os.Open(filePath)
436 if err != nil {
437 return false // If we can't open it, let the caller handle the error
438 }
439 defer file.Close()
440
441 // Read first 512 bytes to check for null bytes
442 buffer := make([]byte, 512)
443 n, err := file.Read(buffer)
444 if err != nil && err != io.EOF {
445 return false
446 }
447
448 // Check for null bytes (common in binary files)
449 for i := range n {
450 if buffer[i] == 0 {
451 return true
452 }
453 }
454
455 return false
456}
457
458func globToRegex(glob string) string {
459 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
460 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
461 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
462
463 // Use pre-compiled regex instead of compiling each time
464 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
465 inner := match[1 : len(match)-1]
466 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
467 })
468
469 return regexPattern
470}