fix: remove pool logic

Raphael Amorim created

Change summary

internal/llm/tools/fetch.go       | 64 +++++++-------------------------
internal/llm/tools/grep.go        | 19 +++++----
internal/llm/tools/sourcegraph.go | 63 +++++++------------------------
3 files changed, 40 insertions(+), 106 deletions(-)

Detailed changes

internal/llm/tools/fetch.go 🔗

@@ -7,7 +7,6 @@ import (
 	"io"
 	"net/http"
 	"strings"
-	"sync"
 	"time"
 
 	md "github.com/JohannesKaufmann/html-to-markdown"
@@ -29,10 +28,8 @@ type FetchPermissionsParams struct {
 }
 
 type fetchTool struct {
-	client       *http.Client
-	clientPool   map[int]*http.Client
-	clientPoolMu sync.RWMutex
-	permissions  permission.Service
+	client      *http.Client
+	permissions permission.Service
 }
 
 const (
@@ -78,51 +75,10 @@ func NewFetchTool(permissions permission.Service) BaseTool {
 				IdleConnTimeout:     90 * time.Second,
 			},
 		},
-		clientPool:  make(map[int]*http.Client),
 		permissions: permissions,
 	}
 }
 
-// getClientForTimeout returns a cached client for the given timeout or the default client
-func (t *fetchTool) getClientForTimeout(timeout int) *http.Client {
-	if timeout <= 0 {
-		return t.client
-	}
-
-	maxTimeout := 120 // 2 minutes
-	if timeout > maxTimeout {
-		timeout = maxTimeout
-	}
-
-	// Check if we have a cached client for this timeout
-	t.clientPoolMu.RLock()
-	if client, exists := t.clientPool[timeout]; exists {
-		t.clientPoolMu.RUnlock()
-		return client
-	}
-	t.clientPoolMu.RUnlock()
-
-	// Create and cache a new client
-	t.clientPoolMu.Lock()
-	defer t.clientPoolMu.Unlock()
-
-	// Double-check in case another goroutine created it
-	if client, exists := t.clientPool[timeout]; exists {
-		return client
-	}
-
-	client := &http.Client{
-		Timeout: time.Duration(timeout) * time.Second,
-		Transport: &http.Transport{
-			MaxIdleConns:        100,
-			MaxIdleConnsPerHost: 10,
-			IdleConnTimeout:     90 * time.Second,
-		},
-	}
-	t.clientPool[timeout] = client
-	return client
-}
-
 func (t *fetchTool) Info() ToolInfo {
 	return ToolInfo{
 		Name:        FetchToolName,
@@ -185,16 +141,26 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 		return ToolResponse{}, permission.ErrorPermissionDenied
 	}
 
-	client := t.getClientForTimeout(params.Timeout)
+	// Handle timeout with context
+	requestCtx := ctx
+	if params.Timeout > 0 {
+		maxTimeout := 120 // 2 minutes
+		if params.Timeout > maxTimeout {
+			params.Timeout = maxTimeout
+		}
+		var cancel context.CancelFunc
+		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
+		defer cancel()
+	}
 
-	req, err := http.NewRequestWithContext(ctx, "GET", params.URL, nil)
+	req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
 	if err != nil {
 		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 	}
 
 	req.Header.Set("User-Agent", "crush/1.0")
 
-	resp, err := client.Do(req)
+	resp, err := t.client.Do(req)
 	if err != nil {
 		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
 	}

internal/llm/tools/grep.go 🔗

@@ -402,19 +402,20 @@ func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, st
 	return false, 0, "", scanner.Err()
 }
 
