restructure auth code

Astro created

Change summary

src/client/auth.rs | 154 ++++++++++++++++++++---------------------------
src/client/mod.rs  |   2 
src/error.rs       |   2 
src/xmpp_codec.rs  |   2 
4 files changed, 70 insertions(+), 90 deletions(-)

Detailed changes

src/client/auth.rs 🔗

@@ -1,37 +1,29 @@
-use futures::{sink, Async, Future, Poll, Stream};
+use std::mem::replace;
+use std::str::FromStr;
+use futures::{sink, Async, Future, Poll, Stream, future::{ok, err, IntoFuture}};
 use minidom::Element;
 use sasl::client::mechanisms::{Anonymous, Plain, Scram};
 use sasl::client::Mechanism;
 use sasl::common::scram::{Sha1, Sha256};
 use sasl::common::Credentials;
-use std::mem::replace;
-use std::str::FromStr;
 use tokio_io::{AsyncRead, AsyncWrite};
 use try_from::TryFrom;
 use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
 
-use crate::stream_start::StreamStart;
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::XMPPStream;
 use crate::{AuthError, Error, ProtocolError};
 
 const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 
-pub struct ClientAuth<S: AsyncWrite> {
-    state: ClientAuthState<S>,
-    mechanism: Box<Mechanism>,
-}
-
-enum ClientAuthState<S: AsyncWrite> {
-    WaitSend(sink::Send<XMPPStream<S>>),
-    WaitRecv(XMPPStream<S>),
-    Start(StreamStart<S>),
-    Invalid,
+pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
+    future: Box<Future<Item = XMPPStream<S>, Error = Error>>,
 }
 
-impl<S: AsyncWrite> ClientAuth<S> {
+impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
     pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
         let mechs: Vec<Box<Mechanism>> = vec![
+            // TODO: Box::new(|| …
             Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
             Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
             Box::new(Plain::from_credentials(creds).unwrap()),
@@ -46,36 +38,74 @@ impl<S: AsyncWrite> ClientAuth<S> {
             .filter(|child| child.is("mechanism", NS_XMPP_SASL))
             .map(|mech_el| mech_el.text())
             .collect();
+        // TODO: iter instead of collect()
         // println!("SASL mechanisms offered: {:?}", mech_names);
 
-        for mut mech in mechs {
-            let name = mech.name().to_owned();
+        for mut mechanism in mechs {
+            let name = mechanism.name().to_owned();
             if mech_names.iter().any(|name1| *name1 == name) {
                 // println!("SASL mechanism selected: {:?}", name);
-                let initial = mech.initial().map_err(AuthError::Sasl)?;
-                let mut this = ClientAuth {
-                    state: ClientAuthState::Invalid,
-                    mechanism: mech,
-                };
-                let mechanism = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
-                this.send(
-                    stream,
-                    Auth {
-                        mechanism,
-                        data: initial,
-                    },
-                );
-                return Ok(this);
+                let initial = mechanism.initial().map_err(AuthError::Sasl)?;
+                let mechanism_name = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
+
+                let send_initial = Box::new(stream.send_stanza(Auth {
+                    mechanism: mechanism_name,
+                    data: initial,
+                }))
+                    .map_err(Error::Io);
+                let future = Box::new(send_initial.and_then(
+                    |stream| Self::handle_challenge(stream, mechanism)
+                ).and_then(
+                    |stream| stream.restart()
+                ));
+                return Ok(ClientAuth {
+                    future,
+                });
             }
         }
 
         Err(AuthError::NoMechanism)?
     }
 
-    fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
-        let send = stream.send_stanza(nonza);
-
-        self.state = ClientAuthState::WaitSend(send);
+    fn handle_challenge(stream: XMPPStream<S>, mut mechanism: Box<Mechanism>) -> Box<Future<Item = XMPPStream<S>, Error = Error>> {
+        Box::new(
+            stream.into_future()
+            .map_err(|(e, _stream)| e.into())
+            .and_then(|(stanza, stream)| {
+                match stanza {
+                    Some(Packet::Stanza(stanza)) => {
+                        if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
+                            let response = mechanism
+                                .response(&challenge.data);
+                            Box::new(
+                                response
+                                    .map_err(|e| AuthError::Sasl(e).into())
+                                    .into_future()
+                                    .and_then(|response| {
+                                        // Send response and loop
+                                        stream.send_stanza(Response { data: response })
+                                            .map_err(Error::Io)
+                                            .and_then(|stream| Self::handle_challenge(stream, mechanism))
+                                    })
+                            )
+                        } else if let Ok(_) = Success::try_from(stanza.clone()) {
+                            Box::new(ok(stream))
+                        } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
+                            Box::new(err(Error::Auth(AuthError::Fail(failure.defined_condition))))
+                        } else {
+                            // ignore and loop
+                            println!("Ignore: {:?}", stanza);
+                            Self::handle_challenge(stream, mechanism)
+                        }
+                    }
+                    Some(_) => {
+                        // ignore and loop
+                        Self::handle_challenge(stream, mechanism)
+                    }
+                    None => Box::new(err(Error::Disconnected))
+                }
+            })
+        )
     }
 }
 
