1package tools
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 _ "embed"
8 "encoding/json"
9 "fmt"
10 "io"
11 "net/http"
12 "os"
13 "os/exec"
14 "path/filepath"
15 "regexp"
16 "sort"
17 "strings"
18 "time"
19
20 "charm.land/fantasy"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/fsext"
23)
24
25// regexCache provides thread-safe caching of compiled regex patterns
26type regexCache struct {
27 *csync.Map[string, *regexp.Regexp]
28}
29
30// newRegexCache creates a new regex cache
31func newRegexCache() *regexCache {
32 return ®exCache{
33 Map: csync.NewMap[string, *regexp.Regexp](),
34 }
35}
36
37// get retrieves a compiled regex from cache or compiles and caches it
38func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
39 var rerr error
40 return rc.GetOrSet(pattern, func() *regexp.Regexp {
41 regex, err := regexp.Compile(pattern)
42 if err != nil {
43 rerr = err
44 }
45 return regex
46 }), rerr
47}
48
49// ResetCache clears compiled regex caches to prevent unbounded growth across sessions.
50func ResetCache() {
51 searchRegexCache.Reset(map[string]*regexp.Regexp{})
52 globRegexCache.Reset(map[string]*regexp.Regexp{})
53}
54
55// Global regex cache instances
56var (
57 searchRegexCache = newRegexCache()
58 globRegexCache = newRegexCache()
59 // Pre-compiled regex for glob conversion (used frequently)
60 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
61)
62
63type GrepParams struct {
64 Pattern string `json:"pattern" description:"The regex pattern to search for in file contents"`
65 Path string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
66 Include string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
67 LiteralText bool `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
68}
69
70type grepMatch struct {
71 path string
72 modTime time.Time
73 lineNum int
74 charNum int
75 lineText string
76}
77
78type GrepResponseMetadata struct {
79 NumberOfMatches int `json:"number_of_matches"`
80 Truncated bool `json:"truncated"`
81}
82
83const (
84 GrepToolName = "grep"
85 maxGrepContentWidth = 500
86)
87
88//go:embed grep.md
89var grepDescription []byte
90
91// escapeRegexPattern escapes special regex characters so they're treated as literal characters
92func escapeRegexPattern(pattern string) string {
93 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
94 escaped := pattern
95
96 for _, char := range specialChars {
97 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
98 }
99
100 return escaped
101}
102
103func NewGrepTool(workingDir string) fantasy.AgentTool {
104 return fantasy.NewAgentTool(
105 GrepToolName,
106 string(grepDescription),
107 func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
108 if params.Pattern == "" {
109 return fantasy.NewTextErrorResponse("pattern is required"), nil
110 }
111
112 // If literal_text is true, escape the pattern
113 searchPattern := params.Pattern
114 if params.LiteralText {
115 searchPattern = escapeRegexPattern(params.Pattern)
116 }
117
118 searchPath := params.Path
119 if searchPath == "" {
120 searchPath = workingDir
121 }
122
123 matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
124 if err != nil {
125 return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
126 }
127
128 var output strings.Builder
129 if len(matches) == 0 {
130 output.WriteString("No files found")
131 } else {
132 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
133
134 currentFile := ""
135 for _, match := range matches {
136 if currentFile != match.path {
137 if currentFile != "" {
138 output.WriteString("\n")
139 }
140 currentFile = match.path
141 fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
142 }
143 if match.lineNum > 0 {
144 lineText := match.lineText
145 if len(lineText) > maxGrepContentWidth {
146 lineText = lineText[:maxGrepContentWidth] + "..."
147 }
148 if match.charNum > 0 {
149 fmt.Fprintf(&output, " Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
150 } else {
151 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, lineText)
152 }
153 } else {
154 fmt.Fprintf(&output, " %s\n", match.path)
155 }
156 }
157
158 if truncated {
159 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
160 }
161 }
162
163 return fantasy.WithResponseMetadata(
164 fantasy.NewTextResponse(output.String()),
165 GrepResponseMetadata{
166 NumberOfMatches: len(matches),
167 Truncated: truncated,
168 },
169 ), nil
170 })
171}
172
173func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
174 matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
175 if err != nil {
176 matches, err = searchFilesWithRegex(pattern, rootPath, include)
177 if err != nil {
178 return nil, false, err
179 }
180 }
181
182 sort.Slice(matches, func(i, j int) bool {
183 return matches[i].modTime.After(matches[j].modTime)
184 })
185
186 truncated := len(matches) > limit
187 if truncated {
188 matches = matches[:limit]
189 }
190
191 return matches, truncated, nil
192}
193
194func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
195 cmd := getRgSearchCmd(ctx, pattern, path, include)
196 if cmd == nil {
197 return nil, fmt.Errorf("ripgrep not found in $PATH")
198 }
199
200 // Only add ignore files if they exist
201 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
202 ignorePath := filepath.Join(path, ignoreFile)
203 if _, err := os.Stat(ignorePath); err == nil {
204 cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
205 }
206 }
207
208 output, err := cmd.Output()
209 if err != nil {
210 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
211 return []grepMatch{}, nil
212 }
213 return nil, err
214 }
215
216 var matches []grepMatch
217 for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
218 if len(line) == 0 {
219 continue
220 }
221 var match ripgrepMatch
222 if err := json.Unmarshal(line, &match); err != nil {
223 continue
224 }
225 if match.Type != "match" {
226 continue
227 }
228 for _, m := range match.Data.Submatches {
229 fi, err := os.Stat(match.Data.Path.Text)
230 if err != nil {
231 continue // Skip files we can't access
232 }
233 matches = append(matches, grepMatch{
234 path: match.Data.Path.Text,
235 modTime: fi.ModTime(),
236 lineNum: match.Data.LineNumber,
237 charNum: m.Start + 1, // ensure 1-based
238 lineText: strings.TrimSpace(match.Data.Lines.Text),
239 })
240 // only get the first match of each line
241 break
242 }
243 }
244 return matches, nil
245}
246
247type ripgrepMatch struct {
248 Type string `json:"type"`
249 Data struct {
250 Path struct {
251 Text string `json:"text"`
252 } `json:"path"`
253 Lines struct {
254 Text string `json:"text"`
255 } `json:"lines"`
256 LineNumber int `json:"line_number"`
257 Submatches []struct {
258 Start int `json:"start"`
259 } `json:"submatches"`
260 } `json:"data"`
261}
262
263func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
264 matches := []grepMatch{}
265
266 // Use cached regex compilation
267 regex, err := searchRegexCache.get(pattern)
268 if err != nil {
269 return nil, fmt.Errorf("invalid regex pattern: %w", err)
270 }
271
272 var includePattern *regexp.Regexp
273 if include != "" {
274 regexPattern := globToRegex(include)
275 includePattern, err = globRegexCache.get(regexPattern)
276 if err != nil {
277 return nil, fmt.Errorf("invalid include pattern: %w", err)
278 }
279 }
280
281 // Create walker with gitignore and crushignore support
282 walker := fsext.NewFastGlobWalker(rootPath)
283
284 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
285 if err != nil {
286 return nil // Skip errors
287 }
288
289 if info.IsDir() {
290 // Check if directory should be skipped
291 if walker.ShouldSkip(path) {
292 return filepath.SkipDir
293 }
294 return nil // Continue into directory
295 }
296
297 // Use walker's shouldSkip method for files
298 if walker.ShouldSkip(path) {
299 return nil
300 }
301
302 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
303 base := filepath.Base(path)
304 if base != "." && strings.HasPrefix(base, ".") {
305 return nil
306 }
307
308 if includePattern != nil && !includePattern.MatchString(path) {
309 return nil
310 }
311
312 match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
313 if err != nil {
314 return nil // Skip files we can't read
315 }
316
317 if match {
318 matches = append(matches, grepMatch{
319 path: path,
320 modTime: info.ModTime(),
321 lineNum: lineNum,
322 charNum: charNum,
323 lineText: lineText,
324 })
325
326 if len(matches) >= 200 {
327 return filepath.SkipAll
328 }
329 }
330
331 return nil
332 })
333 if err != nil {
334 return nil, err
335 }
336
337 return matches, nil
338}
339
340func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
341 // Only search text files.
342 if !isTextFile(filePath) {
343 return false, 0, 0, "", nil
344 }
345
346 file, err := os.Open(filePath)
347 if err != nil {
348 return false, 0, 0, "", err
349 }
350 defer file.Close()
351
352 scanner := bufio.NewScanner(file)
353 lineNum := 0
354 for scanner.Scan() {
355 lineNum++
356 line := scanner.Text()
357 if loc := pattern.FindStringIndex(line); loc != nil {
358 charNum := loc[0] + 1
359 return true, lineNum, charNum, line, nil
360 }
361 }
362
363 return false, 0, 0, "", scanner.Err()
364}
365
366// isTextFile checks if a file is a text file by examining its MIME type.
367func isTextFile(filePath string) bool {
368 file, err := os.Open(filePath)
369 if err != nil {
370 return false
371 }
372 defer file.Close()
373
374 // Read first 512 bytes for MIME type detection.
375 buffer := make([]byte, 512)
376 n, err := file.Read(buffer)
377 if err != nil && err != io.EOF {
378 return false
379 }
380
381 // Detect content type.
382 contentType := http.DetectContentType(buffer[:n])
383
384 // Check if it's a text MIME type.
385 return strings.HasPrefix(contentType, "text/") ||
386 contentType == "application/json" ||
387 contentType == "application/xml" ||
388 contentType == "application/javascript" ||
389 contentType == "application/x-sh"
390}
391
392func globToRegex(glob string) string {
393 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
394 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
395 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
396
397 // Use pre-compiled regex instead of compiling each time
398 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
399 inner := match[1 : len(match)-1]
400 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
401 })
402
403 return regexPattern
404}