client.rs

  1use xml;
  2use jid::Jid;
  3use transport::{Transport, SslTransport};
  4use error::Error;
  5use ns;
  6use plugin::{Plugin, PluginInit, PluginProxyBinding, PluginContainer, PluginRef};
  7use connection::{Connection, C2S};
  8use sasl::client::Mechanism as SaslMechanism;
  9use sasl::client::mechanisms::{Plain, Scram};
 10use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
 11use sasl::common::scram::{Sha1, Sha256};
 12use components::sasl_error::SaslError;
 13use util::FromElement;
 14use event::{Event, Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
 15
 16use base64;
 17
 18use minidom::Element;
 19
 20use xml::reader::XmlEvent as ReaderEvent;
 21
 22use std::sync::{Mutex, Arc};
 23
 24use std::collections::HashSet;
 25
 26/// Struct that should be moved somewhere else and cleaned up.
 27#[derive(Debug)]
 28pub struct StreamFeatures {
 29    pub sasl_mechanisms: Option<HashSet<String>>,
 30}
 31
 32/// A builder for `Client`s.
 33pub struct ClientBuilder {
 34    jid: Jid,
 35    credentials: SaslCredentials,
 36    host: Option<String>,
 37    port: u16,
 38}
 39
 40impl ClientBuilder {
 41    /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
 42    pub fn new(jid: Jid) -> ClientBuilder {
 43        ClientBuilder {
 44            jid: jid,
 45            credentials: SaslCredentials::default(),
 46            host: None,
 47            port: 5222,
 48        }
 49    }
 50
 51    /// Sets the host to connect to.
 52    pub fn host(mut self, host: String) -> ClientBuilder {
 53        self.host = Some(host);
 54        self
 55    }
 56
 57    /// Sets the port to connect to.
 58    pub fn port(mut self, port: u16) -> ClientBuilder {
 59        self.port = port;
 60        self
 61    }
 62
 63    /// Sets the password to use.
 64    pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
 65        self.credentials = SaslCredentials {
 66            identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
 67            secret: Secret::password_plain(password),
 68            channel_binding: ChannelBinding::None,
 69        };
 70        self
 71    }
 72
 73    /// Connects to the server and returns a `Client` when succesful.
 74    pub fn connect(self) -> Result<Client, Error> {
 75        let host = &self.host.unwrap_or(self.jid.domain.clone());
 76        let mut transport = SslTransport::connect(host, self.port)?;
 77        C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
 78        let dispatcher = Arc::new(Dispatcher::new());
 79        let mut credentials = self.credentials;
 80        credentials.channel_binding = transport.channel_bind();
 81        let transport = Arc::new(Mutex::new(transport));
 82        let plugin_container = Arc::new(PluginContainer::new());
 83        let mut client = Client {
 84            jid: self.jid.clone(),
 85            transport: transport.clone(),
 86            binding: PluginProxyBinding::new(dispatcher.clone(), plugin_container.clone(), self.jid),
 87            plugin_container: plugin_container,
 88            dispatcher: dispatcher,
 89        };
 90        client.dispatcher.register(Priority::Default, move |evt: &SendElement| {
 91            let mut t = transport.lock().unwrap();
 92            t.write_element(&evt.0).unwrap();
 93            Propagation::Continue
 94        });
 95        client.connect(credentials)?;
 96        client.bind()?;
 97        Ok(client)
 98    }
 99}
