1package tools
2
3import (
4 "bufio"
5 "bytes"
6 "context"
7 _ "embed"
8 "encoding/json"
9 "fmt"
10 "io"
11 "net/http"
12 "os"
13 "os/exec"
14 "path/filepath"
15 "regexp"
16 "sort"
17 "strings"
18 "time"
19
20 "charm.land/fantasy"
21 "github.com/charmbracelet/crush/internal/config"
22 "github.com/charmbracelet/crush/internal/csync"
23 "github.com/charmbracelet/crush/internal/fsext"
24)
25
26// regexCache provides thread-safe caching of compiled regex patterns
27type regexCache struct {
28 *csync.Map[string, *regexp.Regexp]
29}
30
31// newRegexCache creates a new regex cache
32func newRegexCache() *regexCache {
33 return ®exCache{
34 Map: csync.NewMap[string, *regexp.Regexp](),
35 }
36}
37
38// get retrieves a compiled regex from cache or compiles and caches it
39func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
40 var rerr error
41 return rc.GetOrSet(pattern, func() *regexp.Regexp {
42 regex, err := regexp.Compile(pattern)
43 if err != nil {
44 rerr = err
45 }
46 return regex
47 }), rerr
48}
49
50// ResetCache clears compiled regex caches to prevent unbounded growth across sessions.
51func ResetCache() {
52 searchRegexCache.Reset(map[string]*regexp.Regexp{})
53 globRegexCache.Reset(map[string]*regexp.Regexp{})
54}
55
56// Global regex cache instances
57var (
58 searchRegexCache = newRegexCache()
59 globRegexCache = newRegexCache()
60 // Pre-compiled regex for glob conversion (used frequently)
61 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
62)
63
64type GrepParams struct {
65 Pattern string `json:"pattern" description:"The regex pattern to search for in file contents"`
66 Path string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
67 Include string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
68 LiteralText bool `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
69}
70
71type grepMatch struct {
72 path string
73 modTime time.Time
74 lineNum int
75 charNum int
76 lineText string
77}
78
79type GrepResponseMetadata struct {
80 NumberOfMatches int `json:"number_of_matches"`
81 Truncated bool `json:"truncated"`
82}
83
84const (
85 GrepToolName = "grep"
86 maxGrepContentWidth = 500
87)
88
89//go:embed grep.md
90var grepDescription []byte
91
92// escapeRegexPattern escapes special regex characters so they're treated as literal characters
93func escapeRegexPattern(pattern string) string {
94 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
95 escaped := pattern
96
97 for _, char := range specialChars {
98 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
99 }
100
101 return escaped
102}
103
104func NewGrepTool(workingDir string, config config.ToolGrep) fantasy.AgentTool {
105 return fantasy.NewAgentTool(
106 GrepToolName,
107 string(grepDescription),
108 func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
109 if params.Pattern == "" {
110 return fantasy.NewTextErrorResponse("pattern is required"), nil
111 }
112
113 searchPattern := params.Pattern
114 if params.LiteralText {
115 searchPattern = escapeRegexPattern(params.Pattern)
116 }
117
118 searchPath := params.Path
119 if searchPath == "" {
120 searchPath = workingDir
121 }
122
123 searchCtx, cancel := context.WithTimeout(ctx, config.GetTimeout())
124 defer cancel()
125
126 matches, truncated, err := searchFiles(searchCtx, searchPattern, searchPath, params.Include, 100)
127 if err != nil {
128 return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
129 }
130
131 var output strings.Builder
132 if len(matches) == 0 {
133 output.WriteString("No files found")
134 } else {
135 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
136
137 currentFile := ""
138 for _, match := range matches {
139 if currentFile != match.path {
140 if currentFile != "" {
141 output.WriteString("\n")
142 }
143 currentFile = match.path
144 fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
145 }
146 if match.lineNum > 0 {
147 lineText := match.lineText
148 if len(lineText) > maxGrepContentWidth {
149 lineText = lineText[:maxGrepContentWidth] + "..."
150 }
151 if match.charNum > 0 {
152 fmt.Fprintf(&output, " Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
153 } else {
154 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, lineText)
155 }
156 } else {
157 fmt.Fprintf(&output, " %s\n", match.path)
158 }
159 }
160
161 if truncated {
162 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
163 }
164 }
165
166 return fantasy.WithResponseMetadata(
167 fantasy.NewTextResponse(output.String()),
168 GrepResponseMetadata{
169 NumberOfMatches: len(matches),
170 Truncated: truncated,
171 },
172 ), nil
173 })
174}
175
176func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
177 matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
178 if err != nil {
179 matches, err = searchFilesWithRegex(pattern, rootPath, include)
180 if err != nil {
181 return nil, false, err
182 }
183 }
184
185 sort.Slice(matches, func(i, j int) bool {
186 return matches[i].modTime.After(matches[j].modTime)
187 })
188
189 truncated := len(matches) > limit
190 if truncated {
191 matches = matches[:limit]
192 }
193
194 return matches, truncated, nil
195}
196
197func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
198 cmd := getRgSearchCmd(ctx, pattern, path, include)
199 if cmd == nil {
200 return nil, fmt.Errorf("ripgrep not found in $PATH")
201 }
202
203 // Only add ignore files if they exist
204 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
205 ignorePath := filepath.Join(path, ignoreFile)
206 if _, err := os.Stat(ignorePath); err == nil {
207 cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
208 }
209 }
210
211 output, err := cmd.Output()
212 if err != nil {
213 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
214 return []grepMatch{}, nil
215 }
216 return nil, err
217 }
218
219 var matches []grepMatch
220 for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
221 if len(line) == 0 {
222 continue
223 }
224 var match ripgrepMatch
225 if err := json.Unmarshal(line, &match); err != nil {
226 continue
227 }
228 if match.Type != "match" {
229 continue
230 }
231 for _, m := range match.Data.Submatches {
232 fi, err := os.Stat(match.Data.Path.Text)
233 if err != nil {
234 continue // Skip files we can't access
235 }
236 matches = append(matches, grepMatch{
237 path: match.Data.Path.Text,
238 modTime: fi.ModTime(),
239 lineNum: match.Data.LineNumber,
240 charNum: m.Start + 1, // ensure 1-based
241 lineText: strings.TrimSpace(match.Data.Lines.Text),
242 })
243 // only get the first match of each line
244 break
245 }
246 }
247 return matches, nil
248}
249
250type ripgrepMatch struct {
251 Type string `json:"type"`
252 Data struct {
253 Path struct {
254 Text string `json:"text"`
255 } `json:"path"`
256 Lines struct {
257 Text string `json:"text"`
258 } `json:"lines"`
259 LineNumber int `json:"line_number"`
260 Submatches []struct {
261 Start int `json:"start"`
262 } `json:"submatches"`
263 } `json:"data"`
264}
265
266func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
267 matches := []grepMatch{}
268
269 // Use cached regex compilation
270 regex, err := searchRegexCache.get(pattern)
271 if err != nil {
272 return nil, fmt.Errorf("invalid regex pattern: %w", err)
273 }
274
275 var includePattern *regexp.Regexp
276 if include != "" {
277 regexPattern := globToRegex(include)
278 includePattern, err = globRegexCache.get(regexPattern)
279 if err != nil {
280 return nil, fmt.Errorf("invalid include pattern: %w", err)
281 }
282 }
283
284 // Create walker with gitignore and crushignore support
285 walker := fsext.NewFastGlobWalker(rootPath)
286
287 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
288 if err != nil {
289 return nil // Skip errors
290 }
291
292 if info.IsDir() {
293 // Check if directory should be skipped
294 if walker.ShouldSkip(path) {
295 return filepath.SkipDir
296 }
297 return nil // Continue into directory
298 }
299
300 // Use walker's shouldSkip method for files
301 if walker.ShouldSkip(path) {
302 return nil
303 }
304
305 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
306 base := filepath.Base(path)
307 if base != "." && strings.HasPrefix(base, ".") {
308 return nil
309 }
310
311 if includePattern != nil && !includePattern.MatchString(path) {
312 return nil
313 }
314
315 match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
316 if err != nil {
317 return nil // Skip files we can't read
318 }
319
320 if match {
321 matches = append(matches, grepMatch{
322 path: path,
323 modTime: info.ModTime(),
324 lineNum: lineNum,
325 charNum: charNum,
326 lineText: lineText,
327 })
328
329 if len(matches) >= 200 {
330 return filepath.SkipAll
331 }
332 }
333
334 return nil
335 })
336 if err != nil {
337 return nil, err
338 }
339
340 return matches, nil
341}
342
343func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
344 // Only search text files.
345 if !isTextFile(filePath) {
346 return false, 0, 0, "", nil
347 }
348
349 file, err := os.Open(filePath)
350 if err != nil {
351 return false, 0, 0, "", err
352 }
353 defer file.Close()
354
355 scanner := bufio.NewScanner(file)
356 lineNum := 0
357 for scanner.Scan() {
358 lineNum++
359 line := scanner.Text()
360 if loc := pattern.FindStringIndex(line); loc != nil {
361 charNum := loc[0] + 1
362 return true, lineNum, charNum, line, nil
363 }
364 }
365
366 return false, 0, 0, "", scanner.Err()
367}
368
369// isTextFile checks if a file is a text file by examining its MIME type.
370func isTextFile(filePath string) bool {
371 file, err := os.Open(filePath)
372 if err != nil {
373 return false
374 }
375 defer file.Close()
376
377 // Read first 512 bytes for MIME type detection.
378 buffer := make([]byte, 512)
379 n, err := file.Read(buffer)
380 if err != nil && err != io.EOF {
381 return false
382 }
383
384 // Detect content type.
385 contentType := http.DetectContentType(buffer[:n])
386
387 // Check if it's a text MIME type.
388 return strings.HasPrefix(contentType, "text/") ||
389 contentType == "application/json" ||
390 contentType == "application/xml" ||
391 contentType == "application/javascript" ||
392 contentType == "application/x-sh"
393}
394
395func globToRegex(glob string) string {
396 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
397 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
398 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
399
400 // Use pre-compiled regex instead of compiling each time
401 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
402 inner := match[1 : len(match)-1]
403 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
404 })
405
406 return regexPattern
407}