1use std::mem::replace;
2use futures::*;
3use futures::sink;
4use tokio_io::{AsyncRead, AsyncWrite};
5use xml;
6use sasl::common::Credentials;
7use sasl::common::scram::*;
8use sasl::client::Mechanism;
9use sasl::client::mechanisms::*;
10use serialize::base64::{self, ToBase64, FromBase64};
11
12use xmpp_codec::*;
13use xmpp_stream::*;
14use stream_start::*;
15
16const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
17
18pub struct ClientAuth<S: AsyncWrite> {
19 state: ClientAuthState<S>,
20 mechanism: Box<Mechanism>,
21}
22
23enum ClientAuthState<S: AsyncWrite> {
24 WaitSend(sink::Send<XMPPStream<S>>),
25 WaitRecv(XMPPStream<S>),
26 Start(StreamStart<S>),
27 Invalid,
28}
29
30impl<S: AsyncWrite> ClientAuth<S> {
31 pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
32 let mechs: Vec<Box<Mechanism>> = vec![
33 Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
34 Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
35 Box::new(Plain::from_credentials(creds).unwrap()),
36 Box::new(Anonymous::new()),
37 ];
38
39 println!("stream_features: {}", stream.stream_features);
40 let mech_names: Vec<String> =
41 match stream.stream_features.get_child("mechanisms", Some(NS_XMPP_SASL)) {
42 None =>
43 return Err("No auth mechanisms".to_owned()),
44 Some(mechs) =>
45 mechs.get_children("mechanism", Some(NS_XMPP_SASL))
46 .map(|mech_el| mech_el.content_str())
47 .collect(),
48 };
49 println!("Offered mechanisms: {:?}", mech_names);
50
51 for mut mech in mechs {
52 let name = mech.name().to_owned();
53 if mech_names.iter().any(|name1| *name1 == name) {
54 println!("Selected mechanism: {:?}", name);
55 let initial = try!(mech.initial());
56 let mut this = ClientAuth {
57 state: ClientAuthState::Invalid,
58 mechanism: mech,
59 };
60 this.send(
61 stream,
62 "auth", &[("mechanism".to_owned(), name)],
63 &initial
64 );
65 return Ok(this);
66 }
67 }
68
69 Err("No supported SASL mechanism available".to_owned())
70 }
71
72 fn send(&mut self, stream: XMPPStream<S>, nonza_name: &str, attrs: &[(String, String)], content: &[u8]) {
73 let mut nonza = xml::Element::new(
74 nonza_name.to_owned(),
75 Some(NS_XMPP_SASL.to_owned()),
76 attrs.iter()
77 .map(|&(ref name, ref value)| (name.clone(), None, value.clone()))
78 .collect()
79 );
80 nonza.text(content.to_base64(base64::URL_SAFE));
81
82 println!("send {}", nonza);
83 let send = stream.send(Packet::Stanza(nonza));
84
85 self.state = ClientAuthState::WaitSend(send);
86 }
87}
88
89impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
90 type Item = XMPPStream<S>;
91 type Error = String;
92
93 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
94 let state = replace(&mut self.state, ClientAuthState::Invalid);
95
96 match state {
97 ClientAuthState::WaitSend(mut send) =>
98 match send.poll() {
99 Ok(Async::Ready(stream)) => {
100 println!("send done");
101 self.state = ClientAuthState::WaitRecv(stream);
102 self.poll()
103 },
104 Ok(Async::NotReady) => {
105 self.state = ClientAuthState::WaitSend(send);
106 Ok(Async::NotReady)
107 },
108 Err(e) =>
109 Err(format!("{}", e)),
110 },
111 ClientAuthState::WaitRecv(mut stream) =>
112 match stream.poll() {
113 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
114 if stanza.name == "challenge"
115 && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
116 {
117 let content = try!(
118 stanza.content_str()
119 .from_base64()
120 .map_err(|e| format!("{}", e))
121 );
122 let response = try!(self.mechanism.response(&content));
123 self.send(stream, "response", &[], &response);
124 self.poll()
125 },
126 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
127 if stanza.name == "success"
128 && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
129 {
130 let start = stream.restart();
131 self.state = ClientAuthState::Start(start);
132 self.poll()
133 },
134 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
135 if stanza.name == "failure"
136 && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
137 {
138 let mut e = None;
139 for child in &stanza.children {
140 match child {
141 &xml::Xml::ElementNode(ref child) => {
142 e = Some(child.name.clone());
143 break
144 },
145 _ => (),
146 }
147 }
148 let e = e.unwrap_or_else(|| "Authentication failure".to_owned());
149 Err(e)
150 },
151 Ok(Async::Ready(event)) => {
152 println!("ClientAuth ignore {:?}", event);
153 Ok(Async::NotReady)
154 },
155 Ok(_) => {
156 self.state = ClientAuthState::WaitRecv(stream);
157 Ok(Async::NotReady)
158 },
159 Err(e) =>
160 Err(format!("{}", e)),
161 },
162 ClientAuthState::Start(mut start) =>
163 match start.poll() {
164 Ok(Async::Ready(stream)) =>
165 Ok(Async::Ready(stream)),
166 Ok(Async::NotReady) => {
167 self.state = ClientAuthState::Start(start);
168 Ok(Async::NotReady)
169 },
170 Err(e) =>
171 Err(format!("{}", e)),
172 },
173 ClientAuthState::Invalid =>
174 unreachable!(),
175 }
176 }
177}