1package tools
2
3import (
4 "bufio"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "fmt"
9 "io"
10 "os"
11 "os/exec"
12 "path/filepath"
13 "regexp"
14 "sort"
15 "strconv"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/charmbracelet/crush/internal/fsext"
21)
22
23// regexCache provides thread-safe caching of compiled regex patterns
24type regexCache struct {
25 cache map[string]*regexp.Regexp
26 mu sync.RWMutex
27}
28
29// newRegexCache creates a new regex cache
30func newRegexCache() *regexCache {
31 return ®exCache{
32 cache: make(map[string]*regexp.Regexp),
33 }
34}
35
36// get retrieves a compiled regex from cache or compiles and caches it
37func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
38 // Try to get from cache first (read lock)
39 rc.mu.RLock()
40 if regex, exists := rc.cache[pattern]; exists {
41 rc.mu.RUnlock()
42 return regex, nil
43 }
44 rc.mu.RUnlock()
45
46 // Compile the regex (write lock)
47 rc.mu.Lock()
48 defer rc.mu.Unlock()
49
50 // Double-check in case another goroutine compiled it while we waited
51 if regex, exists := rc.cache[pattern]; exists {
52 return regex, nil
53 }
54
55 // Compile and cache the regex
56 regex, err := regexp.Compile(pattern)
57 if err != nil {
58 return nil, err
59 }
60
61 rc.cache[pattern] = regex
62 return regex, nil
63}
64
65// Global regex cache instances
66var (
67 searchRegexCache = newRegexCache()
68 globRegexCache = newRegexCache()
69 // Pre-compiled regex for glob conversion (used frequently)
70 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
71)
72
73type GrepParams struct {
74 Pattern string `json:"pattern"`
75 Path string `json:"path"`
76 Include string `json:"include"`
77 LiteralText bool `json:"literal_text"`
78}
79
80type grepMatch struct {
81 path string
82 modTime time.Time
83 lineNum int
84 lineText string
85}
86
87type GrepResponseMetadata struct {
88 NumberOfMatches int `json:"number_of_matches"`
89 Truncated bool `json:"truncated"`
90}
91
92type grepTool struct {
93 workingDir string
94}
95
96const (
97 GrepToolName = "grep"
98 maxGrepContentWidth = 500
99)
100
101//go:embed grep.md
102var grepDescription []byte
103
104func NewGrepTool(workingDir string) BaseTool {
105 return &grepTool{
106 workingDir: workingDir,
107 }
108}
109
110func (g *grepTool) Name() string {
111 return GrepToolName
112}
113
114func (g *grepTool) Info() ToolInfo {
115 return ToolInfo{
116 Name: GrepToolName,
117 Description: string(grepDescription),
118 Parameters: map[string]any{
119 "pattern": map[string]any{
120 "type": "string",
121 "description": "The regex pattern to search for in file contents",
122 },
123 "path": map[string]any{
124 "type": "string",
125 "description": "The directory to search in. Defaults to the current working directory.",
126 },
127 "include": map[string]any{
128 "type": "string",
129 "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
130 },
131 "literal_text": map[string]any{
132 "type": "boolean",
133 "description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
134 },
135 },
136 Required: []string{"pattern"},
137 }
138}
139
140// escapeRegexPattern escapes special regex characters so they're treated as literal characters
141func escapeRegexPattern(pattern string) string {
142 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
143 escaped := pattern
144
145 for _, char := range specialChars {
146 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
147 }
148
149 return escaped
150}
151
152func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
153 var params GrepParams
154 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
155 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
156 }
157
158 if params.Pattern == "" {
159 return NewTextErrorResponse("pattern is required"), nil
160 }
161
162 // If literal_text is true, escape the pattern
163 searchPattern := params.Pattern
164 if params.LiteralText {
165 searchPattern = escapeRegexPattern(params.Pattern)
166 }
167
168 searchPath := params.Path
169 if searchPath == "" {
170 searchPath = g.workingDir
171 }
172
173 matches, truncated, err := searchFiles(ctx, searchWithRipgrep, searchPattern, searchPath, params.Include, 100)
174 if err != nil {
175 return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
176 }
177
178 var output strings.Builder
179 if len(matches) == 0 {
180 output.WriteString("No files found")
181 } else {
182 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
183
184 currentFile := ""
185 for _, match := range matches {
186 if currentFile != match.path {
187 if currentFile != "" {
188 output.WriteString("\n")
189 }
190 currentFile = match.path
191 fmt.Fprintf(&output, "%s:\n", match.path)
192 }
193 if match.lineNum > 0 {
194 lineText := match.lineText
195 if len(lineText) > maxGrepContentWidth {
196 lineText = lineText[:maxGrepContentWidth] + "..."
197 }
198 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, lineText)
199 } else {
200 fmt.Fprintf(&output, " %s\n", match.path)
201 }
202 }
203
204 if truncated {
205 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
206 }
207 }
208
209 return WithResponseMetadata(
210 NewTextResponse(output.String()),
211 GrepResponseMetadata{
212 NumberOfMatches: len(matches),
213 Truncated: truncated,
214 },
215 ), nil
216}
217
218func searchFiles(ctx context.Context, ripGrepSearch searchWithRipgrapFn, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
219 matches, err := ripGrepSearch(ctx, getRgSearchCmd, pattern, rootPath, include)
220 if err != nil {
221 matches, err = searchFilesWithRegex(pattern, rootPath, include)
222 if err != nil {
223 return nil, false, err
224 }
225 }
226
227 sort.Slice(matches, func(i, j int) bool {
228 return matches[i].modTime.After(matches[j].modTime)
229 })
230
231 truncated := len(matches) > limit
232 if truncated {
233 matches = matches[:limit]
234 }
235
236 return matches, truncated, nil
237}
238
239type searchWithRipgrapFn func(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error)
240
241// NOTE(tauraamui): ideally I would want to not pass in the search specific args here but will leave for now
242func searchWithRipgrep(ctx context.Context, rgSearchCmd resolveRgSearchCmd, pattern, path, include string) ([]grepMatch, error) {
243 cmd := rgSearchCmd(ctx, pattern, path, include)
244 if cmd == nil {
245 return nil, fmt.Errorf("ripgrep not found in $PATH")
246 }
247
248 // Only add ignore files if they exist
249 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
250 ignorePath := filepath.Join(path, ignoreFile)
251 if _, err := os.Stat(ignorePath); err == nil {
252 cmd.AddArgs("--ignore-file", ignorePath)
253 }
254 }
255
256 output, err := cmd.Output()
257 if err != nil {
258 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
259 return []grepMatch{}, nil
260 }
261 return nil, err
262 }
263
264 lines := strings.Split(strings.TrimSpace(string(output)), "\n")
265 matches := make([]grepMatch, 0, len(lines))
266
267 for _, line := range lines {
268 if line == "" {
269 continue
270 }
271
272 // Parse ripgrep output using null separation
273 filePath, lineNumStr, lineText, ok := parseRipgrepLine(line)
274 if !ok {
275 continue
276 }
277
278 lineNum, err := strconv.Atoi(lineNumStr)
279 if err != nil {
280 continue
281 }
282
283 fileInfo, err := os.Stat(filePath)
284 if err != nil {
285 continue // Skip files we can't access
286 }
287
288 matches = append(matches, grepMatch{
289 path: filePath,
290 modTime: fileInfo.ModTime(),
291 lineNum: lineNum,
292 lineText: lineText,
293 })
294 }
295
296 return matches, nil
297}
298
299// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths
300func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) {
301 // Split on null byte first to separate filename from rest
302 parts := strings.SplitN(line, "\x00", 2)
303 if len(parts) != 2 {
304 return "", "", "", false
305 }
306
307 filePath = parts[0]
308 remainder := parts[1]
309
310 // Now split the remainder on first colon: "linenum:content"
311 colonIndex := strings.Index(remainder, ":")
312 if colonIndex == -1 {
313 return "", "", "", false
314 }
315
316 lineNumStr := remainder[:colonIndex]
317 lineText = remainder[colonIndex+1:]
318
319 if _, err := strconv.Atoi(lineNumStr); err != nil {
320 return "", "", "", false
321 }
322
323 return filePath, lineNumStr, lineText, true
324}
325
326func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
327 matches := []grepMatch{}
328
329 // Use cached regex compilation
330 regex, err := searchRegexCache.get(pattern)
331 if err != nil {
332 return nil, fmt.Errorf("invalid regex pattern: %w", err)
333 }
334
335 var includePattern *regexp.Regexp
336 if include != "" {
337 regexPattern := globToRegex(include)
338 includePattern, err = globRegexCache.get(regexPattern)
339 if err != nil {
340 return nil, fmt.Errorf("invalid include pattern: %w", err)
341 }
342 }
343
344 // Create walker with gitignore and crushignore support
345 walker := fsext.NewFastGlobWalker(rootPath)
346
347 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
348 if err != nil {
349 return nil // Skip errors
350 }
351
352 if info.IsDir() {
353 // Check if directory should be skipped
354 if walker.ShouldSkip(path) {
355 return filepath.SkipDir
356 }
357 return nil // Continue into directory
358 }
359
360 // Use walker's shouldSkip method for files
361 if walker.ShouldSkip(path) {
362 return nil
363 }
364
365 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
366 base := filepath.Base(path)
367 if base != "." && strings.HasPrefix(base, ".") {
368 return nil
369 }
370
371 if includePattern != nil && !includePattern.MatchString(path) {
372 return nil
373 }
374
375 match, lineNum, lineText, err := fileContainsPattern(path, regex)
376 if err != nil {
377 return nil // Skip files we can't read
378 }
379
380 if match {
381 matches = append(matches, grepMatch{
382 path: path,
383 modTime: info.ModTime(),
384 lineNum: lineNum,
385 lineText: lineText,
386 })
387
388 if len(matches) >= 200 {
389 return filepath.SkipAll
390 }
391 }
392
393 return nil
394 })
395 if err != nil {
396 return nil, err
397 }
398
399 return matches, nil
400}
401
402func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
403 // Quick binary file detection
404 if isBinaryFile(filePath) {
405 return false, 0, "", nil
406 }
407
408 file, err := os.Open(filePath)
409 if err != nil {
410 return false, 0, "", err
411 }
412 defer file.Close()
413
414 scanner := bufio.NewScanner(file)
415 lineNum := 0
416 for scanner.Scan() {
417 lineNum++
418 line := scanner.Text()
419 if pattern.MatchString(line) {
420 return true, lineNum, line, nil
421 }
422 }
423
424 return false, 0, "", scanner.Err()
425}
426
427var binaryExts = map[string]struct{}{
428 ".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
429 ".bin": {}, ".obj": {}, ".o": {}, ".a": {},
430 ".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
431 ".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
432 ".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
433 ".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
434}
435
436// isBinaryFile performs a quick check to determine if a file is binary
437func isBinaryFile(filePath string) bool {
438 // Check file extension first (fastest)
439 ext := strings.ToLower(filepath.Ext(filePath))
440 if _, isBinary := binaryExts[ext]; isBinary {
441 return true
442 }
443
444 // Quick content check for files without clear extensions
445 file, err := os.Open(filePath)
446 if err != nil {
447 return false // If we can't open it, let the caller handle the error
448 }
449 defer file.Close()
450
451 // Read first 512 bytes to check for null bytes
452 buffer := make([]byte, 512)
453 n, err := file.Read(buffer)
454 if err != nil && err != io.EOF {
455 return false
456 }
457
458 // Check for null bytes (common in binary files)
459 for i := range n {
460 if buffer[i] == 0 {
461 return true
462 }
463 }
464
465 return false
466}
467
468func globToRegex(glob string) string {
469 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
470 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
471 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
472
473 // Use pre-compiled regex instead of compiling each time
474 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
475 inner := match[1 : len(match)-1]
476 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
477 })
478
479 return regexPattern
480}