1package tools
2
3import (
4 "bufio"
5 "bytes"
6 "cmp"
7 "context"
8 _ "embed"
9 "encoding/json"
10 "fmt"
11 "html/template"
12 "io"
13 "net/http"
14 "os"
15 "os/exec"
16 "path/filepath"
17 "regexp"
18 "sort"
19 "strings"
20 "time"
21
22 "charm.land/fantasy"
23 "github.com/charmbracelet/crush/internal/config"
24 "github.com/charmbracelet/crush/internal/csync"
25 "github.com/charmbracelet/crush/internal/fsext"
26)
27
28// regexCache provides thread-safe caching of compiled regex patterns
29type regexCache struct {
30 *csync.Map[string, *regexp.Regexp]
31}
32
33// newRegexCache creates a new regex cache
34func newRegexCache() *regexCache {
35 return ®exCache{
36 Map: csync.NewMap[string, *regexp.Regexp](),
37 }
38}
39
40// get retrieves a compiled regex from cache or compiles and caches it
41func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
42 re, ok := rc.Get(pattern)
43 if ok && re != nil {
44 return re, nil
45 }
46 re, err := regexp.Compile(pattern)
47 if err != nil {
48 return nil, err
49 }
50 rc.Set(pattern, re)
51 return re, nil
52}
53
54// ResetCache clears compiled regex caches to prevent unbounded growth across sessions.
55func ResetCache() {
56 searchRegexCache.Reset(map[string]*regexp.Regexp{})
57 globRegexCache.Reset(map[string]*regexp.Regexp{})
58}
59
60// Global regex cache instances
61var (
62 searchRegexCache = newRegexCache()
63 globRegexCache = newRegexCache()
64 // Pre-compiled regex for glob conversion (used frequently)
65 globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
66)
67
68type GrepParams struct {
69 Pattern string `json:"pattern" description:"The regex pattern to search for in file contents"`
70 Path string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
71 Include string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
72 LiteralText bool `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
73}
74
75type grepMatch struct {
76 path string
77 modTime time.Time
78 lineNum int
79 charNum int
80 lineText string
81}
82
83type GrepResponseMetadata struct {
84 NumberOfMatches int `json:"number_of_matches"`
85 Truncated bool `json:"truncated"`
86}
87
88const (
89 GrepToolName = "grep"
90 maxGrepContentWidth = 500
91)
92
93//go:embed grep.md.tpl
94var grepDescriptionTmpl []byte
95
96var grepDescriptionTpl = template.Must(
97 template.New("grepDescription").
98 Parse(string(grepDescriptionTmpl)),
99)
100
101type grepDescriptionData struct {
102 MaxResults int
103}
104
105func grepDescription() string {
106 return renderTemplate(grepDescriptionTpl, grepDescriptionData{
107 MaxResults: 100,
108 })
109}
110
111// escapeRegexPattern escapes special regex characters so they're treated as literal characters
112func escapeRegexPattern(pattern string) string {
113 specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
114 escaped := pattern
115
116 for _, char := range specialChars {
117 escaped = strings.ReplaceAll(escaped, char, "\\"+char)
118 }
119
120 return escaped
121}
122
123func NewGrepTool(workingDir string, config config.ToolGrep) fantasy.AgentTool {
124 return fantasy.NewAgentTool(
125 GrepToolName,
126 grepDescription(),
127 func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
128 if params.Pattern == "" {
129 return fantasy.NewTextErrorResponse("pattern is required"), nil
130 }
131
132 searchPattern := params.Pattern
133 if params.LiteralText {
134 searchPattern = escapeRegexPattern(params.Pattern)
135 }
136
137 searchPath := cmp.Or(params.Path, workingDir)
138
139 searchCtx, cancel := context.WithTimeout(ctx, config.GetTimeout())
140 defer cancel()
141
142 matches, truncated, err := searchFiles(searchCtx, searchPattern, searchPath, params.Include, 100)
143 if err != nil {
144 return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
145 }
146
147 var output strings.Builder
148 if len(matches) == 0 {
149 output.WriteString("No files found")
150 } else {
151 fmt.Fprintf(&output, "Found %d matches\n", len(matches))
152
153 currentFile := ""
154 for _, match := range matches {
155 if currentFile != match.path {
156 if currentFile != "" {
157 output.WriteString("\n")
158 }
159 currentFile = match.path
160 fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
161 }
162 if match.lineNum > 0 {
163 lineText := match.lineText
164 if len(lineText) > maxGrepContentWidth {
165 lineText = lineText[:maxGrepContentWidth] + "..."
166 }
167 if match.charNum > 0 {
168 fmt.Fprintf(&output, " Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
169 } else {
170 fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, lineText)
171 }
172 } else {
173 fmt.Fprintf(&output, " %s\n", match.path)
174 }
175 }
176
177 if truncated {
178 output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
179 }
180 }
181
182 return fantasy.WithResponseMetadata(
183 fantasy.NewTextResponse(output.String()),
184 GrepResponseMetadata{
185 NumberOfMatches: len(matches),
186 Truncated: truncated,
187 },
188 ), nil
189 })
190}
191
192func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
193 matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
194 if err != nil {
195 matches, err = searchFilesWithRegex(pattern, rootPath, include)
196 if err != nil {
197 return nil, false, err
198 }
199 }
200
201 sort.Slice(matches, func(i, j int) bool {
202 return matches[i].modTime.After(matches[j].modTime)
203 })
204
205 truncated := len(matches) > limit
206 if truncated {
207 matches = matches[:limit]
208 }
209
210 return matches, truncated, nil
211}
212
213func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
214 cmd := getRgSearchCmd(ctx, pattern, path, include)
215 if cmd == nil {
216 return nil, fmt.Errorf("ripgrep not found in $PATH")
217 }
218
219 // Only add ignore files if they exist
220 for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
221 ignorePath := filepath.Join(path, ignoreFile)
222 if _, err := os.Stat(ignorePath); err == nil {
223 cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
224 }
225 }
226
227 output, err := cmd.Output()
228 if err != nil {
229 if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
230 return []grepMatch{}, nil
231 }
232 return nil, err
233 }
234
235 var matches []grepMatch
236 for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
237 if len(line) == 0 {
238 continue
239 }
240 var match ripgrepMatch
241 if err := json.Unmarshal(line, &match); err != nil {
242 continue
243 }
244 if match.Type != "match" {
245 continue
246 }
247 for _, m := range match.Data.Submatches {
248 fi, err := os.Stat(match.Data.Path.Text)
249 if err != nil {
250 continue // Skip files we can't access
251 }
252 matches = append(matches, grepMatch{
253 path: match.Data.Path.Text,
254 modTime: fi.ModTime(),
255 lineNum: match.Data.LineNumber,
256 charNum: m.Start + 1, // ensure 1-based
257 lineText: strings.TrimSpace(match.Data.Lines.Text),
258 })
259 // only get the first match of each line
260 break
261 }
262 }
263 return matches, nil
264}
265
266type ripgrepMatch struct {
267 Type string `json:"type"`
268 Data struct {
269 Path struct {
270 Text string `json:"text"`
271 } `json:"path"`
272 Lines struct {
273 Text string `json:"text"`
274 } `json:"lines"`
275 LineNumber int `json:"line_number"`
276 Submatches []struct {
277 Start int `json:"start"`
278 } `json:"submatches"`
279 } `json:"data"`
280}
281
282func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
283 matches := []grepMatch{}
284
285 // Use cached regex compilation
286 regex, err := searchRegexCache.get(pattern)
287 if err != nil {
288 return nil, fmt.Errorf("invalid regex pattern: %w", err)
289 }
290
291 var includePattern *regexp.Regexp
292 if include != "" {
293 regexPattern := globToRegex(include)
294 includePattern, err = globRegexCache.get(regexPattern)
295 if err != nil {
296 return nil, fmt.Errorf("invalid include pattern: %w", err)
297 }
298 }
299
300 // Create walker with gitignore and crushignore support
301 walker := fsext.NewFastGlobWalker(rootPath)
302
303 err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
304 if err != nil {
305 return nil // Skip errors
306 }
307
308 if info.IsDir() {
309 // Check if directory should be skipped
310 if walker.ShouldSkip(path) {
311 return filepath.SkipDir
312 }
313 return nil // Continue into directory
314 }
315
316 // Use walker's shouldSkip method for files
317 if walker.ShouldSkip(path) {
318 return nil
319 }
320
321 // Skip hidden files (starting with a dot) to match ripgrep's default behavior
322 base := filepath.Base(path)
323 if base != "." && strings.HasPrefix(base, ".") {
324 return nil
325 }
326
327 if includePattern != nil && !includePattern.MatchString(path) {
328 return nil
329 }
330
331 match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
332 if err != nil {
333 return nil // Skip files we can't read
334 }
335
336 if match {
337 matches = append(matches, grepMatch{
338 path: path,
339 modTime: info.ModTime(),
340 lineNum: lineNum,
341 charNum: charNum,
342 lineText: lineText,
343 })
344
345 if len(matches) >= 200 {
346 return filepath.SkipAll
347 }
348 }
349
350 return nil
351 })
352 if err != nil {
353 return nil, err
354 }
355
356 return matches, nil
357}
358
359func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
360 if pattern == nil {
361 return false, 0, 0, "", nil
362 }
363 // Only search text files.
364 if !isTextFile(filePath) {
365 return false, 0, 0, "", nil
366 }
367
368 file, err := os.Open(filePath)
369 if err != nil {
370 return false, 0, 0, "", err
371 }
372 defer file.Close()
373
374 reader := bufio.NewReader(file)
375 lineNum := 0
376 for {
377 line, err := reader.ReadString('\n')
378 lineNum++
379 line = strings.TrimSuffix(line, "\n")
380 line = strings.TrimSuffix(line, "\r")
381 if loc := pattern.FindStringIndex(line); loc != nil {
382 charNum := loc[0] + 1
383 return true, lineNum, charNum, line, nil
384 }
385 if err == io.EOF {
386 break
387 }
388 if err != nil {
389 return false, 0, 0, "", err
390 }
391 }
392
393 return false, 0, 0, "", nil
394}
395
396// isTextFile checks if a file is a text file by examining its MIME type.
397func isTextFile(filePath string) bool {
398 file, err := os.Open(filePath)
399 if err != nil {
400 return false
401 }
402 defer file.Close()
403
404 // Read first 512 bytes for MIME type detection.
405 buffer := make([]byte, 512)
406 n, err := file.Read(buffer)
407 if err != nil && err != io.EOF {
408 return false
409 }
410
411 // Detect content type.
412 contentType := http.DetectContentType(buffer[:n])
413
414 // Check if it's a text MIME type.
415 return strings.HasPrefix(contentType, "text/") ||
416 contentType == "application/json" ||
417 contentType == "application/xml" ||
418 contentType == "application/javascript" ||
419 contentType == "application/x-sh"
420}
421
422func globToRegex(glob string) string {
423 regexPattern := strings.ReplaceAll(glob, ".", "\\.")
424 regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
425 regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
426
427 // Use pre-compiled regex instead of compiling each time
428 regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
429 inner := match[1 : len(match)-1]
430 return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
431 })
432
433 return regexPattern
434}