diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 780f22a43bae7c9ec1e077c2d5878d3aeb0284ec..7acf23bae61df88792dd805317bdf8a67095dd0d 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -69,6 +69,11 @@ func NewFetchTool(permissions permission.Service) BaseTool { return &fetchTool{ client: &http.Client{ Timeout: 30 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + }, }, permissions: permissions, } @@ -136,25 +141,26 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, permission.ErrorPermissionDenied } - client := t.client + // Handle timeout with context + requestCtx := ctx if params.Timeout > 0 { maxTimeout := 120 // 2 minutes if params.Timeout > maxTimeout { params.Timeout = maxTimeout } - client = &http.Client{ - Timeout: time.Duration(params.Timeout) * time.Second, - } + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second) + defer cancel() } - req, err := http.NewRequestWithContext(ctx, "GET", params.URL, nil) + req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil) if err != nil { return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("User-Agent", "crush/1.0") - resp, err := client.Do(req) + resp, err := t.client.Do(req) if err != nil { return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err) } diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index e6992a4f8c23dc440898298d8f0e5d880b2fdc53..3064ee633cf0e54bfb9d14efdd475cda15a38c85 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -377,6 +378,11 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error } func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) { + // Quick binary file detection + if isBinaryFile(filePath) { + return false, 0, "", nil + } + file, err := os.Open(filePath) if err != nil { return false, 0, "", err @@ -396,6 +402,47 @@ func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, st return false, 0, "", scanner.Err() } +var binaryExts = map[string]struct{}{ + ".exe": {}, ".dll": {}, ".so": {}, ".dylib": {}, + ".bin": {}, ".obj": {}, ".o": {}, ".a": {}, + ".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {}, + ".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {}, + ".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {}, + ".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {}, +} + +// isBinaryFile performs a quick check to determine if a file is binary +func isBinaryFile(filePath string) bool { + // Check file extension first (fastest) + ext := strings.ToLower(filepath.Ext(filePath)) + if _, isBinary := binaryExts[ext]; isBinary { + return true + } + + // Quick content check for files without clear extensions + file, err := os.Open(filePath) + if err != nil { + return false // If we can't open it, let the caller handle the error + } + defer file.Close() + + // Read first 512 bytes to check for null bytes + buffer := make([]byte, 512) + n, err := file.Read(buffer) + if err != nil && err != io.EOF { + return false + } + + // Check for null bytes (common in binary files) + for i := range n { + if buffer[i] == 0 { + return true + } + } + + return false +} + func globToRegex(glob string) string { regexPattern := strings.ReplaceAll(glob, ".", "\\.") regexPattern = strings.ReplaceAll(regexPattern, "*", ".*") diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index f62e6a961bed962088e0e40670a4276f16174187..29518b7b818da5746d195ea8b7da521d80429962 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -129,6 +129,11 @@ func NewSourcegraphTool() BaseTool { return &sourcegraphTool{ client: &http.Client{ Timeout: 30 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + }, }, } } @@ -178,15 +183,17 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, if params.ContextWindow <= 0 { params.ContextWindow = 10 // Default context window } - client := t.client + + // Handle timeout with context + requestCtx := ctx if params.Timeout > 0 { maxTimeout := 120 // 2 minutes if params.Timeout > maxTimeout { params.Timeout = maxTimeout } - client = &http.Client{ - Timeout: time.Duration(params.Timeout) * time.Second, - } + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second) + defer cancel() } type graphqlRequest struct { @@ -208,7 +215,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, graphqlQuery := string(graphqlQueryBytes) req, err := http.NewRequestWithContext( - ctx, + requestCtx, "POST", "https://sourcegraph.com/.api/graphql", bytes.NewBuffer([]byte(graphqlQuery)), @@ -220,7 +227,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", "crush/1.0") - resp, err := client.Do(req) + resp, err := t.client.Do(req) if err != nil { return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err) }