diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index f5125c8b89f2dda534396f3c51df3839390022ce..ebd1698f2f7bf45ecda15c9160464e3d295ce3d6 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -341,29 +341,32 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon return nil, err } + // XXX: ideally we should be able to use context.WithTimeout here, but, + // the SSE MCP client will start failing once that context is canceled. timeout := mcpTimeout(m) - initCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - if err := c.Start(initCtx); err != nil { + mcpCtx, cancel := context.WithCancel(ctx) + cancelTimer := time.AfterFunc(timeout, cancel) + if err := c.Start(mcpCtx); err != nil { updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error starting mcp client", "error", err, "name", name) _ = c.Close() + cancel() return nil, err } - if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil { + if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil { updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error initializing mcp client", "error", err, "name", name) _ = c.Close() + cancel() return nil, err } - + cancelTimer.Stop() slog.Info("Initialized mcp client", "name", name) return c, nil } func maybeTimeoutErr(err error, timeout time.Duration) error { - if errors.Is(err, context.DeadlineExceeded) { + if errors.Is(err, context.Canceled) { return fmt.Errorf("timed out after %s", timeout) } return err