1package tools
2
3import (
4 "bufio"
5 "bytes"
6 "cmp"
7 "context"
8 _ "embed"
9 "encoding/json"
10 "fmt"
11 "io"
12 "net/http"
13 "os"
14 "os/exec"
15 "path/filepath"
16 "regexp"
17 "sort"
18 "strings"
19 "time"
20
21 "charm.land/fantasy"
22 "github.com/charmbracelet/crush/internal/config"
23 "github.com/charmbracelet/crush/internal/csync"
24 "github.com/charmbracelet/crush/internal/fsext"
25)
26
27// regexCache provides thread-safe caching of compiled regex patterns
28type regexCache struct {
29 *csync.Map[string, *regexp.Regexp]
30}
31
32// newRegexCache creates a new regex cache
33func newRegexCache() *regexCache {
34 return ®exCache{
35 Map: csync.NewMap[string, *regexp.Regexp](),
36 }
37}
38
39// get retrieves a compiled regex from cache or compiles and caches it
40func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
41 var rerr error
42 return rc.GetOrSet(pattern, func() *regexp.Regexp {
43 regex, err := regexp.Compile(pattern)
44 if err != nil {
45 rerr = err
46 }
47 return regex
48 }), rerr
49}
50
51// ResetCache clears compiled regex caches to prevent unbounded growth across sessions.
52func ResetCache() {
53 searchRegexCache.Reset(map[string]*regexp.Regexp{})
54 globRegexCache.Reset(map[string]*regexp.Regexp{})
55}
56
57// Global regex cache instances
58var (
59 searchRegexCache = newRegexCache()
60 globRegexCache = newRegexCache()
61 // Pre-compiled regex for glob conversion (used frequently)
62 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
63)
64
65type GrepParams struct {
66 Pattern string `json:"pattern" description:"The regex pattern to search for in file contents"`
67 Path string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
68 Include string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
69 LiteralText bool `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
70}
71
72type grepMatch struct {
73 path string
74 modTime time.Time
75 lineNum int
76 charNum int
77 lineText string
78}
79
80type GrepResponseMetadata struct {
81 NumberOfMatches int `json:"number_of_matches"`
82 Truncated bool `json:"truncated"`
83}
84
85const (
86 GrepToolName = "grep"
87 maxGrepContentWidth = 500
88)
89
90//go:embed grep.md
91var grepDescription []byte
92
93// escapeRegexPattern escapes special regex characters so they're treated as literal characters
94func escapeRegexPattern(pattern string) string {
95 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
96 escaped := pattern
97
98 for _, char := range specialChars {
99 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
100 }
101
102 return escaped
103}
104
105func NewGrepTool(workingDir string, config config.ToolGrep) fantasy.AgentTool {
106 return fantasy.NewAgentTool(
107 GrepToolName,
108 string(grepDescription),
109 func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
110 if params.Pattern == "" {
111 return fantasy.NewTextErrorResponse("pattern is required"), nil
112 }
113
114 searchPattern := params.Pattern
115 if params.LiteralText {
116 searchPattern = escapeRegexPattern(params.Pattern)
117 }
118
119 searchPath := cmp.Or(params.Path, workingDir)
120
121 searchCtx, cancel := context.WithTimeout(ctx, config.GetTimeout())
122 defer cancel()
123
124 matches, truncated, err := searchFiles(searchCtx, searchPattern, searchPath, params.Include, 100)
125 if err != nil {
126 return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
127 }
128
129 var output strings.Builder
130 if len(matches) == 0 {
131 output.WriteString("No files found")
132 } else {
133 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
134
135 currentFile := ""
136 for _, match := range matches {
137 if currentFile != match.path {
138 if currentFile != "" {
139 output.WriteString("\n")
140 }
141 currentFile = match.path
142 fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
143 }
144 if match.lineNum > 0 {
145 lineText := match.lineText
146 if len(lineText) > maxGrepContentWidth {
147 lineText = lineText[:maxGrepContentWidth] + "..."
148 }
149 if match.charNum > 0 {
150 fmt.Fprintf(&output, " Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
151 } else {
152 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, lineText)
153 }
154 } else {
155 fmt.Fprintf(&output, " %s\n", match.path)
156 }
157 }
158
159 if truncated {
160 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
161 }
162 }
163
164 return fantasy.WithResponseMetadata(
165 fantasy.NewTextResponse(output.String()),
166 GrepResponseMetadata{
167 NumberOfMatches: len(matches),
168 Truncated: truncated,
169 },
170 ), nil
171 })
172}
173
174func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
175 matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
176 if err != nil {
177 matches, err = searchFilesWithRegex(pattern, rootPath, include)
178 if err != nil {
179 return nil, false, err
180 }
181 }
182
183 sort.Slice(matches, func(i, j int) bool {
184 return matches[i].modTime.After(matches[j].modTime)
185 })
186
187 truncated := len(matches) > limit
188 if truncated {
189 matches = matches[:limit]
190 }
191
192 return matches, truncated, nil
193}
194
195func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
196 cmd := getRgSearchCmd(ctx, pattern, path, include)
197 if cmd == nil {
198 return nil, fmt.Errorf("ripgrep not found in $PATH")
199 }
200
201 // Only add ignore files if they exist
202 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
203 ignorePath := filepath.Join(path, ignoreFile)
204 if _, err := os.Stat(ignorePath); err == nil {
205 cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
206 }
207 }
208
209 output, err := cmd.Output()
210 if err != nil {
211 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
212 return []grepMatch{}, nil
213 }
214 return nil, err
215 }
216
217 var matches []grepMatch
218 for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
219 if len(line) == 0 {
220 continue
221 }
222 var match ripgrepMatch
223 if err := json.Unmarshal(line, &match); err != nil {
224 continue
225 }
226 if match.Type != "match" {
227 continue
228 }
229 for _, m := range match.Data.Submatches {
230 fi, err := os.Stat(match.Data.Path.Text)
231 if err != nil {
232 continue // Skip files we can't access
233 }
234 matches = append(matches, grepMatch{
235 path: match.Data.Path.Text,
236 modTime: fi.ModTime(),
237 lineNum: match.Data.LineNumber,
238 charNum: m.Start + 1, // ensure 1-based
239 lineText: strings.TrimSpace(match.Data.Lines.Text),
240 })
241 // only get the first match of each line
242 break
243 }
244 }
245 return matches, nil
246}
247
248type ripgrepMatch struct {
249 Type string `json:"type"`
250 Data struct {
251 Path struct {
252 Text string `json:"text"`
253 } `json:"path"`
254 Lines struct {
255 Text string `json:"text"`
256 } `json:"lines"`
257 LineNumber int `json:"line_number"`
258 Submatches []struct {
259 Start int `json:"start"`
260 } `json:"submatches"`
261 } `json:"data"`
262}
263
264func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
265 matches := []grepMatch{}
266
267 // Use cached regex compilation
268 regex, err := searchRegexCache.get(pattern)
269 if err != nil {
270 return nil, fmt.Errorf("invalid regex pattern: %w", err)
271 }
272
273 var includePattern *regexp.Regexp
274 if include != "" {
275 regexPattern := globToRegex(include)
276 includePattern, err = globRegexCache.get(regexPattern)
277 if err != nil {
278 return nil, fmt.Errorf("invalid include pattern: %w", err)
279 }
280 }
281
282 // Create walker with gitignore and crushignore support
283 walker := fsext.NewFastGlobWalker(rootPath)
284
285 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
286 if err != nil {
287 return nil // Skip errors
288 }
289
290 if info.IsDir() {
291 // Check if directory should be skipped
292 if walker.ShouldSkip(path) {
293 return filepath.SkipDir
294 }
295 return nil // Continue into directory
296 }
297
298 // Use walker's shouldSkip method for files
299 if walker.ShouldSkip(path) {
300 return nil
301 }
302
303 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
304 base := filepath.Base(path)
305 if base != "." && strings.HasPrefix(base, ".") {
306 return nil
307 }
308
309 if includePattern != nil && !includePattern.MatchString(path) {
310 return nil
311 }
312
313 match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
314 if err != nil {
315 return nil // Skip files we can't read
316 }
317
318 if match {
319 matches = append(matches, grepMatch{
320 path: path,
321 modTime: info.ModTime(),
322 lineNum: lineNum,
323 charNum: charNum,
324 lineText: lineText,
325 })
326
327 if len(matches) >= 200 {
328 return filepath.SkipAll
329 }
330 }
331
332 return nil
333 })
334 if err != nil {
335 return nil, err
336 }
337
338 return matches, nil
339}
340
341func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
342 // Only search text files.
343 if !isTextFile(filePath) {
344 return false, 0, 0, "", nil
345 }
346
347 file, err := os.Open(filePath)
348 if err != nil {
349 return false, 0, 0, "", err
350 }
351 defer file.Close()
352
353 scanner := bufio.NewScanner(file)
354 lineNum := 0
355 for scanner.Scan() {
356 lineNum++
357 line := scanner.Text()
358 if loc := pattern.FindStringIndex(line); loc != nil {
359 charNum := loc[0] + 1
360 return true, lineNum, charNum, line, nil
361 }
362 }
363
364 return false, 0, 0, "", scanner.Err()
365}
366
367// isTextFile checks if a file is a text file by examining its MIME type.
368func isTextFile(filePath string) bool {
369 file, err := os.Open(filePath)
370 if err != nil {
371 return false
372 }
373 defer file.Close()
374
375 // Read first 512 bytes for MIME type detection.
376 buffer := make([]byte, 512)
377 n, err := file.Read(buffer)
378 if err != nil && err != io.EOF {
379 return false
380 }
381
382 // Detect content type.
383 contentType := http.DetectContentType(buffer[:n])
384
385 // Check if it's a text MIME type.
386 return strings.HasPrefix(contentType, "text/") ||
387 contentType == "application/json" ||
388 contentType == "application/xml" ||
389 contentType == "application/javascript" ||
390 contentType == "application/x-sh"
391}
392
393func globToRegex(glob string) string {
394 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
395 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
396 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
397
398 // Use pre-compiled regex instead of compiling each time
399 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
400 inner := match[1 : len(match)-1]
401 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
402 })
403
404 return regexPattern
405}