diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index cbf50360b9355c05797690678a99d1310b19556f..237d4e18dab0bc518b9d4b6e2c73ef5035d2b348 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "net/http" "os" "os/exec" "path/filepath" @@ -390,8 +391,8 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error } func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) { - // Quick binary file detection - if isBinaryFile(filePath) { + // Only search text files. + if !isTextFile(filePath) { return false, 0, "", nil } @@ -414,45 +415,30 @@ func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, st return false, 0, "", scanner.Err() } -var binaryExts = map[string]struct{}{ - ".exe": {}, ".dll": {}, ".so": {}, ".dylib": {}, - ".bin": {}, ".obj": {}, ".o": {}, ".a": {}, - ".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {}, - ".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {}, - ".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {}, - ".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {}, -} - -// isBinaryFile performs a quick check to determine if a file is binary -func isBinaryFile(filePath string) bool { - // Check file extension first (fastest) - ext := strings.ToLower(filepath.Ext(filePath)) - if _, isBinary := binaryExts[ext]; isBinary { - return true - } - - // Quick content check for files without clear extensions +// isTextFile checks if a file is a text file by examining its MIME type. +func isTextFile(filePath string) bool { file, err := os.Open(filePath) if err != nil { - return false // If we can't open it, let the caller handle the error + return false } defer file.Close() - // Read first 512 bytes to check for null bytes + // Read first 512 bytes for MIME type detection. buffer := make([]byte, 512) n, err := file.Read(buffer) if err != nil && err != io.EOF { return false } - // Check for null bytes (common in binary files) - for i := range n { - if buffer[i] == 0 { - return true - } - } + // Detect content type. + contentType := http.DetectContentType(buffer[:n]) - return false + // Check if it's a text MIME type. + return strings.HasPrefix(contentType, "text/") || + contentType == "application/json" || + contentType == "application/xml" || + contentType == "application/javascript" || + contentType == "application/x-sh" } func globToRegex(glob string) string { diff --git a/internal/llm/tools/grep_test.go b/internal/llm/tools/grep_test.go index 53c96b22df444adfba59c6b13995a104411a57be..435b3045b93a8e1297ff2aaeff9ee8977b974b56 100644 --- a/internal/llm/tools/grep_test.go +++ b/internal/llm/tools/grep_test.go @@ -198,3 +198,195 @@ func BenchmarkRegexCacheVsCompile(b *testing.B) { } }) } + +func TestIsTextFile(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + + tests := []struct { + name string + filename string + content []byte + wantText bool + }{ + { + name: "go file", + filename: "test.go", + content: []byte("package main\n\nfunc main() {}\n"), + wantText: true, + }, + { + name: "yaml file", + filename: "config.yaml", + content: []byte("key: value\nlist:\n - item1\n - item2\n"), + wantText: true, + }, + { + name: "yml file", + filename: "config.yml", + content: []byte("key: value\n"), + wantText: true, + }, + { + name: "json file", + filename: "data.json", + content: []byte(`{"key": "value"}`), + wantText: true, + }, + { + name: "javascript file", + filename: "script.js", + content: []byte("console.log('hello');\n"), + wantText: true, + }, + { + name: "typescript file", + filename: "script.ts", + content: []byte("const x: string = 'hello';\n"), + wantText: true, + }, + { + name: "markdown file", + filename: "README.md", + content: []byte("# Title\n\nSome content\n"), + wantText: true, + }, + { + name: "shell script", + filename: "script.sh", + content: []byte("#!/bin/bash\necho 'hello'\n"), + wantText: true, + }, + { + name: "python file", + filename: "script.py", + content: []byte("print('hello')\n"), + wantText: true, + }, + { + name: "xml file", + filename: "data.xml", + content: []byte("\n\n"), + wantText: true, + }, + { + name: "plain text", + filename: "file.txt", + content: []byte("plain text content\n"), + wantText: true, + }, + { + name: "css file", + filename: "style.css", + content: []byte("body { color: red; }\n"), + wantText: true, + }, + { + name: "scss file", + filename: "style.scss", + content: []byte("$primary: blue;\nbody { color: $primary; }\n"), + wantText: true, + }, + { + name: "sass file", + filename: "style.sass", + content: []byte("$primary: blue\nbody\n color: $primary\n"), + wantText: true, + }, + { + name: "rust file", + filename: "main.rs", + content: []byte("fn main() {\n println!(\"Hello, world!\");\n}\n"), + wantText: true, + }, + { + name: "zig file", + filename: "main.zig", + content: []byte("const std = @import(\"std\");\npub fn main() void {}\n"), + wantText: true, + }, + { + name: "java file", + filename: "Main.java", + content: []byte("public class Main {\n public static void main(String[] args) {}\n}\n"), + wantText: true, + }, + { + name: "c file", + filename: "main.c", + content: []byte("#include \nint main() { return 0; }\n"), + wantText: true, + }, + { + name: "cpp file", + filename: "main.cpp", + content: []byte("#include \nint main() { return 0; }\n"), + wantText: true, + }, + { + name: "fish shell", + filename: "script.fish", + content: []byte("#!/usr/bin/env fish\necho 'hello'\n"), + wantText: true, + }, + { + name: "powershell file", + filename: "script.ps1", + content: []byte("Write-Host 'Hello, World!'\n"), + wantText: true, + }, + { + name: "cmd batch file", + filename: "script.bat", + content: []byte("@echo off\necho Hello, World!\n"), + wantText: true, + }, + { + name: "cmd file", + filename: "script.cmd", + content: []byte("@echo off\necho Hello, World!\n"), + wantText: true, + }, + { + name: "binary exe", + filename: "binary.exe", + content: []byte{0x4D, 0x5A, 0x90, 0x00, 0x03, 0x00, 0x00, 0x00}, + wantText: false, + }, + { + name: "png image", + filename: "image.png", + content: []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}, + wantText: false, + }, + { + name: "jpeg image", + filename: "image.jpg", + content: []byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10, 0x4A, 0x46}, + wantText: false, + }, + { + name: "zip archive", + filename: "archive.zip", + content: []byte{0x50, 0x4B, 0x03, 0x04, 0x14, 0x00, 0x00, 0x00}, + wantText: false, + }, + { + name: "pdf file", + filename: "document.pdf", + content: []byte("%PDF-1.4\n%âãÏÓ\n"), + wantText: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + filePath := filepath.Join(tempDir, tt.filename) + require.NoError(t, os.WriteFile(filePath, tt.content, 0o644)) + + got := isTextFile(filePath) + require.Equal(t, tt.wantText, got, "isTextFile(%s) = %v, want %v", tt.filename, got, tt.wantText) + }) + } +}