auth.rs

 1use futures::stream::StreamExt;
 2use sasl::client::mechanisms::{Anonymous, Plain, Scram};
 3use sasl::client::Mechanism;
 4use sasl::common::scram::{Sha1, Sha256};
 5use sasl::common::Credentials;
 6use std::collections::HashSet;
 7use std::convert::TryFrom;
 8use std::str::FromStr;
 9use tokio::io::{AsyncRead, AsyncWrite};
10use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
11
12use crate::xmpp_codec::Packet;
13use crate::xmpp_stream::XMPPStream;
14use crate::{AuthError, Error, ProtocolError};
15
16pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
17    mut stream: XMPPStream<S>,
18    creds: Credentials,
19) -> Result<S, Error> {
20    let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism + Send + Sync> + Send>> = vec![
21        Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
22        Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
23        Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
24        Box::new(|| Box::new(Anonymous::new())),
25    ];
26
27    let remote_mechs: HashSet<String> = stream.stream_features.sasl_mechanisms()?.collect();
28
29    for local_mech in local_mechs {
30        let mut mechanism = local_mech();
31        if remote_mechs.contains(mechanism.name()) {
32            let initial = mechanism.initial();
33            let mechanism_name =
34                XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
35
36            stream
37                .send_stanza(Auth {
38                    mechanism: mechanism_name,
39                    data: initial,
40                })
41                .await?;
42
43            loop {
44                match stream.next().await {
45                    Some(Ok(Packet::Stanza(stanza))) => {
46                        if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
47                            let response = mechanism
48                                .response(&challenge.data)
49                                .map_err(|e| AuthError::Sasl(e))?;
50
51                            // Send response and loop
52                            stream.send_stanza(Response { data: response }).await?;
53                        } else if let Ok(_) = Success::try_from(stanza.clone()) {
54                            return Ok(stream.into_inner());
55                        } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
56                            return Err(Error::Auth(AuthError::Fail(failure.defined_condition)));
57                        // TODO: This code was needed for compatibility with some broken server,
58                        // but it’s been forgotten which.  It is currently commented out so that we
59                        // can find it and fix the server software instead.
60                        /*
61                        } else if stanza.name() == "failure" {
62                            // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
63                            return Err(Error::Auth(AuthError::Sasl("failure".to_string())));
64                        */
65                        } else {
66                            // ignore and loop
67                        }
68                    }
69                    Some(Ok(_)) => {
70                        // ignore and loop
71                    }
72                    Some(Err(e)) => return Err(e),
73                    None => return Err(Error::Disconnected),
74                }
75            }
76        }
77    }
78
79    Err(AuthError::NoMechanism.into())
80}