1use futures::stream::StreamExt;
2use sasl::client::mechanisms::{Anonymous, Plain, Scram};
3use sasl::client::Mechanism;
4use sasl::common::scram::{Sha1, Sha256};
5use sasl::common::Credentials;
6use std::collections::HashSet;
7use std::convert::TryFrom;
8use std::str::FromStr;
9use tokio::io::{AsyncRead, AsyncWrite};
10use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
11
12use crate::xmpp_codec::Packet;
13use crate::xmpp_stream::XMPPStream;
14use crate::{AuthError, Error, ProtocolError};
15
16pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
17 mut stream: XMPPStream<S>,
18 creds: Credentials,
19) -> Result<S, Error> {
20 let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism + Send + Sync> + Send>> = vec![
21 Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
22 Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
23 Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
24 Box::new(|| Box::new(Anonymous::new())),
25 ];
26
27 let remote_mechs: HashSet<String> = stream.stream_features.sasl_mechanisms()?.collect();
28
29 for local_mech in local_mechs {
30 let mut mechanism = local_mech();
31 if remote_mechs.contains(mechanism.name()) {
32 let initial = mechanism.initial();
33 let mechanism_name =
34 XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
35
36 stream
37 .send_stanza(Auth {
38 mechanism: mechanism_name,
39 data: initial,
40 })
41 .await?;
42
43 loop {
44 match stream.next().await {
45 Some(Ok(Packet::Stanza(stanza))) => {
46 if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
47 let response = mechanism
48 .response(&challenge.data)
49 .map_err(|e| AuthError::Sasl(e))?;
50
51 // Send response and loop
52 stream.send_stanza(Response { data: response }).await?;
53 } else if let Ok(_) = Success::try_from(stanza.clone()) {
54 return Ok(stream.into_inner());
55 } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
56 return Err(Error::Auth(AuthError::Fail(failure.defined_condition)));
57 // TODO: This code was needed for compatibility with some broken server,
58 // but it’s been forgotten which. It is currently commented out so that we
59 // can find it and fix the server software instead.
60 /*
61 } else if stanza.name() == "failure" {
62 // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
63 return Err(Error::Auth(AuthError::Sasl("failure".to_string())));
64 */
65 } else {
66 // ignore and loop
67 }
68 }
69 Some(Ok(_)) => {
70 // ignore and loop
71 }
72 Some(Err(e)) => return Err(e),
73 None => return Err(Error::Disconnected),
74 }
75 }
76 }
77 }
78
79 Err(AuthError::NoMechanism.into())
80}