@@ -84,58 +114,6 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
     type Error = Error;
 
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
-        let state = replace(&mut self.state, ClientAuthState::Invalid);
-
-        match state {
-            ClientAuthState::WaitSend(mut send) => match send.poll() {
-                Ok(Async::Ready(stream)) => {
-                    self.state = ClientAuthState::WaitRecv(stream);
-                    self.poll()
-                }
-                Ok(Async::NotReady) => {
-                    self.state = ClientAuthState::WaitSend(send);
-                    Ok(Async::NotReady)
-                }
-                Err(e) => Err(e)?,
-            },
-            ClientAuthState::WaitRecv(mut stream) => match stream.poll() {
-                Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
-                    if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
-                        let response = self
-                            .mechanism
-                            .response(&challenge.data)
-                            .map_err(AuthError::Sasl)?;
-                        self.send(stream, Response { data: response });
-                        self.poll()
-                    } else if let Ok(_) = Success::try_from(stanza.clone()) {
-                        let start = stream.restart();
-                        self.state = ClientAuthState::Start(start);
-                        self.poll()
-                    } else if let Ok(failure) = Failure::try_from(stanza) {
-                        Err(AuthError::Fail(failure.defined_condition))?
-                    } else {
-                        Ok(Async::NotReady)
-                    }
-                }
-                Ok(Async::Ready(_event)) => {
-                    // println!("ClientAuth ignore {:?}", _event);
-                    Ok(Async::NotReady)
-                }
-                Ok(_) => {
-                    self.state = ClientAuthState::WaitRecv(stream);
-                    Ok(Async::NotReady)
-                }
-                Err(e) => Err(ProtocolError::Parser(e))?,
-            },
-            ClientAuthState::Start(mut start) => match start.poll() {
-                Ok(Async::Ready(stream)) => Ok(Async::Ready(stream)),
-                Ok(Async::NotReady) => {
-                    self.state = ClientAuthState::Start(start);
-                    Ok(Async::NotReady)
-                }
-                Err(e) => Err(e),
-            },
-            ClientAuthState::Invalid => unreachable!(),
-        }
+        self.future.poll()
     }
 }

src/client/mod.rs 🔗

@@ -104,7 +104,7 @@ impl Client {
         StartTlsClient::from_stream(stream)
     }
 
-    fn auth<S: AsyncRead + AsyncWrite>(
+    fn auth<S: AsyncRead + AsyncWrite + 'static>(
         stream: xmpp_stream::XMPPStream<S>,
         username: String,
         password: String,

src/error.rs 🔗

@@ -26,6 +26,8 @@ pub enum Error {
     Auth(AuthError),
     /// TLS error
     Tls(TlsError),
+    /// Connection closed
+    Disconnected,
     /// Shoud never happen
     InvalidState,
 }

src/xmpp_codec.rs 🔗

@@ -19,7 +19,7 @@ use xml5ever::interface::Attribute;
 use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
 
 /// Anything that can be sent or received on an XMPP/XML stream
-#[derive(Debug)]
+#[derive(Debug, Clone, PartialEq, Eq)]
 pub enum Packet {
     /// `<stream:stream>` start tag
     StreamStart(HashMap<String, String>),