client_auth: add stream restart

Astro created

Change summary

src/client_auth.rs | 19 ++++++++++++++++++-
src/xmpp_stream.rs |  7 +++++++
2 files changed, 25 insertions(+), 1 deletion(-)

Detailed changes

src/client_auth.rs 🔗

@@ -11,6 +11,7 @@ use serialize::base64::{self, ToBase64, FromBase64};
 
 use xmpp_codec::*;
 use xmpp_stream::*;
+use stream_start::*;
 
 const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 
@@ -22,6 +23,7 @@ pub struct ClientAuth<S: AsyncWrite> {
 enum ClientAuthState<S: AsyncWrite> {
     WaitSend(sink::Send<XMPPStream<S>>),
     WaitRecv(XMPPStream<S>),
+    Start(StreamStart<S>),
     Invalid,
 }
 
@@ -124,7 +126,11 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
                     Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
                         if stanza.name == "success"
                         && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
-                        Ok(Async::Ready(stream)),
+                    {
+                        let start = stream.restart();
+                        self.state = ClientAuthState::Start(start);
+                        self.poll()
+                    },
                     Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
                         if stanza.name == "failure"
                         && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
@@ -153,6 +159,17 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
                     Err(e) =>
                         Err(format!("{}", e)),
                 },
+            ClientAuthState::Start(mut start) =>
+                match start.poll() {
+                    Ok(Async::Ready(stream)) =>
+                        Ok(Async::Ready(stream)),
+                    Ok(Async::NotReady) => {
+                        self.state = ClientAuthState::Start(start);
+                        Ok(Async::NotReady)
+                    },
+                    Err(e) =>
+                        Err(format!("{}", e)),
+                },
             ClientAuthState::Invalid =>
                 unreachable!(),
         }

src/xmpp_stream.rs 🔗

@@ -31,6 +31,13 @@ impl<S: AsyncRead + AsyncWrite> XMPPStream<S> {
         self.stream.into_inner()
     }
 
+    pub fn restart(self) -> StreamStart<S> {
+        let to = self.stream_attrs.get("from")
+            .map(|s| s.to_owned())
+            .unwrap_or_else(|| "".to_owned());
+        Self::from_stream(self.into_inner(), to.clone())
+    }
+
     pub fn can_starttls(&self) -> bool {
         self.stream_features
             .get_child("starttls", Some(NS_XMPP_TLS))