1use std::mem::replace;
2use futures::{Future, Poll, Async, sink, Sink, Stream};
3use tokio_io::{AsyncRead, AsyncWrite};
4use minidom::Element;
5use sasl::common::Credentials;
6use sasl::common::scram::{Sha1, Sha256};
7use sasl::client::Mechanism;
8use sasl::client::mechanisms::{Scram, Plain, Anonymous};
9use serialize::base64::{self, ToBase64, FromBase64};
10
11use xmpp_codec::Packet;
12use xmpp_stream::XMPPStream;
13use stream_start::StreamStart;
14
15const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
16
17pub struct ClientAuth<S: AsyncWrite> {
18 state: ClientAuthState<S>,
19 mechanism: Box<Mechanism>,
20}
21
22enum ClientAuthState<S: AsyncWrite> {
23 WaitSend(sink::Send<XMPPStream<S>>),
24 WaitRecv(XMPPStream<S>),
25 Start(StreamStart<S>),
26 Invalid,
27}
28
29impl<S: AsyncWrite> ClientAuth<S> {
30 pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
31 let mechs: Vec<Box<Mechanism>> = vec![
32 Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
33 Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
34 Box::new(Plain::from_credentials(creds).unwrap()),
35 Box::new(Anonymous::new()),
36 ];
37
38 let mech_names: Vec<String> =
39 match stream.stream_features.get_child("mechanisms", NS_XMPP_SASL) {
40 None =>
41 return Err("No auth mechanisms".to_owned()),
42 Some(mechs) =>
43 mechs.children()
44 .filter(|child| child.is("mechanism", NS_XMPP_SASL))
45 .map(|mech_el| mech_el.text())
46 .collect(),
47 };
48 println!("SASL mechanisms offered: {:?}", mech_names);
49
50 for mut mech in mechs {
51 let name = mech.name().to_owned();
52 if mech_names.iter().any(|name1| *name1 == name) {
53 println!("SASL mechanism selected: {:?}", name);
54 let initial = try!(mech.initial());
55 let mut this = ClientAuth {
56 state: ClientAuthState::Invalid,
57 mechanism: mech,
58 };
59 this.send(
60 stream,
61 "auth", &[("mechanism", &name)],
62 &initial
63 );
64 return Ok(this);
65 }
66 }
67
68 Err("No supported SASL mechanism available".to_owned())
69 }
70
71 fn send(&mut self, stream: XMPPStream<S>, nonza_name: &str, attrs: &[(&str, &str)], content: &[u8]) {
72 let nonza = Element::builder(nonza_name)
73 .ns(NS_XMPP_SASL);
74 let nonza = attrs.iter()
75 .fold(nonza, |nonza, &(name, value)| nonza.attr(name, value))
76 .append(content.to_base64(base64::STANDARD))
77 .build();
78
79 let send = stream.send_stanza(nonza);
80
81 self.state = ClientAuthState::WaitSend(send);
82 }
83}
84
85impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
86 type Item = XMPPStream<S>;
87 type Error = String;
88
89 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
90 let state = replace(&mut self.state, ClientAuthState::Invalid);
91
92 match state {
93 ClientAuthState::WaitSend(mut send) =>
94 match send.poll() {
95 Ok(Async::Ready(stream)) => {
96 self.state = ClientAuthState::WaitRecv(stream);
97 self.poll()
98 },
99 Ok(Async::NotReady) => {
100 self.state = ClientAuthState::WaitSend(send);
101 Ok(Async::NotReady)
102 },
103 Err(e) =>
104 Err(format!("{}", e)),
105 },
106 ClientAuthState::WaitRecv(mut stream) =>
107 match stream.poll() {
108 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
109 if stanza.is("challenge", NS_XMPP_SASL) =>
110 {
111 let content = try!(
112 stanza.text()
113 .from_base64()
114 .map_err(|e| format!("{}", e))
115 );
116 let response = try!(self.mechanism.response(&content));
117 self.send(stream, "response", &[], &response);
118 self.poll()
119 },
120 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
121 if stanza.is("success", NS_XMPP_SASL) =>
122 {
123 let start = stream.restart();
124 self.state = ClientAuthState::Start(start);
125 self.poll()
126 },
127 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
128 if stanza.is("failure", NS_XMPP_SASL) =>
129 {
130 let e = stanza.children().next()
131 .map(|child| child.name())
132 .unwrap_or("Authentication failure");
133 Err(e.to_owned())
134 },
135 Ok(Async::Ready(event)) => {
136 println!("ClientAuth ignore {:?}", event);
137 Ok(Async::NotReady)
138 },
139 Ok(_) => {
140 self.state = ClientAuthState::WaitRecv(stream);
141 Ok(Async::NotReady)
142 },
143 Err(e) =>
144 Err(format!("{}", e)),
145 },
146 ClientAuthState::Start(mut start) =>
147 match start.poll() {
148 Ok(Async::Ready(stream)) =>
149 Ok(Async::Ready(stream)),
150 Ok(Async::NotReady) => {
151 self.state = ClientAuthState::Start(start);
152 Ok(Async::NotReady)
153 },
154 Err(e) =>
155 Err(format!("{}", e)),
156 },
157 ClientAuthState::Invalid =>
158 unreachable!(),
159 }
160 }
161}