Merge pull request #32 from charmbracelet/perf-http-clients

Raphael Amorim created

perf: optimize HTTP client pooling and binary file detection

Change summary

internal/llm/tools/fetch.go       | 18 ++++++++----
internal/llm/tools/grep.go        | 47 +++++++++++++++++++++++++++++++++
internal/llm/tools/sourcegraph.go | 19 +++++++++----
3 files changed, 72 insertions(+), 12 deletions(-)

Detailed changes

internal/llm/tools/fetch.go 🔗

@@ -69,6 +69,11 @@ func NewFetchTool(permissions permission.Service) BaseTool {
 	return &fetchTool{
 		client: &http.Client{
 			Timeout: 30 * time.Second,
+			Transport: &http.Transport{
+				MaxIdleConns:        100,
+				MaxIdleConnsPerHost: 10,
+				IdleConnTimeout:     90 * time.Second,
+			},
 		},
 		permissions: permissions,
 	}
@@ -136,25 +141,26 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 		return ToolResponse{}, permission.ErrorPermissionDenied
 	}
 
-	client := t.client
+	// Handle timeout with context
+	requestCtx := ctx
 	if params.Timeout > 0 {
 		maxTimeout := 120 // 2 minutes
 		if params.Timeout > maxTimeout {
 			params.Timeout = maxTimeout
 		}
-		client = &http.Client{
-			Timeout: time.Duration(params.Timeout) * time.Second,
-		}
+		var cancel context.CancelFunc
+		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
+		defer cancel()
 	}
 
-	req, err := http.NewRequestWithContext(ctx, "GET", params.URL, nil)
+	req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
 	if err != nil {
 		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 	}
 
 	req.Header.Set("User-Agent", "crush/1.0")
 
-	resp, err := client.Do(req)
+	resp, err := t.client.Do(req)
 	if err != nil {
 		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
 	}

internal/llm/tools/grep.go 🔗

@@ -5,6 +5,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"io"
 	"os"
 	"os/exec"
 	"path/filepath"
@@ -377,6 +378,11 @@ func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error
 }
 
 func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, string, error) {
+	// Quick binary file detection
+	if isBinaryFile(filePath) {
+		return false, 0, "", nil
+	}
+
 	file, err := os.Open(filePath)
 	if err != nil {
 		return false, 0, "", err
@@ -396,6 +402,47 @@ func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, st
 	return false, 0, "", scanner.Err()
 }
 
+var binaryExts = map[string]struct{}{
+	".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
+	".bin": {}, ".obj": {}, ".o": {}, ".a": {},
+	".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
+	".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
+	".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
+	".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
+}
+
+// isBinaryFile performs a quick check to determine if a file is binary
+func isBinaryFile(filePath string) bool {
+	// Check file extension first (fastest)
+	ext := strings.ToLower(filepath.Ext(filePath))
+	if _, isBinary := binaryExts[ext]; isBinary {
+		return true
+	}
+
+	// Quick content check for files without clear extensions
+	file, err := os.Open(filePath)
+	if err != nil {
+		return false // If we can't open it, let the caller handle the error
+	}
+	defer file.Close()
+
+	// Read first 512 bytes to check for null bytes
+	buffer := make([]byte, 512)
+	n, err := file.Read(buffer)
+	if err != nil && err != io.EOF {
+		return false
+	}
+
+	// Check for null bytes (common in binary files)
+	for i := range n {
+		if buffer[i] == 0 {
+			return true
+		}
+	}
+
+	return false
+}
+
 func globToRegex(glob string) string {
 	regexPattern := strings.ReplaceAll(glob, ".", "\\.")
 	regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")

internal/llm/tools/sourcegraph.go 🔗

@@ -129,6 +129,11 @@ func NewSourcegraphTool() BaseTool {
 	return &sourcegraphTool{
 		client: &http.Client{
 			Timeout: 30 * time.Second,
+			Transport: &http.Transport{
+				MaxIdleConns:        100,
+				MaxIdleConnsPerHost: 10,
+				IdleConnTimeout:     90 * time.Second,
+			},
 		},
 	}
 }
@@ -178,15 +183,17 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	if params.ContextWindow <= 0 {
 		params.ContextWindow = 10 // Default context window
 	}
-	client := t.client
+
+	// Handle timeout with context
+	requestCtx := ctx
 	if params.Timeout > 0 {
 		maxTimeout := 120 // 2 minutes
 		if params.Timeout > maxTimeout {
 			params.Timeout = maxTimeout
 		}
-		client = &http.Client{
-			Timeout: time.Duration(params.Timeout) * time.Second,
-		}
+		var cancel context.CancelFunc
+		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
+		defer cancel()
 	}
 
 	type graphqlRequest struct {
@@ -208,7 +215,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	graphqlQuery := string(graphqlQueryBytes)
 
 	req, err := http.NewRequestWithContext(
-		ctx,
+		requestCtx,
 		"POST",
 		"https://sourcegraph.com/.api/graphql",
 		bytes.NewBuffer([]byte(graphqlQuery)),
@@ -220,7 +227,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	req.Header.Set("Content-Type", "application/json")
 	req.Header.Set("User-Agent", "crush/1.0")
 
-	resp, err := client.Do(req)
+	resp, err := t.client.Do(req)
 	if err != nil {
 		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
 	}