@@ -1,37 +1,29 @@
-use futures::{sink, Async, Future, Poll, Stream};
+use std::mem::replace;
+use std::str::FromStr;
+use futures::{sink, Async, Future, Poll, Stream, future::{ok, err, IntoFuture}};
use minidom::Element;
use sasl::client::mechanisms::{Anonymous, Plain, Scram};
use sasl::client::Mechanism;
use sasl::common::scram::{Sha1, Sha256};
use sasl::common::Credentials;
-use std::mem::replace;
-use std::str::FromStr;
use tokio_io::{AsyncRead, AsyncWrite};
use try_from::TryFrom;
use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
-use crate::stream_start::StreamStart;
use crate::xmpp_codec::Packet;
use crate::xmpp_stream::XMPPStream;
use crate::{AuthError, Error, ProtocolError};
const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
-pub struct ClientAuth<S: AsyncWrite> {
- state: ClientAuthState<S>,
- mechanism: Box<Mechanism>,
-}
-
-enum ClientAuthState<S: AsyncWrite> {
- WaitSend(sink::Send<XMPPStream<S>>),
- WaitRecv(XMPPStream<S>),
- Start(StreamStart<S>),
- Invalid,
+pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
+ future: Box<Future<Item = XMPPStream<S>, Error = Error>>,
}
-impl<S: AsyncWrite> ClientAuth<S> {
+impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
let mechs: Vec<Box<Mechanism>> = vec![
+ // TODO: Box::new(|| …
Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
Box::new(Plain::from_credentials(creds).unwrap()),
@@ -46,36 +38,74 @@ impl<S: AsyncWrite> ClientAuth<S> {
.filter(|child| child.is("mechanism", NS_XMPP_SASL))
.map(|mech_el| mech_el.text())
.collect();
+ // TODO: iter instead of collect()
// println!("SASL mechanisms offered: {:?}", mech_names);
- for mut mech in mechs {
- let name = mech.name().to_owned();
+ for mut mechanism in mechs {
+ let name = mechanism.name().to_owned();
if mech_names.iter().any(|name1| *name1 == name) {
// println!("SASL mechanism selected: {:?}", name);
- let initial = mech.initial().map_err(AuthError::Sasl)?;
- let mut this = ClientAuth {
- state: ClientAuthState::Invalid,
- mechanism: mech,
- };
- let mechanism = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
- this.send(
- stream,
- Auth {
- mechanism,
- data: initial,
- },
- );
- return Ok(this);
+ let initial = mechanism.initial().map_err(AuthError::Sasl)?;
+ let mechanism_name = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
+
+ let send_initial = Box::new(stream.send_stanza(Auth {
+ mechanism: mechanism_name,
+ data: initial,
+ }))
+ .map_err(Error::Io);
+ let future = Box::new(send_initial.and_then(
+ |stream| Self::handle_challenge(stream, mechanism)
+ ).and_then(
+ |stream| stream.restart()
+ ));
+ return Ok(ClientAuth {
+ future,
+ });
}
}
Err(AuthError::NoMechanism)?
}
- fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
- let send = stream.send_stanza(nonza);
-
- self.state = ClientAuthState::WaitSend(send);
+ fn handle_challenge(stream: XMPPStream<S>, mut mechanism: Box<Mechanism>) -> Box<Future<Item = XMPPStream<S>, Error = Error>> {
+ Box::new(
+ stream.into_future()
+ .map_err(|(e, _stream)| e.into())
+ .and_then(|(stanza, stream)| {
+ match stanza {
+ Some(Packet::Stanza(stanza)) => {
+ if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
+ let response = mechanism
+ .response(&challenge.data);
+ Box::new(
+ response
+ .map_err(|e| AuthError::Sasl(e).into())
+ .into_future()
+ .and_then(|response| {
+ // Send response and loop
+ stream.send_stanza(Response { data: response })
+ .map_err(Error::Io)
+ .and_then(|stream| Self::handle_challenge(stream, mechanism))
+ })
+ )
+ } else if let Ok(_) = Success::try_from(stanza.clone()) {
+ Box::new(ok(stream))
+ } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
+ Box::new(err(Error::Auth(AuthError::Fail(failure.defined_condition))))
+ } else {
+ // ignore and loop
+ println!("Ignore: {:?}", stanza);
+ Self::handle_challenge(stream, mechanism)
+ }
+ }
+ Some(_) => {
+ // ignore and loop
+ Self::handle_challenge(stream, mechanism)
+ }
+ None => Box::new(err(Error::Disconnected))
+ }
+ })
+ )
}
}
@@ -84,58 +114,6 @@ impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
- let state = replace(&mut self.state, ClientAuthState::Invalid);
-
- match state {
- ClientAuthState::WaitSend(mut send) => match send.poll() {
- Ok(Async::Ready(stream)) => {
- self.state = ClientAuthState::WaitRecv(stream);
- self.poll()
- }
- Ok(Async::NotReady) => {
- self.state = ClientAuthState::WaitSend(send);
- Ok(Async::NotReady)
- }
- Err(e) => Err(e)?,
- },
- ClientAuthState::WaitRecv(mut stream) => match stream.poll() {
- Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
- if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
- let response = self
- .mechanism
- .response(&challenge.data)
- .map_err(AuthError::Sasl)?;
- self.send(stream, Response { data: response });
- self.poll()
- } else if let Ok(_) = Success::try_from(stanza.clone()) {
- let start = stream.restart();
- self.state = ClientAuthState::Start(start);
- self.poll()
- } else if let Ok(failure) = Failure::try_from(stanza) {
- Err(AuthError::Fail(failure.defined_condition))?
- } else {
- Ok(Async::NotReady)
- }
- }
- Ok(Async::Ready(_event)) => {
- // println!("ClientAuth ignore {:?}", _event);
- Ok(Async::NotReady)
- }
- Ok(_) => {
- self.state = ClientAuthState::WaitRecv(stream);
- Ok(Async::NotReady)
- }
- Err(e) => Err(ProtocolError::Parser(e))?,
- },
- ClientAuthState::Start(mut start) => match start.poll() {
- Ok(Async::Ready(stream)) => Ok(Async::Ready(stream)),
- Ok(Async::NotReady) => {
- self.state = ClientAuthState::Start(start);
- Ok(Async::NotReady)
- }
- Err(e) => Err(e),
- },
- ClientAuthState::Invalid => unreachable!(),
- }
+ self.future.poll()
}
}
@@ -19,7 +19,7 @@ use xml5ever::interface::Attribute;
use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
/// Anything that can be sent or received on an XMPP/XML stream
-#[derive(Debug)]
+#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Packet {
/// `<stream:stream>` start tag
StreamStart(HashMap<String, String>),