client.rs

  1use jid::Jid;
  2use transport::{Transport, SslTransport};
  3use error::Error;
  4use ns;
  5use plugin::{Plugin, PluginProxyBinding};
  6use event::AbstractEvent;
  7use connection::{Connection, C2S};
  8use sasl::{ Mechanism as SaslMechanism
  9          , Credentials as SaslCredentials
 10          , Secret as SaslSecret
 11          , ChannelBinding
 12          };
 13use sasl::mechanisms::{Plain, Scram, Sha1, Sha256};
 14use components::sasl_error::SaslError;
 15use util::FromElement;
 16
 17use base64;
 18
 19use minidom::Element;
 20
 21use xml::reader::XmlEvent as ReaderEvent;
 22
 23use std::sync::mpsc::{Receiver, channel};
 24
 25use std::collections::HashSet;
 26
 27/// Struct that should be moved somewhere else and cleaned up.
 28#[derive(Debug)]
 29pub struct StreamFeatures {
 30    pub sasl_mechanisms: Option<HashSet<String>>,
 31}
 32
 33/// A builder for `Client`s.
 34pub struct ClientBuilder {
 35    jid: Jid,
 36    credentials: SaslCredentials,
 37    host: Option<String>,
 38    port: u16,
 39}
 40
 41impl ClientBuilder {
 42    /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
 43    pub fn new(jid: Jid) -> ClientBuilder {
 44        ClientBuilder {
 45            jid: jid,
 46            credentials: SaslCredentials::default(),
 47            host: None,
 48            port: 5222,
 49        }
 50    }
 51
 52    /// Sets the host to connect to.
 53    pub fn host(mut self, host: String) -> ClientBuilder {
 54        self.host = Some(host);
 55        self
 56    }
 57
 58    /// Sets the port to connect to.
 59    pub fn port(mut self, port: u16) -> ClientBuilder {
 60        self.port = port;
 61        self
 62    }
 63
 64    /// Sets the password to use.
 65    pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
 66        self.credentials = SaslCredentials {
 67            username: Some(self.jid.node.clone().expect("JID has no node")),
 68            secret: SaslSecret::Password(password.into()),
 69            channel_binding: ChannelBinding::None,
 70        };
 71        self
 72    }
 73
 74    /// Connects to the server and returns a `Client` when succesful.
 75    pub fn connect(self) -> Result<Client, Error> {
 76        let host = &self.host.unwrap_or(self.jid.domain.clone());
 77        let mut transport = SslTransport::connect(host, self.port)?;
 78        C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
 79        let (sender_out, sender_in) = channel();
 80        let (dispatcher_out, dispatcher_in) = channel();
 81        let mut credentials = self.credentials;
 82        credentials.channel_binding = transport.channel_bind();
 83        let mut client = Client {
 84            jid: self.jid,
 85            transport: transport,
 86            plugins: Vec::new(),
 87            binding: PluginProxyBinding::new(sender_out, dispatcher_out),
 88            sender_in: sender_in,
 89            dispatcher_in: dispatcher_in,
 90        };
 91        client.connect(credentials)?;
 92        client.bind()?;
 93        Ok(client)
 94    }
 95}
 96
 97/// An XMPP client.
 98pub struct Client {
 99    jid: Jid,
100    transport: SslTransport,
101    plugins: Vec<Box<Plugin>>,
102    binding: PluginProxyBinding,
103    sender_in: Receiver<Element>,
104    dispatcher_in: Receiver<AbstractEvent>,
105}
106
107impl Client {
108    /// Returns a reference to the `Jid` associated with this `Client`.
109    pub fn jid(&self) -> &Jid {
110        &self.jid
111    }
112
113    /// Registers a plugin.
114    pub fn register_plugin<P: Plugin + 'static>(&mut self, mut plugin: P) {
115        plugin.bind(self.binding.clone());
116        self.plugins.push(Box::new(plugin));
117    }
118
119    /// Returns the plugin given by the type parameter, if it exists, else panics.
120    pub fn plugin<P: Plugin>(&self) -> &P {
121        for plugin in &self.plugins {
122            let any = plugin.as_any();
123            if let Some(ret) = any.downcast_ref::<P>() {
124                return ret;
125            }
126        }
127        panic!("plugin does not exist!");
128    }
129
130    /// Returns the next event and flush the send queue.
131    pub fn next_event(&mut self) -> Result<AbstractEvent, Error> {
132        self.flush_send_queue()?;
133        loop {
134            if let Ok(evt) = self.dispatcher_in.try_recv() {
135                return Ok(evt);
136            }
137            let elem = self.transport.read_element()?;
138            for plugin in self.plugins.iter_mut() {
139                plugin.handle(&elem);
140                // TODO: handle plugin return
141            }
142            self.flush_send_queue()?;
143        }
144    }
145
146    /// Flushes the send queue, sending all queued up stanzas.
147    pub fn flush_send_queue(&mut self) -> Result<(), Error> { // TODO: not sure how great of an
148                                                              //       idea it is to flush in this
149                                                              //       manner…
150        while let Ok(elem) = self.sender_in.try_recv() {
151            self.transport.write_element(&elem)?;
152        }
153        Ok(())
154    }
155
156    fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
157        let features = self.wait_for_features()?;
158        let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
159        fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
160        // TODO: better way for selecting these, enabling anonymous auth
161        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
162            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
163        }
164        else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
165            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
166        }
167        else if ms.contains("SCRAM-SHA-256") {
168            if credentials.channel_binding != ChannelBinding::None {
169                credentials.channel_binding = ChannelBinding::Unsupported;
170            }
171            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
172        }
173        else if ms.contains("SCRAM-SHA-1") {
174            if credentials.channel_binding != ChannelBinding::None {
175                credentials.channel_binding = ChannelBinding::Unsupported;
176            }
177            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
178        }
179        else if ms.contains("PLAIN") {
180            Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
181        }
182        else {
183            return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
184        };
185        let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
186        let mut elem = Element::builder("auth")
187                               .ns(ns::SASL)
188                               .attr("mechanism", mechanism.name())
189                               .build();
190        if !auth.is_empty() {
191            elem.append_text_node(base64::encode(&auth));
192        }
193        self.transport.write_element(&elem)?;
194        loop {
195            let n = self.transport.read_element()?;
196            if n.is("challenge", ns::SASL) {
197                let text = n.text();
198                let challenge = if text == "" {
199                    Vec::new()
200                }
201                else {
202                    base64::decode(&text)?
203                };
204                let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
205                let mut elem = Element::builder("response")
206                                       .ns(ns::SASL)
207                                       .build();
208                if !response.is_empty() {
209                    elem.append_text_node(base64::encode(&response));
210                }
211                self.transport.write_element(&elem)?;
212            }
213            else if n.is("success", ns::SASL) {
214                let text = n.text();
215                let data = if text == "" {
216                    Vec::new()
217                }
218                else {
219                    base64::decode(&text)?
220                };
221                mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
222                self.transport.reset_stream();
223                C2S::init(&mut self.transport, &self.jid.domain, "after_sasl")?;
224                self.wait_for_features()?;
225                return Ok(());
226            }
227            else if n.is("failure", ns::SASL) {
228                let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
229                return Err(Error::XmppSaslError(inner));
230            }
231        }
232    }
233
234    fn bind(&mut self) -> Result<(), Error> {
235        let mut elem = Element::builder("iq")
236                               .attr("id", "bind")
237                               .attr("type", "set")
238                               .build();
239        let mut bind = Element::builder("bind")
240                               .ns(ns::BIND)
241                               .build();
242        if let Some(ref resource) = self.jid.resource {
243            let res = Element::builder("resource")
244                              .ns(ns::BIND)
245                              .append(resource.to_owned())
246                              .build();
247            bind.append_child(res);
248        }
249        elem.append_child(bind);
250        self.transport.write_element(&elem)?;
251        loop {
252            let n = self.transport.read_element()?;
253            if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
254                return Ok(());
255            }
256        }
257    }
258
259    fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
260        // TODO: this is very ugly
261        loop {
262            let e = self.transport.read_event()?;
263            match e {
264                ReaderEvent::StartElement { .. } => {
265                    break;
266                },
267                _ => (),
268            }
269        }
270        loop {
271            let n = self.transport.read_element()?;
272            if n.is("features", ns::STREAM) {
273                let mut features = StreamFeatures {
274                    sasl_mechanisms: None,
275                };
276                if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
277                    let mut res = HashSet::new();
278                    for cld in ms.children() {
279                        res.insert(cld.text());
280                    }
281                    features.sasl_mechanisms = Some(res);
282                }
283                return Ok(features);
284            }
285        }
286    }
287}