1use std::mem::replace;
2use futures::*;
3use futures::sink;
4use tokio_io::{AsyncRead, AsyncWrite};
5use xml;
6use sasl::common::Credentials;
7use sasl::common::scram::*;
8use sasl::client::Mechanism;
9use sasl::client::mechanisms::*;
10use serialize::base64::{self, ToBase64, FromBase64};
11
12use xmpp_codec::*;
13use xmpp_stream::*;
14
15const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
16
17pub struct ClientAuth<S: AsyncWrite> {
18 state: ClientAuthState<S>,
19 mechanism: Box<Mechanism>,
20}
21
22enum ClientAuthState<S: AsyncWrite> {
23 WaitSend(sink::Send<XMPPStream<S>>),
24 WaitRecv(XMPPStream<S>),
25 Invalid,
26}
27
28impl<S: AsyncWrite> ClientAuth<S> {
29 pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
30 let mechs: Vec<Box<Mechanism>> = vec![
31 Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
32 Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
33 Box::new(Plain::from_credentials(creds).unwrap()),
34 Box::new(Anonymous::new()),
35 ];
36
37 println!("stream_features: {}", stream.stream_features);
38 let mech_names: Vec<String> =
39 match stream.stream_features.get_child("mechanisms", Some(NS_XMPP_SASL)) {
40 None =>
41 return Err("No auth mechanisms".to_owned()),
42 Some(mechs) =>
43 mechs.get_children("mechanism", Some(NS_XMPP_SASL))
44 .map(|mech_el| mech_el.content_str())
45 .collect(),
46 };
47 println!("Offered mechanisms: {:?}", mech_names);
48
49 for mut mech in mechs {
50 let name = mech.name().to_owned();
51 if mech_names.iter().any(|name1| *name1 == name) {
52 println!("Selected mechanism: {:?}", name);
53 let initial = try!(mech.initial());
54 let mut this = ClientAuth {
55 state: ClientAuthState::Invalid,
56 mechanism: mech,
57 };
58 this.send(
59 stream,
60 "auth", &[("mechanism".to_owned(), name)],
61 &initial
62 );
63 return Ok(this);
64 }
65 }
66
67 Err("No supported SASL mechanism available".to_owned())
68 }
69
70 fn send(&mut self, stream: XMPPStream<S>, nonza_name: &str, attrs: &[(String, String)], content: &[u8]) {
71 let mut nonza = xml::Element::new(
72 nonza_name.to_owned(),
73 Some(NS_XMPP_SASL.to_owned()),
74 attrs.iter()
75 .map(|&(ref name, ref value)| (name.clone(), None, value.clone()))
76 .collect()
77 );
78 nonza.text(content.to_base64(base64::URL_SAFE));
79
80 println!("send {}", nonza);
81 let send = stream.send(Packet::Stanza(nonza));
82
83 self.state = ClientAuthState::WaitSend(send);
84 }
85}
86
87impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
88 type Item = XMPPStream<S>;
89 type Error = String;
90
91 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
92 let state = replace(&mut self.state, ClientAuthState::Invalid);
93
94 match state {
95 ClientAuthState::WaitSend(mut send) =>
96 match send.poll() {
97 Ok(Async::Ready(stream)) => {
98 println!("send done");
99 self.state = ClientAuthState::WaitRecv(stream);
100 self.poll()
101 },
102 Ok(Async::NotReady) => {
103 self.state = ClientAuthState::WaitSend(send);
104 Ok(Async::NotReady)
105 },
106 Err(e) =>
107 Err(format!("{}", e)),
108 },
109 ClientAuthState::WaitRecv(mut stream) =>
110 match stream.poll() {
111 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
112 if stanza.name == "challenge"
113 && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
114 {
115 let content = try!(
116 stanza.content_str()
117 .from_base64()
118 .map_err(|e| format!("{}", e))
119 );
120 let response = try!(self.mechanism.response(&content));
121 self.send(stream, "response", &[], &response);
122 self.poll()
123 },
124 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
125 if stanza.name == "success"
126 && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
127 Ok(Async::Ready(stream)),
128 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
129 if stanza.name == "failure"
130 && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
131 {
132 let mut e = None;
133 for child in &stanza.children {
134 match child {
135 &xml::Xml::ElementNode(ref child) => {
136 e = Some(child.name.clone());
137 break
138 },
139 _ => (),
140 }
141 }
142 let e = e.unwrap_or_else(|| "Authentication failure".to_owned());
143 Err(e)
144 },
145 Ok(Async::Ready(event)) => {
146 println!("ClientAuth ignore {:?}", event);
147 Ok(Async::NotReady)
148 },
149 Ok(_) => {
150 self.state = ClientAuthState::WaitRecv(stream);
151 Ok(Async::NotReady)
152 },
153 Err(e) =>
154 Err(format!("{}", e)),
155 },
156 ClientAuthState::Invalid =>
157 unreachable!(),
158 }
159 }
160}