client.rs

  1use jid::Jid;
  2use transport::{Transport, SslTransport};
  3use error::Error;
  4use ns;
  5use plugin::{Plugin, PluginInit, PluginProxyBinding, PluginContainer, PluginRef};
  6use connection::{Connection, C2S};
  7use sasl::client::Mechanism as SaslMechanism;
  8use sasl::client::mechanisms::{Plain, Scram};
  9use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
 10use sasl::common::scram::{Sha1, Sha256};
 11use components::sasl_error::SaslError;
 12use util::FromElement;
 13use event::{Event, Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
 14
 15use base64;
 16
 17use minidom::Element;
 18
 19use quick_xml::events::Event as XmlEvent;
 20
 21use std::sync::{Mutex, Arc};
 22
 23use std::collections::HashSet;
 24
 25/// Struct that should be moved somewhere else and cleaned up.
 26#[derive(Debug)]
 27pub struct StreamFeatures {
 28    pub sasl_mechanisms: Option<HashSet<String>>,
 29}
 30
 31/// A builder for `Client`s.
 32pub struct ClientBuilder {
 33    jid: Jid,
 34    credentials: SaslCredentials,
 35    host: Option<String>,
 36    port: u16,
 37}
 38
 39impl ClientBuilder {
 40    /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
 41    pub fn new(jid: Jid) -> ClientBuilder {
 42        ClientBuilder {
 43            jid: jid,
 44            credentials: SaslCredentials::default(),
 45            host: None,
 46            port: 5222,
 47        }
 48    }
 49
 50    /// Sets the host to connect to.
 51    pub fn host(mut self, host: String) -> ClientBuilder {
 52        self.host = Some(host);
 53        self
 54    }
 55
 56    /// Sets the port to connect to.
 57    pub fn port(mut self, port: u16) -> ClientBuilder {
 58        self.port = port;
 59        self
 60    }
 61
 62    /// Sets the password to use.
 63    pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
 64        self.credentials = SaslCredentials {
 65            identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
 66            secret: Secret::password_plain(password),
 67            channel_binding: ChannelBinding::None,
 68        };
 69        self
 70    }
 71
 72    /// Connects to the server and returns a `Client` when succesful.
 73    pub fn connect(self) -> Result<Client, Error> {
 74        let host = &self.host.unwrap_or(self.jid.domain.clone());
 75        let mut transport = SslTransport::connect(host, self.port)?;
 76        C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
 77        let dispatcher = Arc::new(Dispatcher::new());
 78        let mut credentials = self.credentials;
 79        credentials.channel_binding = transport.channel_bind();
 80        let transport = Arc::new(Mutex::new(transport));
 81        let plugin_container = Arc::new(PluginContainer::new());
 82        let mut client = Client {
 83            jid: self.jid.clone(),
 84            transport: transport.clone(),
 85            binding: PluginProxyBinding::new(dispatcher.clone(), plugin_container.clone(), self.jid),
 86            plugin_container: plugin_container,
 87            dispatcher: dispatcher,
 88        };
 89        client.dispatcher.register(Priority::Default, move |evt: &SendElement| {
 90            let mut t = transport.lock().unwrap();
 91            t.write_element(&evt.0).unwrap();
 92            Propagation::Continue
 93        });
 94        client.connect(credentials)?;
 95        client.bind()?;
 96        Ok(client)
 97    }
 98}
 99
