1use xml;
2use jid::Jid;
3use transport::{Transport, SslTransport};
4use error::Error;
5use ns;
6use plugin::{Plugin, PluginInit, PluginProxyBinding};
7use connection::{Connection, C2S};
8use sasl::client::Mechanism as SaslMechanism;
9use sasl::client::mechanisms::{Plain, Scram};
10use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
11use sasl::common::scram::{Sha1, Sha256};
12use components::sasl_error::SaslError;
13use util::FromElement;
14use event::{Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
15
16use base64;
17
18use minidom::Element;
19
20use xml::reader::XmlEvent as ReaderEvent;
21
22use std::sync::{Mutex, Arc};
23
24use std::collections::{HashSet, HashMap};
25
26use std::any::TypeId;
27
28/// Struct that should be moved somewhere else and cleaned up.
29#[derive(Debug)]
30pub struct StreamFeatures {
31 pub sasl_mechanisms: Option<HashSet<String>>,
32}
33
34/// A builder for `Client`s.
35pub struct ClientBuilder {
36 jid: Jid,
37 credentials: SaslCredentials,
38 host: Option<String>,
39 port: u16,
40}
41
42impl ClientBuilder {
43 /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
44 pub fn new(jid: Jid) -> ClientBuilder {
45 ClientBuilder {
46 jid: jid,
47 credentials: SaslCredentials::default(),
48 host: None,
49 port: 5222,
50 }
51 }
52
53 /// Sets the host to connect to.
54 pub fn host(mut self, host: String) -> ClientBuilder {
55 self.host = Some(host);
56 self
57 }
58
59 /// Sets the port to connect to.
60 pub fn port(mut self, port: u16) -> ClientBuilder {
61 self.port = port;
62 self
63 }
64
65 /// Sets the password to use.
66 pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
67 self.credentials = SaslCredentials {
68 identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
69 secret: Secret::password_plain(password),
70 channel_binding: ChannelBinding::None,
71 };
72 self
73 }
74
75 /// Connects to the server and returns a `Client` when succesful.
76 pub fn connect(self) -> Result<Client, Error> {
77 let host = &self.host.unwrap_or(self.jid.domain.clone());
78 let mut transport = SslTransport::connect(host, self.port)?;
79 C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
80 let dispatcher = Arc::new(Mutex::new(Dispatcher::new()));
81 let mut credentials = self.credentials;
82 credentials.channel_binding = transport.channel_bind();
83 let transport = Arc::new(Mutex::new(transport));
84 let mut client = Client {
85 jid: self.jid,
86 transport: transport.clone(),
87 plugins: HashMap::new(),
88 binding: PluginProxyBinding::new(dispatcher.clone()),
89 dispatcher: dispatcher,
90 };
91 client.dispatcher.lock().unwrap().register(Priority::Default, move |evt: &SendElement| {
92 let mut t = transport.lock().unwrap();
93 t.write_element(&evt.0).unwrap();
94 Propagation::Continue
95 });
96 client.connect(credentials)?;
97 client.bind()?;
98 Ok(client)
99 }
100}
101
102/// An XMPP client.
103pub struct Client {
104 jid: Jid,
105 transport: Arc<Mutex<SslTransport>>,
106 plugins: HashMap<TypeId, Arc<Box<Plugin>>>,
107 binding: PluginProxyBinding,
108 dispatcher: Arc<Mutex<Dispatcher>>,
109}
110
111impl Client {
112 /// Returns a reference to the `Jid` associated with this `Client`.
113 pub fn jid(&self) -> &Jid {
114 &self.jid
115 }
116
117 /// Registers a plugin.
118 pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
119 let binding = self.binding.clone();
120 plugin.bind(binding);
121 let p = Arc::new(Box::new(plugin) as Box<Plugin>);
122 {
123 let mut disp = self.dispatcher.lock().unwrap();
124 P::init(&mut disp, p.clone());
125 }
126 if self.plugins.insert(TypeId::of::<P>(), p).is_some() {
127 panic!("registering a plugin that's already registered");
128 }
129 }
130
131 /// Returns the plugin given by the type parameter, if it exists, else panics.
132 pub fn plugin<P: Plugin>(&self) -> &P {
133 self.plugins.get(&TypeId::of::<P>())
134 .expect("the requested plugin was not registered")
135 .as_any()
136 .downcast_ref::<P>()
137 .expect("plugin downcast failure (should not happen!!)")
138 }
139
140 /// Returns the next event and flush the send queue.
141 pub fn main(&mut self) -> Result<(), Error> {
142 self.dispatcher.lock().unwrap().flush_all();
143 loop {
144 let elem = self.read_element()?;
145 {
146 let mut disp = self.dispatcher.lock().unwrap();
147 disp.dispatch(ReceiveElement(elem));
148 disp.flush_all();
149 }
150 }
151 }
152
153 fn reset_stream(&self) {
154 self.transport.lock().unwrap().reset_stream()
155 }
156
157 fn read_element(&self) -> Result<Element, Error> {
158 self.transport.lock().unwrap().read_element()
159 }
160
161 fn write_element(&self, elem: &Element) -> Result<(), Error> {
162 self.transport.lock().unwrap().write_element(elem)
163 }
164
165 fn read_event(&self) -> Result<xml::reader::XmlEvent, Error> {
166 self.transport.lock().unwrap().read_event()
167 }
168
169 fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
170 let features = self.wait_for_features()?;
171 let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
172 fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
173 // TODO: better way for selecting these, enabling anonymous auth
174 let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
175 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
176 }
177 else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
178 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
179 }
180 else if ms.contains("SCRAM-SHA-256") {
181 if credentials.channel_binding != ChannelBinding::None {
182 credentials.channel_binding = ChannelBinding::Unsupported;
183 }
184 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
185 }
186 else if ms.contains("SCRAM-SHA-1") {
187 if credentials.channel_binding != ChannelBinding::None {
188 credentials.channel_binding = ChannelBinding::Unsupported;
189 }
190 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
191 }
192 else if ms.contains("PLAIN") {
193 Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
194 }
195 else {
196 return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
197 };
198 let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
199 let mut elem = Element::builder("auth")
200 .ns(ns::SASL)
201 .attr("mechanism", mechanism.name())
202 .build();
203 if !auth.is_empty() {
204 elem.append_text_node(base64::encode(&auth));
205 }
206 self.write_element(&elem)?;
207 loop {
208 let n = self.read_element()?;
209 if n.is("challenge", ns::SASL) {
210 let text = n.text();
211 let challenge = if text == "" {
212 Vec::new()
213 }
214 else {
215 base64::decode(&text)?
216 };
217 let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
218 let mut elem = Element::builder("response")
219 .ns(ns::SASL)
220 .build();
221 if !response.is_empty() {
222 elem.append_text_node(base64::encode(&response));
223 }
224 self.write_element(&elem)?;
225 }
226 else if n.is("success", ns::SASL) {
227 let text = n.text();
228 let data = if text == "" {
229 Vec::new()
230 }
231 else {
232 base64::decode(&text)?
233 };
234 mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
235 self.reset_stream();
236 {
237 let mut g = self.transport.lock().unwrap();
238 C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
239 }
240 self.wait_for_features()?;
241 return Ok(());
242 }
243 else if n.is("failure", ns::SASL) {
244 let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
245 return Err(Error::XmppSaslError(inner));
246 }
247 }
248 }
249
250 fn bind(&mut self) -> Result<(), Error> {
251 let mut elem = Element::builder("iq")
252 .attr("id", "bind")
253 .attr("type", "set")
254 .build();
255 let mut bind = Element::builder("bind")
256 .ns(ns::BIND)
257 .build();
258 if let Some(ref resource) = self.jid.resource {
259 let res = Element::builder("resource")
260 .ns(ns::BIND)
261 .append(resource.to_owned())
262 .build();
263 bind.append_child(res);
264 }
265 elem.append_child(bind);
266 self.write_element(&elem)?;
267 loop {
268 let n = self.read_element()?;
269 if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
270 return Ok(());
271 }
272 }
273 }
274
275 fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
276 // TODO: this is very ugly
277 loop {
278 let e = self.read_event()?;
279 match e {
280 ReaderEvent::StartElement { .. } => {
281 break;
282 },
283 _ => (),
284 }
285 }
286 loop {
287 let n = self.read_element()?;
288 if n.is("features", ns::STREAM) {
289 let mut features = StreamFeatures {
290 sasl_mechanisms: None,
291 };
292 if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
293 let mut res = HashSet::new();
294 for cld in ms.children() {
295 res.insert(cld.text());
296 }
297 features.sasl_mechanisms = Some(res);
298 }
299 return Ok(features);
300 }
301 }
302 }
303}