@@ -11,6 +11,7 @@ use serialize::base64::{self, ToBase64, FromBase64};
use xmpp_codec::*;
use xmpp_stream::*;
+use stream_start::*;
const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
@@ -22,6 +23,7 @@ pub struct ClientAuth<S: AsyncWrite> {
enum ClientAuthState<S: AsyncWrite> {
WaitSend(sink::Send<XMPPStream<S>>),
WaitRecv(XMPPStream<S>),
+ Start(StreamStart<S>),
Invalid,
}
@@ -124,7 +126,11 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
if stanza.name == "success"
&& stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
- Ok(Async::Ready(stream)),
+ {
+ let start = stream.restart();
+ self.state = ClientAuthState::Start(start);
+ self.poll()
+ },
Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
if stanza.name == "failure"
&& stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
@@ -153,6 +159,17 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
Err(e) =>
Err(format!("{}", e)),
},
+ ClientAuthState::Start(mut start) =>
+ match start.poll() {
+ Ok(Async::Ready(stream)) =>
+ Ok(Async::Ready(stream)),
+ Ok(Async::NotReady) => {
+ self.state = ClientAuthState::Start(start);
+ Ok(Async::NotReady)
+ },
+ Err(e) =>
+ Err(format!("{}", e)),
+ },
ClientAuthState::Invalid =>
unreachable!(),
}
@@ -31,6 +31,13 @@ impl<S: AsyncRead + AsyncWrite> XMPPStream<S> {
self.stream.into_inner()
}
+ pub fn restart(self) -> StreamStart<S> {
+ let to = self.stream_attrs.get("from")
+ .map(|s| s.to_owned())
+ .unwrap_or_else(|| "".to_owned());
+ Self::from_stream(self.into_inner(), to.clone())
+ }
+
pub fn can_starttls(&self) -> bool {
self.stream_features
.get_child("starttls", Some(NS_XMPP_TLS))