1use xml;
2use jid::Jid;
3use transport::{Transport, SslTransport};
4use error::Error;
5use ns;
6use plugin::{Plugin, PluginInit, PluginProxyBinding};
7use connection::{Connection, C2S};
8use sasl::client::Mechanism as SaslMechanism;
9use sasl::client::mechanisms::{Plain, Scram};
10use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
11use sasl::common::scram::{Sha1, Sha256};
12use components::sasl_error::SaslError;
13use util::FromElement;
14use event::{Event, Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
15
16use base64;
17
18use minidom::Element;
19
20use xml::reader::XmlEvent as ReaderEvent;
21
22use std::sync::{Mutex, Arc};
23
24use std::collections::{HashSet, HashMap};
25
26use std::any::TypeId;
27
28/// Struct that should be moved somewhere else and cleaned up.
29#[derive(Debug)]
30pub struct StreamFeatures {
31 pub sasl_mechanisms: Option<HashSet<String>>,
32}
33
34/// A builder for `Client`s.
35pub struct ClientBuilder {
36 jid: Jid,
37 credentials: SaslCredentials,
38 host: Option<String>,
39 port: u16,
40}
41
42impl ClientBuilder {
43 /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
44 pub fn new(jid: Jid) -> ClientBuilder {
45 ClientBuilder {
46 jid: jid,
47 credentials: SaslCredentials::default(),
48 host: None,
49 port: 5222,
50 }
51 }
52
53 /// Sets the host to connect to.
54 pub fn host(mut self, host: String) -> ClientBuilder {
55 self.host = Some(host);
56 self
57 }
58
59 /// Sets the port to connect to.
60 pub fn port(mut self, port: u16) -> ClientBuilder {
61 self.port = port;
62 self
63 }
64
65 /// Sets the password to use.
66 pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
67 self.credentials = SaslCredentials {
68 identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
69 secret: Secret::password_plain(password),
70 channel_binding: ChannelBinding::None,
71 };
72 self
73 }
74
75 /// Connects to the server and returns a `Client` when succesful.
76 pub fn connect(self) -> Result<Client, Error> {
77 let host = &self.host.unwrap_or(self.jid.domain.clone());
78 let mut transport = SslTransport::connect(host, self.port)?;
79 C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
80 let dispatcher = Arc::new(Mutex::new(Dispatcher::new()));
81 let mut credentials = self.credentials;
82 credentials.channel_binding = transport.channel_bind();
83 let transport = Arc::new(Mutex::new(transport));
84 let mut client = Client {
85 jid: self.jid,
86 transport: transport.clone(),
87 plugins: HashMap::new(),
88 binding: PluginProxyBinding::new(dispatcher.clone()),
89 dispatcher: dispatcher,
90 };
91 client.dispatcher.lock().unwrap().register(Priority::Default, move |evt: &SendElement| {
92 let mut t = transport.lock().unwrap();
93 t.write_element(&evt.0).unwrap();
94 Propagation::Continue
95 });
96 client.connect(credentials)?;
97 client.bind()?;
98 Ok(client)
99 }
100}
101
102/// An XMPP client.
103pub struct Client {
104 jid: Jid,
105 transport: Arc<Mutex<SslTransport>>,
106 plugins: HashMap<TypeId, Arc<Box<Plugin>>>,
107 binding: PluginProxyBinding,
108 dispatcher: Arc<Mutex<Dispatcher>>,
109}
110
111impl Client {
112 /// Returns a reference to the `Jid` associated with this `Client`.
113 pub fn jid(&self) -> &Jid {
114 &self.jid
115 }
116
117 /// Registers a plugin.
118 pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
119 let binding = self.binding.clone();
120 plugin.bind(binding);
121 let p = Arc::new(Box::new(plugin) as Box<Plugin>);
122 {
123 let mut disp = self.dispatcher.lock().unwrap();
124 P::init(&mut disp, p.clone());
125 }
126 if self.plugins.insert(TypeId::of::<P>(), p).is_some() {
127 panic!("registering a plugin that's already registered");
128 }
129 }
130
131 pub fn register_handler<E, F>(&mut self, pri: Priority, func: F)
132 where
133 E: Event,
134 F: Fn(&E) -> Propagation + 'static {
135 self.dispatcher.lock().unwrap().register(pri, func);
136 }
137
138 /// Returns the plugin given by the type parameter, if it exists, else panics.
139 pub fn plugin<P: Plugin>(&self) -> &P {
140 self.plugins.get(&TypeId::of::<P>())
141 .expect("the requested plugin was not registered")
142 .as_any()
143 .downcast_ref::<P>()
144 .expect("plugin downcast failure (should not happen!!)")
145 }
146
147 /// Returns the next event and flush the send queue.
148 pub fn main(&mut self) -> Result<(), Error> {
149 self.dispatcher.lock().unwrap().flush_all();
150 loop {
151 let elem = self.read_element()?;
152 {
153 let mut disp = self.dispatcher.lock().unwrap();
154 disp.dispatch(ReceiveElement(elem));
155 disp.flush_all();
156 }
157 }
158 }
159
160 fn reset_stream(&self) {
161 self.transport.lock().unwrap().reset_stream()
162 }
163
164 fn read_element(&self) -> Result<Element, Error> {
165 self.transport.lock().unwrap().read_element()
166 }
167
168 fn write_element(&self, elem: &Element) -> Result<(), Error> {
169 self.transport.lock().unwrap().write_element(elem)
170 }
171
172 fn read_event(&self) -> Result<xml::reader::XmlEvent, Error> {
173 self.transport.lock().unwrap().read_event()
174 }
175
176 fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
177 let features = self.wait_for_features()?;
178 let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
179 fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
180 // TODO: better way for selecting these, enabling anonymous auth
181 let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
182 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
183 }
184 else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
185 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
186 }
187 else if ms.contains("SCRAM-SHA-256") {
188 if credentials.channel_binding != ChannelBinding::None {
189 credentials.channel_binding = ChannelBinding::Unsupported;
190 }
191 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
192 }
193 else if ms.contains("SCRAM-SHA-1") {
194 if credentials.channel_binding != ChannelBinding::None {
195 credentials.channel_binding = ChannelBinding::Unsupported;
196 }
197 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
198 }
199 else if ms.contains("PLAIN") {
200 Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
201 }
202 else {
203 return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
204 };
205 let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
206 let mut elem = Element::builder("auth")
207 .ns(ns::SASL)
208 .attr("mechanism", mechanism.name())
209 .build();
210 if !auth.is_empty() {
211 elem.append_text_node(base64::encode(&auth));
212 }
213 self.write_element(&elem)?;
214 loop {
215 let n = self.read_element()?;
216 if n.is("challenge", ns::SASL) {
217 let text = n.text();
218 let challenge = if text == "" {
219 Vec::new()
220 }
221 else {
222 base64::decode(&text)?
223 };
224 let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
225 let mut elem = Element::builder("response")
226 .ns(ns::SASL)
227 .build();
228 if !response.is_empty() {
229 elem.append_text_node(base64::encode(&response));
230 }
231 self.write_element(&elem)?;
232 }
233 else if n.is("success", ns::SASL) {
234 let text = n.text();
235 let data = if text == "" {
236 Vec::new()
237 }
238 else {
239 base64::decode(&text)?
240 };
241 mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
242 self.reset_stream();
243 {
244 let mut g = self.transport.lock().unwrap();
245 C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
246 }
247 self.wait_for_features()?;
248 return Ok(());
249 }
250 else if n.is("failure", ns::SASL) {
251 let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
252 return Err(Error::XmppSaslError(inner));
253 }
254 }
255 }
256
257 fn bind(&mut self) -> Result<(), Error> {
258 let mut elem = Element::builder("iq")
259 .attr("id", "bind")
260 .attr("type", "set")
261 .build();
262 let mut bind = Element::builder("bind")
263 .ns(ns::BIND)
264 .build();
265 if let Some(ref resource) = self.jid.resource {
266 let res = Element::builder("resource")
267 .ns(ns::BIND)
268 .append(resource.to_owned())
269 .build();
270 bind.append_child(res);
271 }
272 elem.append_child(bind);
273 self.write_element(&elem)?;
274 loop {
275 let n = self.read_element()?;
276 if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
277 return Ok(());
278 }
279 }
280 }
281
282 fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
283 // TODO: this is very ugly
284 loop {
285 let e = self.read_event()?;
286 match e {
287 ReaderEvent::StartElement { .. } => {
288 break;
289 },
290 _ => (),
291 }
292 }
293 loop {
294 let n = self.read_element()?;
295 if n.is("features", ns::STREAM) {
296 let mut features = StreamFeatures {
297 sasl_mechanisms: None,
298 };
299 if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
300 let mut res = HashSet::new();
301 for cld in ms.children() {
302 res.insert(cld.text());
303 }
304 features.sasl_mechanisms = Some(res);
305 }
306 return Ok(features);
307 }
308 }
309 }
310}