1package tools
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 _ "embed"
8 "encoding/json"
9 "fmt"
10 "io"
11 "net/http"
12 "os"
13 "os/exec"
14 "path/filepath"
15 "regexp"
16 "sort"
17 "strings"
18 "sync"
19 "time"
20
21 "github.com/charmbracelet/crush/internal/fsext"
22)
23
24// regexCache provides thread-safe caching of compiled regex patterns
25type regexCache struct {
26 cache map[string]*regexp.Regexp
27 mu sync.RWMutex
28}
29
30// newRegexCache creates a new regex cache
31func newRegexCache() *regexCache {
32 return ®exCache{
33 cache: make(map[string]*regexp.Regexp),
34 }
35}
36
37// get retrieves a compiled regex from cache or compiles and caches it
38func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
39 // Try to get from cache first (read lock)
40 rc.mu.RLock()
41 if regex, exists := rc.cache[pattern]; exists {
42 rc.mu.RUnlock()
43 return regex, nil
44 }
45 rc.mu.RUnlock()
46
47 // Compile the regex (write lock)
48 rc.mu.Lock()
49 defer rc.mu.Unlock()
50
51 // Double-check in case another goroutine compiled it while we waited
52 if regex, exists := rc.cache[pattern]; exists {
53 return regex, nil
54 }
55
56 // Compile and cache the regex
57 regex, err := regexp.Compile(pattern)
58 if err != nil {
59 return nil, err
60 }
61
62 rc.cache[pattern] = regex
63 return regex, nil
64}
65
66// Global regex cache instances
67var (
68 searchRegexCache = newRegexCache()
69 globRegexCache = newRegexCache()
70 // Pre-compiled regex for glob conversion (used frequently)
71 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
72)
73
74type GrepParams struct {
75 Pattern string `json:"pattern"`
76 Path string `json:"path"`
77 Include string `json:"include"`
78 LiteralText bool `json:"literal_text"`
79}
80
81type grepMatch struct {
82 path string
83 modTime time.Time
84 lineNum int
85 charNum int
86 lineText string
87}
88
89type GrepResponseMetadata struct {
90 NumberOfMatches int `json:"number_of_matches"`
91 Truncated bool `json:"truncated"`
92}
93
94type grepTool struct {
95 workingDir string
96}
97
98const GrepToolName = "grep"
99
100//go:embed grep.md
101var grepDescription []byte
102
103func NewGrepTool(workingDir string) BaseTool {
104 return &grepTool{
105 workingDir: workingDir,
106 }
107}
108
109func (g *grepTool) Name() string {
110 return GrepToolName
111}
112
113func (g *grepTool) Info() ToolInfo {
114 return ToolInfo{
115 Name: GrepToolName,
116 Description: string(grepDescription),
117 Parameters: map[string]any{
118 "pattern": map[string]any{
119 "type": "string",
120 "description": "The regex pattern to search for in file contents",
121 },
122 "path": map[string]any{
123 "type": "string",
124 "description": "The directory to search in. Defaults to the current working directory.",
125 },
126 "include": map[string]any{
127 "type": "string",
128 "description": "File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")",
129 },
130 "literal_text": map[string]any{
131 "type": "boolean",
132 "description": "If true, the pattern will be treated as literal text with special regex characters escaped. Default is false.",
133 },
134 },
135 Required: []string{"pattern"},
136 }
137}
138
139// escapeRegexPattern escapes special regex characters so they're treated as literal characters
140func escapeRegexPattern(pattern string) string {
141 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
142 escaped := pattern
143
144 for _, char := range specialChars {
145 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
146 }
147
148 return escaped
149}
150
151func (g *grepTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
152 var params GrepParams
153 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
154 return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
155 }
156
157 if params.Pattern == "" {
158 return NewTextErrorResponse("pattern is required"), nil
159 }
160
161 // If literal_text is true, escape the pattern
162 searchPattern := params.Pattern
163 if params.LiteralText {
164 searchPattern = escapeRegexPattern(params.Pattern)
165 }
166
167 searchPath := params.Path
168 if searchPath == "" {
169 searchPath = g.workingDir
170 }
171
172 matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
173 if err != nil {
174 return ToolResponse{}, fmt.Errorf("error searching files: %w", err)
175 }
176
177 var output strings.Builder
178 if len(matches) == 0 {
179 output.WriteString("No files found")
180 } else {
181 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
182
183 currentFile := ""
184 for _, match := range matches {
185 if currentFile != match.path {
186 if currentFile != "" {
187 output.WriteString("\n")
188 }
189 currentFile = match.path
190 fmt.Fprintf(&output, "%s:\n", match.path)
191 }
192 if match.lineNum > 0 {
193 if match.charNum > 0 {
194 fmt.Fprintf(&output, " Line %d, Char %d: %s\n", match.lineNum, match.charNum, match.lineText)
195 } else {
196 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, match.lineText)
197 }
198 } else {
199 fmt.Fprintf(&output, " %s\n", match.path)
200 }
201 }
202
203 if truncated {
204 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
205 }
206 }
207
208 return WithResponseMetadata(
209 NewTextResponse(output.String()),
210 GrepResponseMetadata{
211 NumberOfMatches: len(matches),
212 Truncated: truncated,
213 },
214 ), nil
215}
216
217func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
218 matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
219 if err != nil {
220 matches, err = searchFilesWithRegex(pattern, rootPath, include)
221 if err != nil {
222 return nil, false, err
223 }
224 }
225
226 sort.Slice(matches, func(i, j int) bool {
227 return matches[i].modTime.After(matches[j].modTime)
228 })
229
230 truncated := len(matches) > limit
231 if truncated {
232 matches = matches[:limit]
233 }
234
235 return matches, truncated, nil
236}
237
238func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
239 cmd := getRgSearchCmd(ctx, pattern, path, include)
240 if cmd == nil {
241 return nil, fmt.Errorf("ripgrep not found in $PATH")
242 }
243
244 // Only add ignore files if they exist
245 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
246 ignorePath := filepath.Join(path, ignoreFile)
247 if _, err := os.Stat(ignorePath); err == nil {
248 cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
249 }
250 }
251
252 output, err := cmd.Output()
253 if err != nil {
254 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
255 return []grepMatch{}, nil
256 }
257 return nil, err
258 }
259
260 var matches []grepMatch
261 for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
262 if len(line) == 0 {
263 continue
264 }
265 var match ripgrepMatch
266 if err := json.Unmarshal(line, &match); err != nil {
267 continue
268 }
269 if match.Type != "match" {
270 continue
271 }
272 for _, m := range match.Data.Submatches {
273 fi, err := os.Stat(match.Data.Path.Text)
274 if err != nil {
275 continue // Skip files we can't access
276 }
277 matches = append(matches, grepMatch{
278 path: match.Data.Path.Text,
279 modTime: fi.ModTime(),
280 lineNum: match.Data.LineNumber,
281 charNum: m.Start + 1, // ensure 1-based
282 lineText: strings.TrimSpace(match.Data.Lines.Text),
283 })
284 // only get the first match of each line
285 break
286 }
287 }
288 return matches, nil
289}
290
291type ripgrepMatch struct {
292 Type string `json:"type"`
293 Data struct {
294 Path struct {
295 Text string `json:"text"`
296 } `json:"path"`
297 Lines struct {
298 Text string `json:"text"`
299 } `json:"lines"`
300 LineNumber int `json:"line_number"`
301 Submatches []struct {
302 Start int `json:"start"`
303 } `json:"submatches"`
304 } `json:"data"`
305}
306
307func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
308 matches := []grepMatch{}
309
310 // Use cached regex compilation
311 regex, err := searchRegexCache.get(pattern)
312 if err != nil {
313 return nil, fmt.Errorf("invalid regex pattern: %w", err)
314 }
315
316 var includePattern *regexp.Regexp
317 if include != "" {
318 regexPattern := globToRegex(include)
319 includePattern, err = globRegexCache.get(regexPattern)
320 if err != nil {
321 return nil, fmt.Errorf("invalid include pattern: %w", err)
322 }
323 }
324
325 // Create walker with gitignore and crushignore support
326 walker := fsext.NewFastGlobWalker(rootPath)
327
328 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
329 if err != nil {
330 return nil // Skip errors
331 }
332
333 if info.IsDir() {
334 // Check if directory should be skipped
335 if walker.ShouldSkip(path) {
336 return filepath.SkipDir
337 }
338 return nil // Continue into directory
339 }
340
341 // Use walker's shouldSkip method for files
342 if walker.ShouldSkip(path) {
343 return nil
344 }
345
346 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
347 base := filepath.Base(path)
348 if base != "." && strings.HasPrefix(base, ".") {
349 return nil
350 }
351
352 if includePattern != nil && !includePattern.MatchString(path) {
353 return nil
354 }
355
356 match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
357 if err != nil {
358 return nil // Skip files we can't read
359 }
360
361 if match {
362 matches = append(matches, grepMatch{
363 path: path,
364 modTime: info.ModTime(),
365 lineNum: lineNum,
366 charNum: charNum,
367 lineText: lineText,
368 })
369
370 if len(matches) >= 200 {
371 return filepath.SkipAll
372 }
373 }
374
375 return nil
376 })
377 if err != nil {
378 return nil, err
379 }
380
381 return matches, nil
382}
383
384func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
385 // Only search text files.
386 if !isTextFile(filePath) {
387 return false, 0, 0, "", nil
388 }
389
390 file, err := os.Open(filePath)
391 if err != nil {
392 return false, 0, 0, "", err
393 }
394 defer file.Close()
395
396 scanner := bufio.NewScanner(file)
397 lineNum := 0
398 for scanner.Scan() {
399 lineNum++
400 line := scanner.Text()
401 if loc := pattern.FindStringIndex(line); loc != nil {
402 charNum := loc[0] + 1
403 return true, lineNum, charNum, line, nil
404 }
405 }
406
407 return false, 0, 0, "", scanner.Err()
408}
409
410// isTextFile checks if a file is a text file by examining its MIME type.
411func isTextFile(filePath string) bool {
412 file, err := os.Open(filePath)
413 if err != nil {
414 return false
415 }
416 defer file.Close()
417
418 // Read first 512 bytes for MIME type detection.
419 buffer := make([]byte, 512)
420 n, err := file.Read(buffer)
421 if err != nil && err != io.EOF {
422 return false
423 }
424
425 // Detect content type.
426 contentType := http.DetectContentType(buffer[:n])
427
428 // Check if it's a text MIME type.
429 return strings.HasPrefix(contentType, "text/") ||
430 contentType == "application/json" ||
431 contentType == "application/xml" ||
432 contentType == "application/javascript" ||
433 contentType == "application/x-sh"
434}
435
436func globToRegex(glob string) string {
437 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
438 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
439 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
440
441 // Use pre-compiled regex instead of compiling each time
442 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
443 inner := match[1 : len(match)-1]
444 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
445 })
446
447 return regexPattern
448}