improve style: flatten future

Astro created

Change summary

src/client/mod.rs     |   8 
src/component/mod.rs  |   4 
src/happy_eyeballs.rs | 233 +++++++++++++++++++++++++-------------------
3 files changed, 138 insertions(+), 107 deletions(-)

Detailed changes

src/client/mod.rs πŸ”—

@@ -61,7 +61,7 @@ impl Client {
         done(idna::domain_to_ascii(&jid.domain))
             .map_err(|_| Error::Idna)
             .and_then(|domain|
-                      done(Connecter::from_lookup(&domain, "_xmpp-client._tcp", 5222))
+                      done(Connecter::from_lookup(&domain, Some("_xmpp-client._tcp"), 5222))
                       .map_err(Error::Connection)
             )
             .and_then(|connecter|
@@ -75,10 +75,8 @@ impl Client {
                 } else {
                     Err(Error::Protocol(ProtocolError::NoTls))
                 }
-            }).and_then(|starttls|
-                        // TODO: flatten?
-                        starttls
-            ).and_then(|tls_stream|
+            }).flatten()
+            .and_then(|tls_stream|
                        XMPPStream::start(tls_stream, jid2, NS_JABBER_CLIENT.to_owned())
             ).and_then(move |xmpp_stream|
                        done(Self::auth(xmpp_stream, username, password))

src/component/mod.rs πŸ”—

@@ -53,8 +53,8 @@ impl Component {
     fn make_connect(jid: Jid, password: String, server: &str, port: u16) -> impl Future<Item=XMPPStream, Error=Error> {
         let jid1 = jid.clone();
         let password = password;
-        done(Connecter::from_lookup(server, "_xmpp-component._tcp", port))
-            .and_then(|connecter| connecter)
+        done(Connecter::from_lookup(server, None, port))
+            .flatten()
             .map_err(Error::Connection)
             .and_then(move |tcp_stream| {
                 xmpp_stream::XMPPStream::start(tcp_stream, jid1, NS_JABBER_COMPONENT_ACCEPT.to_owned())

src/happy_eyeballs.rs πŸ”—

@@ -1,44 +1,61 @@
 use std::mem;
-use std::net::{SocketAddr, IpAddr};
-use std::collections::{BTreeMap, btree_map};
+use std::net::SocketAddr;
+use std::collections::BTreeMap;
 use std::collections::VecDeque;
+use std::cell::RefCell;
 use futures::{Future, Poll, Async};
 use tokio::net::{ConnectFuture, TcpStream};
 use trust_dns_resolver::{IntoName, Name, ResolverFuture, error::ResolveError};
 use trust_dns_resolver::lookup::SrvLookupFuture;
 use trust_dns_resolver::lookup_ip::LookupIpFuture;
-use trust_dns_proto::rr::rdata::srv::SRV;
 use ConnecterError;
 
+enum State {
+    AwaitResolver(Box<Future<Item = ResolverFuture, Error = ResolveError> + Send>),
+    ResolveSrv(ResolverFuture, SrvLookupFuture),
+    ResolveTarget(ResolverFuture, LookupIpFuture, u16),
+    Connecting(Option<ResolverFuture>, Vec<RefCell<ConnectFuture>>),
+    Invalid,
+}
+
 pub struct Connecter {
     fallback_port: u16,
-    name: Name,
+    srv_domain: Option<Name>,
     domain: Name,
-    resolver_future: Box<Future<Item = ResolverFuture, Error = ResolveError> + Send>,
-    resolver_opt: Option<ResolverFuture>,
-    srv_lookup_opt: Option<SrvLookupFuture>,
-    srvs_opt: Option<btree_map::IntoIter<u16, SRV>>,
-    ip_lookup_opt: Option<(u16, LookupIpFuture)>,
-    ips_opt: Option<(u16, VecDeque<IpAddr>)>,
-    connect_opt: Option<ConnectFuture>,
+    state: State,
+    targets: VecDeque<(Name, u16)>,
 }
 
 impl Connecter {
-    pub fn from_lookup(domain: &str, srv: &str, fallback_port: u16) -> Result<Connecter, ConnecterError> {
+    pub fn from_lookup(domain: &str, srv: Option<&str>, fallback_port: u16) -> Result<Connecter, ConnecterError> {
+        if let Ok(ip) = domain.parse() {
+            // use specified IP address, not domain name, skip the whole dns part
+            let connect =
+                RefCell::new(TcpStream::connect(&SocketAddr::new(ip, fallback_port)));
+            return Ok(Connecter {
+                fallback_port,
+                srv_domain: None,
+                domain: "nohost".into_name()?,
+                state: State::Connecting(None, vec![connect]),
+                targets: VecDeque::new(),
+            });
+        }
+
         let resolver_future = ResolverFuture::from_system_conf()?;
-        let name = format!("{}.{}.", srv, domain).into_name()?;
+        let state = State::AwaitResolver(resolver_future);
+        let srv_domain = match srv {
+            Some(srv) =>
+                Some(format!("{}.{}.", srv, domain).into_name()?),
+            None =>
+                None,
+        };
 
         Ok(Connecter {
             fallback_port,
-            name,
+            srv_domain,
             domain: domain.into_name()?,
-            resolver_future,
-            resolver_opt: None,
-            srv_lookup_opt: None,
-            srvs_opt: None,
-            ip_lookup_opt: None,
-            ips_opt: None,
-            connect_opt: None,
+            state,
+            targets: VecDeque::new(),
         })
     }
 }
@@ -48,102 +65,118 @@ impl Future for Connecter {
     type Error = ConnecterError;
 
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        if self.resolver_opt.is_none() {
-            //println!("Poll resolver future");
-            match self.resolver_future.poll()? {
-                Async::Ready(resolver) =>
-                    self.resolver_opt = Some(resolver),
-                Async::NotReady =>
-                    return Ok(Async::NotReady),
-            }
-        }
-
-        if let Some(ref resolver) = self.resolver_opt {
-            if self.srvs_opt.is_none() {
-                if self.srv_lookup_opt.is_none() {
-                    //println!("Lookup srv: {:?}", self.name);
-                    self.srv_lookup_opt = Some(resolver.lookup_srv(&self.name));
-                }
-
-                if let Some(ref mut srv_lookup) = self.srv_lookup_opt {
-                    match srv_lookup.poll() {
-                        Ok(Async::Ready(t)) => {
-                            let mut srvs = BTreeMap::new();
-                            for srv in t.iter() {
-                                srvs.insert(srv.priority(), srv.clone());
+        let state = mem::replace(&mut self.state, State::Invalid);
+        match state {
+            State::AwaitResolver(mut resolver_future) => {
+                match resolver_future.poll()? {
+                    Async::NotReady => {
+                        self.state = State::AwaitResolver(resolver_future);
+                        Ok(Async::NotReady)
+                    }
+                    Async::Ready(resolver) => {
+                        match &self.srv_domain {
+                            &Some(ref srv_domain) => {
+                                let srv_lookup = resolver.lookup_srv(srv_domain);
+                                self.state = State::ResolveSrv(resolver, srv_lookup);
+                            }
+                            None => {
+                                self.targets =
+                                    [(self.domain.clone(), self.fallback_port)].into_iter()
+                                    .cloned()
+                                    .collect();
+                                self.state = State::Connecting(Some(resolver), vec![]);
                             }
-                            srvs.insert(65535, SRV::new(65535, 0, self.fallback_port, self.domain.clone()));
-                            self.srvs_opt = Some(srvs.into_iter());
                         }
-                        Ok(Async::NotReady) => return Ok(Async::NotReady),
-                        Err(_) => {
-                            //println!("Ignore SVR error: {:?}", e);
-                            let mut srvs = BTreeMap::new();
-                            srvs.insert(65535, SRV::new(65535, 0, self.fallback_port, self.domain.clone()));
-                            self.srvs_opt = Some(srvs.into_iter());
-                        },
+                        self.poll()
                     }
                 }
             }
-
-            if self.connect_opt.is_none() {
-                if self.ips_opt.is_none() {
-                    if self.ip_lookup_opt.is_none() {
-                        if let Some(ref mut srvs) = self.srvs_opt {
-                            if let Some((_, srv)) = srvs.next() {
-                                //println!("Lookup ip: {:?}", srv);
-                                self.ip_lookup_opt = Some((srv.port(), resolver.lookup_ip(srv.target())));
-                            } else {
-                                return Err(ConnecterError::NoSrv);
-                            }
-                        }
+            State::ResolveSrv(resolver, mut srv_lookup) => {
+                match srv_lookup.poll() {
+                    Ok(Async::NotReady) => {
+                        self.state = State::ResolveSrv(resolver, srv_lookup);
+                        Ok(Async::NotReady)
                     }
-
-                    if let Some((port, mut ip_lookup)) = mem::replace(&mut self.ip_lookup_opt, None) {
-                        match ip_lookup.poll() {
-                            Ok(Async::Ready(t)) => {
-                                let mut ip_deque = VecDeque::new();
-                                ip_deque.extend(t.iter());
-                                //println!("IPs: {:?}", ip_deque);
-                                self.ips_opt = Some((port, ip_deque));
-                                self.ip_lookup_opt = None;
-                            },
-                            Ok(Async::NotReady) => {
-                                self.ip_lookup_opt = Some((port, ip_lookup));
-                                return Ok(Async::NotReady)
-                            },
-                            Err(_) => {
-                                //println!("Ignore lookup error: {:?}", e);
-                                self.ip_lookup_opt = None;
-                            }
-                        }
+                    Ok(Async::Ready(srv_result)) => {
+                        let mut srv_map: BTreeMap<_, _> =
+                            srv_result.iter()
+                            .map(|srv| (srv.priority(), (srv.target().clone(), srv.port())))
+                            .collect();
+                        let targets =
+                            srv_map.into_iter()
+                            .map(|(_, tp)| tp)
+                            .collect();
+                        self.targets = targets;
+                        self.state = State::Connecting(Some(resolver), vec![]);
+                        self.poll()
+                    }
+                    Err(_) => {
+                        // ignore, fallback
+                        self.targets =
+                            [(self.domain.clone(), self.fallback_port)].into_iter()
+                            .cloned()
+                            .collect();
+                        self.state = State::Connecting(Some(resolver), vec![]);
+                        self.poll()
                     }
                 }
-
-                if let Some((port, mut ip_deque)) = mem::replace(&mut self.ips_opt, None) {
-                    if let Some(ip) = ip_deque.pop_front() {
-                        //println!("Connect to {:?}:{}", ip, port);
-                        self.connect_opt = Some(TcpStream::connect(&SocketAddr::new(ip, port)));
-                        self.ips_opt = Some((port, ip_deque));
+            }
+            State::Connecting(resolver, mut connects) => {
+                if resolver.is_some() &&
+                    connects.len() == 0 &&
+                    self.targets.len() > 0 {
+                        let resolver = resolver.unwrap();
+                        let (host, port) = self.targets.pop_front().unwrap();
+                        let ip_lookup = resolver.lookup_ip(host);
+                        self.state = State::ResolveTarget(resolver, ip_lookup, port);
+                        self.poll()
+                } else if connects.len() > 0 {
+                    let mut success = None;
+                    connects.retain(|connect| {
+                        match connect.borrow_mut().poll() {
+                            Ok(Async::NotReady) => true,
+                            Ok(Async::Ready(connection)) => {
+                                success = Some(connection);
+                                false
+                            }
+                            Err(_) => false,
+                        }
+                    });
+                    match success {
+                        Some(connection) =>
+                            Ok(Async::Ready(connection)),
+                        None => {
+                            self.state = State::Connecting(resolver, connects);
+                            Ok(Async::NotReady)
+                        },
                     }
+                } else {
+                    Err(ConnecterError::AllFailed)
                 }
             }
-
-            if let Some(mut connect_future) = mem::replace(&mut self.connect_opt, None) {
-                match connect_future.poll() {
-                    Ok(Async::Ready(t)) => return Ok(Async::Ready(t)),
+            State::ResolveTarget(resolver, mut ip_lookup, port) => {
+                match ip_lookup.poll() {
                     Ok(Async::NotReady) => {
-                        self.connect_opt = Some(connect_future);
-                        return Ok(Async::NotReady)
+                        self.state = State::ResolveTarget(resolver, ip_lookup, port);
+                        Ok(Async::NotReady)
+                    }
+                    Ok(Async::Ready(ip_result)) => {
+                        let connects =
+                            ip_result.iter()
+                            .map(|ip| RefCell::new(TcpStream::connect(&SocketAddr::new(ip, port))))
+                            .collect();
+                        self.state = State::Connecting(Some(resolver), connects);
+                        self.poll()
                     }
                     Err(_) => {
-                        //println!("Ignore connect error: {:?}", e);
-                    },
+                        // ignore, next…
+                        self.state = State::Connecting(Some(resolver), vec![]);
+                        self.poll()
+                    }
                 }
             }
+            _ => panic!("")
         }
-
-        Ok(Async::NotReady)
     }
 }