diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 2662210d128f5e86e2dcbba0d262722850de4b38..7acf23bae61df88792dd805317bdf8a67095dd0d 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -7,7 +7,6 @@ import ( "io" "net/http" "strings" - "sync" "time" md "github.com/JohannesKaufmann/html-to-markdown" @@ -29,10 +28,8 @@ type FetchPermissionsParams struct { } type fetchTool struct { - client *http.Client - clientPool map[int]*http.Client - clientPoolMu sync.RWMutex - permissions permission.Service + client *http.Client + permissions permission.Service } const ( @@ -78,51 +75,10 @@ func NewFetchTool(permissions permission.Service) BaseTool { IdleConnTimeout: 90 * time.Second, }, }, - clientPool: make(map[int]*http.Client), permissions: permissions, } } -// getClientForTimeout returns a cached client for the given timeout or the default client -func (t *fetchTool) getClientForTimeout(timeout int) *http.Client { - if timeout <= 0 { - return t.client - } - - maxTimeout := 120 // 2 minutes - if timeout > maxTimeout { - timeout = maxTimeout - } - - // Check if we have a cached client for this timeout - t.clientPoolMu.RLock() - if client, exists := t.clientPool[timeout]; exists { - t.clientPoolMu.RUnlock() - return client - } - t.clientPoolMu.RUnlock() - - // Create and cache a new client - t.clientPoolMu.Lock() - defer t.clientPoolMu.Unlock() - - // Double-check in case another goroutine created it - if client, exists := t.clientPool[timeout]; exists { - return client - } - - client := &http.Client{ - Timeout: time.Duration(timeout) * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - }, - } - t.clientPool[timeout] = client - return client -} - func (t *fetchTool) Info() ToolInfo { return ToolInfo{ Name: FetchToolName, @@ -185,16 +141,26 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error return ToolResponse{}, permission.ErrorPermissionDenied } - client := t.getClientForTimeout(params.Timeout) + // Handle timeout with context + requestCtx := ctx + if params.Timeout > 0 { + maxTimeout := 120 // 2 minutes + if params.Timeout > maxTimeout { + params.Timeout = maxTimeout + } + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second) + defer cancel() + } - req, err := http.NewRequestWithContext(ctx, "GET", params.URL, nil) + req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil) if err != nil { return ToolResponse{}, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("User-Agent", "crush/1.0") - resp, err := client.Do(req) + resp, err := t.client.Do(req) if err != nil { return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err) } diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index b4eae224d617fdf6ec358d52f5c1011de3df3cab..61d4fb79a614da282fb48cb23b8e0405f28d23ac 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -402,19 +402,20 @@ func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, st return false, 0, "", scanner.Err() } +var binaryExts = map[string]struct{}{ + ".exe": {}, ".dll": {}, ".so": {}, ".dylib": {}, + ".bin": {}, ".obj": {}, ".o": {}, ".a": {}, + ".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {}, + ".jpg": {}, ".jpeg": {}, ".png": {}, ".gif": {}, + ".pdf": {}, ".doc": {}, ".docx": {}, ".xls": {}, + ".mp3": {}, ".mp4": {}, ".avi": {}, ".mov": {}, +} + // isBinaryFile performs a quick check to determine if a file is binary func isBinaryFile(filePath string) bool { // Check file extension first (fastest) ext := strings.ToLower(filepath.Ext(filePath)) - binaryExts := map[string]bool{ - ".exe": true, ".dll": true, ".so": true, ".dylib": true, - ".bin": true, ".obj": true, ".o": true, ".a": true, - ".zip": true, ".tar": true, ".gz": true, ".bz2": true, - ".jpg": true, ".jpeg": true, ".png": true, ".gif": true, - ".pdf": true, ".doc": true, ".docx": true, ".xls": true, - ".mp3": true, ".mp4": true, ".avi": true, ".mov": true, - } - if binaryExts[ext] { + if _, isBinary := binaryExts[ext]; isBinary { return true } diff --git a/internal/llm/tools/sourcegraph.go b/internal/llm/tools/sourcegraph.go index 0fb517512db18e4a817bf99e95971920018052d2..29518b7b818da5746d195ea8b7da521d80429962 100644 --- a/internal/llm/tools/sourcegraph.go +++ b/internal/llm/tools/sourcegraph.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "strings" - "sync" "time" ) @@ -25,9 +24,7 @@ type SourcegraphResponseMetadata struct { } type sourcegraphTool struct { - client *http.Client - clientPool map[int]*http.Client - clientPoolMu sync.RWMutex + client *http.Client } const ( @@ -138,50 +135,9 @@ func NewSourcegraphTool() BaseTool { IdleConnTimeout: 90 * time.Second, }, }, - clientPool: make(map[int]*http.Client), } } -// getClientForTimeout returns a cached client for the given timeout or the default client -func (t *sourcegraphTool) getClientForTimeout(timeout int) *http.Client { - if timeout <= 0 { - return t.client - } - - maxTimeout := 120 // 2 minutes - if timeout > maxTimeout { - timeout = maxTimeout - } - - // Check if we have a cached client for this timeout - t.clientPoolMu.RLock() - if client, exists := t.clientPool[timeout]; exists { - t.clientPoolMu.RUnlock() - return client - } - t.clientPoolMu.RUnlock() - - // Create and cache a new client - t.clientPoolMu.Lock() - defer t.clientPoolMu.Unlock() - - // Double-check in case another goroutine created it - if client, exists := t.clientPool[timeout]; exists { - return client - } - - client := &http.Client{ - Timeout: time.Duration(timeout) * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 100, - MaxIdleConnsPerHost: 10, - IdleConnTimeout: 90 * time.Second, - }, - } - t.clientPool[timeout] = client - return client -} - func (t *sourcegraphTool) Info() ToolInfo { return ToolInfo{ Name: SourcegraphToolName, @@ -227,7 +183,18 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, if params.ContextWindow <= 0 { params.ContextWindow = 10 // Default context window } - client := t.getClientForTimeout(params.Timeout) + + // Handle timeout with context + requestCtx := ctx + if params.Timeout > 0 { + maxTimeout := 120 // 2 minutes + if params.Timeout > maxTimeout { + params.Timeout = maxTimeout + } + var cancel context.CancelFunc + requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second) + defer cancel() + } type graphqlRequest struct { Query string `json:"query"` @@ -248,7 +215,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, graphqlQuery := string(graphqlQueryBytes) req, err := http.NewRequestWithContext( - ctx, + requestCtx, "POST", "https://sourcegraph.com/.api/graphql", bytes.NewBuffer([]byte(graphqlQuery)), @@ -260,7 +227,7 @@ func (t *sourcegraphTool) Run(ctx context.Context, call ToolCall) (ToolResponse, req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", "crush/1.0") - resp, err := client.Do(req) + resp, err := t.client.Do(req) if err != nil { return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err) }