100/// An XMPP client.
101pub struct Client {
102    jid: Jid,
103    transport: Arc<Mutex<SslTransport>>,
104    plugin_container: Arc<PluginContainer>,
105    binding: PluginProxyBinding,
106    dispatcher: Arc<Dispatcher>,
107}
108
109impl Client {
110    /// Returns a reference to the `Jid` associated with this `Client`.
111    pub fn jid(&self) -> &Jid {
112        &self.jid
113    }
114
115    /// Registers a plugin.
116    pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
117        let binding = self.binding.clone();
118        plugin.bind(binding);
119        let p = Arc::new(plugin);
120        P::init(&self.dispatcher, p.clone());
121        self.plugin_container.register(p);
122    }
123
124    pub fn register_handler<E, F>(&mut self, pri: Priority, func: F)
125        where
126            E: Event,
127            F: Fn(&E) -> Propagation + 'static {
128        self.dispatcher.register(pri, func);
129    }
130
131    /// Returns the plugin given by the type parameter, if it exists, else panics.
132    pub fn plugin<P: Plugin>(&self) -> PluginRef<P> {
133        self.plugin_container.get::<P>().unwrap()
134    }
135
136    /// Returns the next event and flush the send queue.
137    pub fn main(&mut self) -> Result<(), Error> {
138        self.dispatcher.flush_all();
139        loop {
140            let elem = self.read_element()?;
141            self.dispatcher.dispatch(ReceiveElement(elem));
142            self.dispatcher.flush_all();
143        }
144    }
145
146    fn reset_stream(&self) {
147        self.transport.lock().unwrap().reset_stream()
148    }
149
150    fn read_element(&self) -> Result<Element, Error> {
151        self.transport.lock().unwrap().read_element()
152    }
153
154    fn write_element(&self, elem: &Element) -> Result<(), Error> {
155        self.transport.lock().unwrap().write_element(elem)
156    }
157
158    fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
159        let features = self.wait_for_features()?;
160        let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
161        fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
162        // TODO: better way for selecting these, enabling anonymous auth
163        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
164            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
165        }
166        else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
167            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
168        }
169        else if ms.contains("SCRAM-SHA-256") {
170            if credentials.channel_binding != ChannelBinding::None {
171                credentials.channel_binding = ChannelBinding::Unsupported;
172            }
173            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
174        }
175        else if ms.contains("SCRAM-SHA-1") {
176            if credentials.channel_binding != ChannelBinding::None {
177                credentials.channel_binding = ChannelBinding::Unsupported;
178            }
179            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
180        }
181        else if ms.contains("PLAIN") {
182            Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
183        }
184        else {
185            return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
186        };
187        let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
188        let mut elem = Element::builder("auth")
189                               .ns(ns::SASL)
190                               .attr("mechanism", mechanism.name())
191                               .build();
192        if !auth.is_empty() {
193            elem.append_text_node(base64::encode(&auth));
194        }
195        self.write_element(&elem)?;
196        loop {
197            let n = self.read_element()?;
198            if n.is("challenge", ns::SASL) {
199                let text = n.text();
200                let challenge = if text == "" {
201                    Vec::new()
202                }
203                else {
204                    base64::decode(&text)?
205                };
206                let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
207                let mut elem = Element::builder("response")
208                                       .ns(ns::SASL)
209                                       .build();
210                if !response.is_empty() {
211                    elem.append_text_node(base64::encode(&response));
212                }
213                self.write_element(&elem)?;
214            }
215            else if n.is("success", ns::SASL) {
216                let text = n.text();
217                let data = if text == "" {
218                    Vec::new()
219                }
220                else {
221                    base64::decode(&text)?
222                };
223                mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
224                self.reset_stream();
225                {
226                    let mut g = self.transport.lock().unwrap();
227                    C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
228                }
229                self.wait_for_features()?;
230                return Ok(());
231            }
232            else if n.is("failure", ns::SASL) {
233                let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
234                return Err(Error::XmppSaslError(inner));
235            }
236        }
237    }
238
239    fn bind(&mut self) -> Result<(), Error> {
240        let mut elem = Element::builder("iq")
241                               .attr("id", "bind")
242                               .attr("type", "set")
243                               .build();
244        let mut bind = Element::builder("bind")
245                               .ns(ns::BIND)
246                               .build();
247        if let Some(ref resource) = self.jid.resource {
248            let res = Element::builder("resource")
249                              .ns(ns::BIND)
250                              .append(resource.to_owned())
251                              .build();
252            bind.append_child(res);
253        }
254        elem.append_child(bind);
255        self.write_element(&elem)?;
256        loop {
257            let n = self.read_element()?;
258            if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
259                return Ok(());
260            }
261        }
262    }
263
264    fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
265        // TODO: this is very ugly
266        loop {
267            let mut transport = self.transport.lock().unwrap();
268            let e = transport.read_event();
269            match e {
270                Ok(XmlEvent::Start { .. }) => {
271                    break;
272                },
273                _ => (),
274            }
275        }
276        loop {
277            let n = self.read_element()?;
278            if n.is("features", ns::STREAM) {
279                let mut features = StreamFeatures {
280                    sasl_mechanisms: None,
281                };
282                if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
283                    let mut res = HashSet::new();
284                    for cld in ms.children() {
285                        res.insert(cld.text());
286                    }
287                    features.sasl_mechanisms = Some(res);
288                }
289                return Ok(features);
290            }
291        }
292    }
293}