1package tools
2
3import (
4 "bufio"
5 "context"
6 _ "embed"
7 "encoding/json"
8 "fmt"
9 "io"
10 "net/http"
11 "os"
12 "os/exec"
13 "path/filepath"
14 "regexp"
15 "sort"
16 "strconv"
17 "strings"
18 "sync"
19 "time"
20
21 "github.com/charmbracelet/crush/internal/fsext"
22)
23
24// regexCache provides thread-safe caching of compiled regex patterns
25type regexCache struct {
26 cache map[string]*regexp.Regexp
27 mu sync.RWMutex
28}
29
30// newRegexCache creates a new regex cache
31func newRegexCache() *regexCache {
32 return ®exCache{
33 cache: make(map[string]*regexp.Regexp),
34 }
35}
36
37// get retrieves a compiled regex from cache or compiles and caches it
38func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
39 // Try to get from cache first (read lock)
40 rc.mu.RLock()
41 if regex, exists := rc.cache[pattern]; exists {
42 rc.mu.RUnlock()
43 return regex, nil
44 }
45 rc.mu.RUnlock()
46
47 // Compile the regex (write lock)
48 rc.mu.Lock()
49 defer rc.mu.Unlock()
50
51 // Double-check in case another goroutine compiled it while we waited
52 if regex, exists := rc.cache[pattern]; exists {
53 return regex, nil
54 }
55
56 // Compile and cache the regex
57 regex, err := regexp.Compile(pattern)
58 if err != nil {
59 return nil, err
60 }
61
62 rc.cache[pattern] = regex
63 return regex, nil
64}
65
66// Global regex cache instances
67var (
68 searchRegexCache = newRegexCache()
69 globRegexCache = newRegexCache()
70 // Pre-compiled regex for glob conversion (used frequently)
71 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
72)
73
74type GrepParams struct {
75 Pattern string `json:"pattern"`
76 Path string `json:"path"`
77 Include string `json:"include"`
78 LiteralText bool `json:"literal_text"`
79}
80
81type grepMatch struct {
82 path string
83 modTime time.Time
84 lineNum int
85 lineText string
86}
87
88type GrepResponseMetadata struct {
89 NumberOfMatches int `json:"number_of_matches"`
90 Truncated bool `json:"truncated"`
91}
92
93type grepTool struct {
94 workingDir string
95}
96
97const GrepToolName = "grep"
98
99//go:embed grep.md
100var grepDescription []byte
101
102func NewGrepTool(workingDir string) BaseTool {
103 return &grepTool{
104 workingDir: workingDir,
105 }
106}
107
108func (g *grepTool) Name() string {
109 return GrepToolName
110}
111
112func (g *grepTool) Info() ToolInfo {
113 return ToolInfo{
114 Name: GrepToolName,
115 Description: string(grepDescription),
116 Parameters: map[string]any{
117 "pattern": map[string]any{
118 "type": "string",
119 "description": "The regex pattern to search for in file contents",
120 },
121 "path": map[string]any{
122 "type": "string",
123 "description": "The directory to search in. Defaults to the current working directory.",
124 },
125 "include": map[string]any{
126 "type": "string",
127 "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
128 },
129 "literal_text": map[string]any{
130 "type": "boolean",
131 "description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
132 },
133 },
134 Required: []string{"pattern"},
135 }
136}
137
138// escapeRegexPattern escapes special regex characters so they're treated as literal characters
139func escapeRegexPattern(pattern string) string {
140 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
141 escaped := pattern
142
143 for _, char := range specialChars {
144 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
145 }
146
147 return escaped
148}
149
150func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
151 var params GrepParams
152 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
153 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
154 }
155
156 if params.Pattern == "" {
157 return NewTextErrorResponse("pattern is required"), nil
158 }
159
160 // If literal_text is true, escape the pattern
161 searchPattern := params.Pattern
162 if params.LiteralText {
163 searchPattern = escapeRegexPattern(params.Pattern)
164 }
165
166 searchPath := params.Path
167 if searchPath == "" {
168 searchPath = g.workingDir
169 }
170
171 matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
172 if err != nil {
173 return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
174 }
175
176 var output strings.Builder
177 if len(matches) == 0 {
178 output.WriteString("No files found")
179 } else {
180 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
181
182 currentFile := ""
183 for _, match := range matches {
184 if currentFile != match.path {
185 if currentFile != "" {
186 output.WriteString("\n")
187 }
188 currentFile = match.path
189 fmt.Fprintf(&output, "%s:\n", match.path)
190 }
191 if match.lineNum > 0 {
192 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, match.lineText)
193 } else {
194 fmt.Fprintf(&output, " %s\n", match.path)
195 }
196 }
197
198 if truncated {
199 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
200 }
201 }
202
203 return WithResponseMetadata(
204 NewTextResponse(output.String()),
205 GrepResponseMetadata{
206 NumberOfMatches: len(matches),
207 Truncated: truncated,
208 },
209 ), nil
210}
211
212func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
213 matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
214 if err != nil {
215 matches, err = searchFilesWithRegex(pattern, rootPath, include)
216 if err != nil {
217 return nil, false, err
218 }
219 }
220
221 sort.Slice(matches, func(i, j int) bool {
222 return matches[i].modTime.After(matches[j].modTime)
223 })
224
225 truncated := len(matches) > limit
226 if truncated {
227 matches = matches[:limit]
228 }
229
230 return matches, truncated, nil
231}
232
233func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
234 cmd := getRgSearchCmd(ctx, pattern, path, include)
235 if cmd == nil {
236 return nil, fmt.Errorf("ripgrep not found in $PATH")
237 }
238
239 // Only add ignore files if they exist
240 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
241 ignorePath := filepath.Join(path, ignoreFile)
242 if _, err := os.Stat(ignorePath); err == nil {
243 cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
244 }
245 }
246
247 output, err := cmd.Output()
248 if err != nil {
249 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
250 return []grepMatch{}, nil
251 }
252 return nil, err
253 }
254
255 lines := strings.Split(strings.TrimSpace(string(output)), "\n")
256 matches := make([]grepMatch, 0, len(lines))
257
258 for _, line := range lines {
259 if line == "" {
260 continue
261 }
262
263 // Parse ripgrep output using null separation
264 filePath, lineNumStr, lineText, ok := parseRipgrepLine(line)
265 if !ok {
266 continue
267 }
268
269 lineNum, err := strconv.Atoi(lineNumStr)
270 if err != nil {
271 continue
272 }
273
274 fileInfo, err := os.Stat(filePath)
275 if err != nil {
276 continue // Skip files we can't access
277 }
278
279 matches = append(matches, grepMatch{
280 path: filePath,
281 modTime: fileInfo.ModTime(),
282 lineNum: lineNum,
283 lineText: lineText,
284 })
285 }
286
287 return matches, nil
288}
289
290// parseRipgrepLine parses ripgrep output with null separation to handle Windows paths
291func parseRipgrepLine(line string) (filePath, lineNum, lineText string, ok bool) {
292 // Split on null byte first to separate filename from rest
293 parts := strings.SplitN(line, "\x00", 2)
294 if len(parts) != 2 {
295 return "", "", "", false
296 }
297
298 filePath = parts[0]
299 remainder := parts[1]
300
301 // Now split the remainder on first colon: "linenum:content"
302 colonIndex := strings.Index(remainder, ":")
303 if colonIndex == -1 {
304 return "", "", "", false
305 }
306
307 lineNumStr := remainder[:colonIndex]
308 lineText = remainder[colonIndex+1:]
309
310 if _, err := strconv.Atoi(lineNumStr); err != nil {
311 return "", "", "", false
312 }
313
314 return filePath, lineNumStr, lineText, true
315}
316
317func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
318 matches := []grepMatch{}
319
320 // Use cached regex compilation
321 regex, err := searchRegexCache.get(pattern)
322 if err != nil {
323 return nil, fmt.Errorf("invalid regex pattern: %w", err)
324 }
325
326 var includePattern *regexp.Regexp
327 if include != "" {
328 regexPattern := globToRegex(include)
329 includePattern, err = globRegexCache.get(regexPattern)
330 if err != nil {
331 return nil, fmt.Errorf("invalid include pattern: %w", err)
332 }
333 }
334
335 // Create walker with gitignore and crushignore support
336 walker := fsext.NewFastGlobWalker(rootPath)
337
338 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
339 if err != nil {
340 return nil // Skip errors
341 }
342
343 if info.IsDir() {
344 // Check if directory should be skipped
345 if walker.ShouldSkip(path) {
346 return filepath.SkipDir
347 }
348 return nil // Continue into directory
349 }
350
351 // Use walker's shouldSkip method for files
352 if walker.ShouldSkip(path) {
353 return nil
354 }
355
356 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
357 base := filepath.Base(path)
358 if base != "." && strings.HasPrefix(base, ".") {
359 return nil
360 }
361
362 if includePattern != nil && !includePattern.MatchString(path) {
363 return nil
364 }
365
366 match, lineNum, lineText, err := fileContainsPattern(path, regex)
367 if err != nil {
368 return nil // Skip files we can't read
369 }
370
371 if match {
372 matches = append(matches, grepMatch{
373 path: path,
374 modTime: info.ModTime(),
375 lineNum: lineNum,
376 lineText: lineText,
377 })
378
379 if len(matches) >= 200 {
380 return filepath.SkipAll
381 }
382 }
383
384 return nil
385 })
386 if err != nil {
387 return nil, err
388 }
389
390 return matches, nil
391}
392
393func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
394 // Only search text files.
395 if !isTextFile(filePath) {
396 return false, 0, "", nil
397 }
398
399 file, err := os.Open(filePath)
400 if err != nil {
401 return false, 0, "", err
402 }
403 defer file.Close()
404
405 scanner := bufio.NewScanner(file)
406 lineNum := 0
407 for scanner.Scan() {
408 lineNum++
409 line := scanner.Text()
410 if pattern.MatchString(line) {
411 return true, lineNum, line, nil
412 }
413 }
414
415 return false, 0, "", scanner.Err()
416}
417
418// isTextFile checks if a file is a text file by examining its MIME type.
419func isTextFile(filePath string) bool {
420 file, err := os.Open(filePath)
421 if err != nil {
422 return false
423 }
424 defer file.Close()
425
426 // Read first 512 bytes for MIME type detection.
427 buffer := make([]byte, 512)
428 n, err := file.Read(buffer)
429 if err != nil && err != io.EOF {
430 return false
431 }
432
433 // Detect content type.
434 contentType := http.DetectContentType(buffer[:n])
435
436 // Check if it's a text MIME type.
437 return strings.HasPrefix(contentType, "text/") ||
438 contentType == "application/json" ||
439 contentType == "application/xml" ||
440 contentType == "application/javascript" ||
441 contentType == "application/x-sh"
442}
443
444func globToRegex(glob string) string {
445 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
446 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
447 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
448
449 // Use pre-compiled regex instead of compiling each time
450 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
451 inner := match[1 : len(match)-1]
452 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
453 })
454
455 return regexPattern
456}