+var binaryExts = map[string]struct{}{
+	".exe": {}, ".dll": {}, ".so": {}, ".dylib": {},
+	".bin": {}, ".obj": {}, ".o": {}, ".a": {},
+	".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {},
+	".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {},
+	".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {},
+	".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {},
+}
+
 // isBinaryFile performs a quick check to determine if a file is binary
 func isBinaryFile(filePath string) bool {
 	// Check file extension first (fastest)
 	ext := strings.ToLower(filepath.Ext(filePath))
-	binaryExts := map[string]bool{
-		".exe": true, ".dll": true, ".so": true, ".dylib": true,
-		".bin": true, ".obj": true, ".o": true, ".a": true,
-		".zip": true, ".tar": true, ".gz": true, ".bz2": true,
-		".jpg": true, ".jpeg": true, ".png": true, ".gif": true,
-		".pdf": true, ".doc": true, ".docx": true, ".xls": true,
-		".mp3": true, ".mp4": true, ".avi": true, ".mov": true,
-	}
-	if binaryExts[ext] {
+	if _, isBinary := binaryExts[ext]; isBinary {
 		return true
 	}
 

internal/llm/tools/sourcegraph.go 🔗

@@ -8,7 +8,6 @@ import (
 	"io"
 	"net/http"
 	"strings"
-	"sync"
 	"time"
 )
 
@@ -25,9 +24,7 @@ type SourcegraphResponseMetadata struct {
 }
 
 type sourcegraphTool struct {
-	client       *http.Client
-	clientPool   map[int]*http.Client
-	clientPoolMu sync.RWMutex
+	client *http.Client
 }
 
 const (
@@ -138,50 +135,9 @@ func NewSourcegraphTool() BaseTool {
 				IdleConnTimeout:     90 * time.Second,
 			},
 		},
-		clientPool: make(map[int]*http.Client),
 	}
 }
 
-// getClientForTimeout returns a cached client for the given timeout or the default client
-func (t *sourcegraphTool) getClientForTimeout(timeout int) *http.Client {
-	if timeout <= 0 {
-		return t.client
-	}
-
-	maxTimeout := 120 // 2 minutes
-	if timeout > maxTimeout {
-		timeout = maxTimeout
-	}
-
-	// Check if we have a cached client for this timeout
-	t.clientPoolMu.RLock()
-	if client, exists := t.clientPool[timeout]; exists {
-		t.clientPoolMu.RUnlock()
-		return client
-	}
-	t.clientPoolMu.RUnlock()
-
-	// Create and cache a new client
-	t.clientPoolMu.Lock()
-	defer t.clientPoolMu.Unlock()
-
-	// Double-check in case another goroutine created it
-	if client, exists := t.clientPool[timeout]; exists {
-		return client
-	}
-
-	client := &http.Client{
-		Timeout: time.Duration(timeout) * time.Second,
-		Transport: &http.Transport{
-			MaxIdleConns:        100,
-			MaxIdleConnsPerHost: 10,
-			IdleConnTimeout:     90 * time.Second,
-		},
-	}
-	t.clientPool[timeout] = client
-	return client
-}
-
 func (t *sourcegraphTool) Info() ToolInfo {
 	return ToolInfo{
 		Name:        SourcegraphToolName,
@@ -227,7 +183,18 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	if params.ContextWindow <= 0 {
 		params.ContextWindow = 10 // Default context window
 	}
-	client := t.getClientForTimeout(params.Timeout)
+
+	// Handle timeout with context
+	requestCtx := ctx
+	if params.Timeout > 0 {
+		maxTimeout := 120 // 2 minutes
+		if params.Timeout > maxTimeout {
+			params.Timeout = maxTimeout
+		}
+		var cancel context.CancelFunc
+		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
+		defer cancel()
+	}
 
 	type graphqlRequest struct {
 		Query     string `json:"query"`
@@ -248,7 +215,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	graphqlQuery := string(graphqlQueryBytes)
 
 	req, err := http.NewRequestWithContext(
-		ctx,
+		requestCtx,
 		"POST",
 		"https://sourcegraph.com/.api/graphql",
 		bytes.NewBuffer([]byte(graphqlQuery)),
@@ -260,7 +227,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse,
 	req.Header.Set("Content-Type", "application/json")
 	req.Header.Set("User-Agent", "crush/1.0")
 
-	resp, err := client.Do(req)
+	resp, err := t.client.Do(req)
 	if err != nil {
 		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
 	}