100
101/// An XMPP client.
102pub struct Client {
103    jid: Jid,
104    transport: Arc<Mutex<SslTransport>>,
105    plugin_container: Arc<PluginContainer>,
106    binding: PluginProxyBinding,
107    dispatcher: Arc<Dispatcher>,
108}
109
110impl Client {
111    /// Returns a reference to the `Jid` associated with this `Client`.
112    pub fn jid(&self) -> &Jid {
113        &self.jid
114    }
115
116    /// Registers a plugin.
117    pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
118        let binding = self.binding.clone();
119        plugin.bind(binding);
120        let p = Arc::new(plugin);
121        P::init(&self.dispatcher, p.clone());
122        self.plugin_container.register(p);
123    }
124
125    pub fn register_handler<E, F>(&mut self, pri: Priority, func: F)
126        where
127            E: Event,
128            F: Fn(&E) -> Propagation + 'static {
129        self.dispatcher.register(pri, func);
130    }
131
132    /// Returns the plugin given by the type parameter, if it exists, else panics.
133    pub fn plugin<P: Plugin>(&self) -> PluginRef<P> {
134        self.plugin_container.get::<P>().unwrap()
135    }
136
137    /// Returns the next event and flush the send queue.
138    pub fn main(&mut self) -> Result<(), Error> {
139        self.dispatcher.flush_all();
140        loop {
141            let elem = self.read_element()?;
142            self.dispatcher.dispatch(ReceiveElement(elem));
143            self.dispatcher.flush_all();
144        }
145    }
146
147    fn reset_stream(&self) {
148        self.transport.lock().unwrap().reset_stream()
149    }
150
151    fn read_element(&self) -> Result<Element, Error> {
152        self.transport.lock().unwrap().read_element()
153    }
154
155    fn write_element(&self, elem: &Element) -> Result<(), Error> {
156        self.transport.lock().unwrap().write_element(elem)
157    }
158
159    fn read_event(&self) -> Result<xml::reader::XmlEvent, Error> {
160        self.transport.lock().unwrap().read_event()
161    }
162
163    fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
164        let features = self.wait_for_features()?;
165        let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
166        fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
167        // TODO: better way for selecting these, enabling anonymous auth
168        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
169            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
170        }
171        else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
172            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
173        }
174        else if ms.contains("SCRAM-SHA-256") {
175            if credentials.channel_binding != ChannelBinding::None {
176                credentials.channel_binding = ChannelBinding::Unsupported;
177            }
178            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
179        }
180        else if ms.contains("SCRAM-SHA-1") {
181            if credentials.channel_binding != ChannelBinding::None {
182                credentials.channel_binding = ChannelBinding::Unsupported;
183            }
184            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
185        }
186        else if ms.contains("PLAIN") {
187            Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
188        }
189        else {
190            return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
191        };
192        let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
193        let mut elem = Element::builder("auth")
194                               .ns(ns::SASL)
195                               .attr("mechanism", mechanism.name())
196                               .build();
197        if !auth.is_empty() {
198            elem.append_text_node(base64::encode(&auth));
199        }
200        self.write_element(&elem)?;
201        loop {
202            let n = self.read_element()?;
203            if n.is("challenge", ns::SASL) {
204                let text = n.text();
205                let challenge = if text == "" {
206                    Vec::new()
207                }
208                else {
209                    base64::decode(&text)?
210                };
211                let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
212                let mut elem = Element::builder("response")
213                                       .ns(ns::SASL)
214                                       .build();
215                if !response.is_empty() {
216                    elem.append_text_node(base64::encode(&response));
217                }
218                self.write_element(&elem)?;
219            }
220            else if n.is("success", ns::SASL) {
221                let text = n.text();
222                let data = if text == "" {
223                    Vec::new()
224                }
225                else {
226                    base64::decode(&text)?
227                };
228                mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
229                self.reset_stream();
230                {
231                    let mut g = self.transport.lock().unwrap();
232                    C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
233                }
234                self.wait_for_features()?;
235                return Ok(());
236            }
237            else if n.is("failure", ns::SASL) {
238                let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
239                return Err(Error::XmppSaslError(inner));
240            }
241        }
242    }
243
244    fn bind(&mut self) -> Result<(), Error> {
245        let mut elem = Element::builder("iq")
246                               .attr("id", "bind")
247                               .attr("type", "set")
248                               .build();
249        let mut bind = Element::builder("bind")
250                               .ns(ns::BIND)
251                               .build();
252        if let Some(ref resource) = self.jid.resource {
253            let res = Element::builder("resource")
254                              .ns(ns::BIND)
255                              .append(resource.to_owned())
256                              .build();
257            bind.append_child(res);
258        }
259        elem.append_child(bind);
260        self.write_element(&elem)?;
261        loop {
262            let n = self.read_element()?;
263            if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
264                return Ok(());
265            }
266        }
267    }
268
269    fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
270        // TODO: this is very ugly
271        loop {
272            let e = self.read_event()?;
273            match e {
274                ReaderEvent::StartElement { .. } => {
275                    break;
276                },
277                _ => (),
278            }
279        }
280        loop {
281            let n = self.read_element()?;
282            if n.is("features", ns::STREAM) {
283                let mut features = StreamFeatures {
284                    sasl_mechanisms: None,
285                };
286                if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
287                    let mut res = HashSet::new();
288                    for cld in ms.children() {
289                        res.insert(cld.text());
290                    }
291                    features.sasl_mechanisms = Some(res);
292                }
293                return Ok(features);
294            }
295        }
296    }
297}