client.rs

  1use xml;
  2use jid::Jid;
  3use transport::{Transport, SslTransport};
  4use error::Error;
  5use ns;
  6use plugin::{Plugin, PluginInit, PluginProxyBinding};
  7use connection::{Connection, C2S};
  8use sasl::client::Mechanism as SaslMechanism;
  9use sasl::client::mechanisms::{Plain, Scram};
 10use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
 11use sasl::common::scram::{Sha1, Sha256};
 12use components::sasl_error::SaslError;
 13use util::FromElement;
 14use event::{Event, Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
 15
 16use base64;
 17
 18use minidom::Element;
 19
 20use xml::reader::XmlEvent as ReaderEvent;
 21
 22use std::sync::{Mutex, Arc};
 23
 24use std::collections::{HashSet, HashMap};
 25
 26use std::any::TypeId;
 27
 28/// Struct that should be moved somewhere else and cleaned up.
 29#[derive(Debug)]
 30pub struct StreamFeatures {
 31    pub sasl_mechanisms: Option<HashSet<String>>,
 32}
 33
 34/// A builder for `Client`s.
 35pub struct ClientBuilder {
 36    jid: Jid,
 37    credentials: SaslCredentials,
 38    host: Option<String>,
 39    port: u16,
 40}
 41
 42impl ClientBuilder {
 43    /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
 44    pub fn new(jid: Jid) -> ClientBuilder {
 45        ClientBuilder {
 46            jid: jid,
 47            credentials: SaslCredentials::default(),
 48            host: None,
 49            port: 5222,
 50        }
 51    }
 52
 53    /// Sets the host to connect to.
 54    pub fn host(mut self, host: String) -> ClientBuilder {
 55        self.host = Some(host);
 56        self
 57    }
 58
 59    /// Sets the port to connect to.
 60    pub fn port(mut self, port: u16) -> ClientBuilder {
 61        self.port = port;
 62        self
 63    }
 64
 65    /// Sets the password to use.
 66    pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
 67        self.credentials = SaslCredentials {
 68            identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
 69            secret: Secret::password_plain(password),
 70            channel_binding: ChannelBinding::None,
 71        };
 72        self
 73    }
 74
 75    /// Connects to the server and returns a `Client` when succesful.
 76    pub fn connect(self) -> Result<Client, Error> {
 77        let host = &self.host.unwrap_or(self.jid.domain.clone());
 78        let mut transport = SslTransport::connect(host, self.port)?;
 79        C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
 80        let dispatcher = Arc::new(Mutex::new(Dispatcher::new()));
 81        let mut credentials = self.credentials;
 82        credentials.channel_binding = transport.channel_bind();
 83        let transport = Arc::new(Mutex::new(transport));
 84        let mut client = Client {
 85            jid: self.jid,
 86            transport: transport.clone(),
 87            plugins: HashMap::new(),
 88            binding: PluginProxyBinding::new(dispatcher.clone()),
 89            dispatcher: dispatcher,
 90        };
 91        client.dispatcher.lock().unwrap().register(Priority::Default, move |evt: &SendElement| {
 92            let mut t = transport.lock().unwrap();
 93            t.write_element(&evt.0).unwrap();
 94            Propagation::Continue
 95        });
 96        client.connect(credentials)?;
 97        client.bind()?;
 98        Ok(client)
 99    }
100}
101
102/// An XMPP client.
103pub struct Client {
104    jid: Jid,
105    transport: Arc<Mutex<SslTransport>>,
106    plugins: HashMap<TypeId, Arc<Box<Plugin>>>,
107    binding: PluginProxyBinding,
108    dispatcher: Arc<Mutex<Dispatcher>>,
109}
110
111impl Client {
112    /// Returns a reference to the `Jid` associated with this `Client`.
113    pub fn jid(&self) -> &Jid {
114        &self.jid
115    }
116
117    /// Registers a plugin.
118    pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
119        let binding = self.binding.clone();
120        plugin.bind(binding);
121        let p = Arc::new(Box::new(plugin) as Box<Plugin>);
122        {
123            let mut disp = self.dispatcher.lock().unwrap();
124            P::init(&mut disp, p.clone());
125        }
126        if self.plugins.insert(TypeId::of::<P>(), p).is_some() {
127            panic!("registering a plugin that's already registered");
128        }
129    }
130
131    pub fn register_handler<E, F>(&mut self, pri: Priority, func: F)
132        where
133            E: Event,
134            F: Fn(&E) -> Propagation + 'static {
135        self.dispatcher.lock().unwrap().register(pri, func);
136    }
137
138    /// Returns the plugin given by the type parameter, if it exists, else panics.
139    pub fn plugin<P: Plugin>(&self) -> &P {
140        self.plugins.get(&TypeId::of::<P>())
141                    .expect("the requested plugin was not registered")
142                    .as_any()
143                    .downcast_ref::<P>()
144                    .expect("plugin downcast failure (should not happen!!)")
145    }
146
147    /// Returns the next event and flush the send queue.
148    pub fn main(&mut self) -> Result<(), Error> {
149        self.dispatcher.lock().unwrap().flush_all();
150        loop {
151            let elem = self.read_element()?;
152            {
153                let mut disp = self.dispatcher.lock().unwrap();
154                disp.dispatch(ReceiveElement(elem));
155                disp.flush_all();
156            }
157        }
158    }
159
160    fn reset_stream(&self) {
161        self.transport.lock().unwrap().reset_stream()
162    }
163
164    fn read_element(&self) -> Result<Element, Error> {
165        self.transport.lock().unwrap().read_element()
166    }
167
168    fn write_element(&self, elem: &Element) -> Result<(), Error> {
169        self.transport.lock().unwrap().write_element(elem)
170    }
171
172    fn read_event(&self) -> Result<xml::reader::XmlEvent, Error> {
173        self.transport.lock().unwrap().read_event()
174    }
175
176    fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
177        let features = self.wait_for_features()?;
178        let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
179        fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
180        // TODO: better way for selecting these, enabling anonymous auth
181        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
182            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
183        }
184        else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
185            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
186        }
187        else if ms.contains("SCRAM-SHA-256") {
188            if credentials.channel_binding != ChannelBinding::None {
189                credentials.channel_binding = ChannelBinding::Unsupported;
190            }
191            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
192        }
193        else if ms.contains("SCRAM-SHA-1") {
194            if credentials.channel_binding != ChannelBinding::None {
195                credentials.channel_binding = ChannelBinding::Unsupported;
196            }
197            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
198        }
199        else if ms.contains("PLAIN") {
200            Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
201        }
202        else {
203            return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
204        };
205        let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
206        let mut elem = Element::builder("auth")
207                               .ns(ns::SASL)
208                               .attr("mechanism", mechanism.name())
209                               .build();
210        if !auth.is_empty() {
211            elem.append_text_node(base64::encode(&auth));
212        }
213        self.write_element(&elem)?;
214        loop {
215            let n = self.read_element()?;
216            if n.is("challenge", ns::SASL) {
217                let text = n.text();
218                let challenge = if text == "" {
219                    Vec::new()
220                }
221                else {
222                    base64::decode(&text)?
223                };
224                let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
225                let mut elem = Element::builder("response")
226                                       .ns(ns::SASL)
227                                       .build();
228                if !response.is_empty() {
229                    elem.append_text_node(base64::encode(&response));
230                }
231                self.write_element(&elem)?;
232            }
233            else if n.is("success", ns::SASL) {
234                let text = n.text();
235                let data = if text == "" {
236                    Vec::new()
237                }
238                else {
239                    base64::decode(&text)?
240                };
241                mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
242                self.reset_stream();
243                {
244                    let mut g = self.transport.lock().unwrap();
245                    C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
246                }
247                self.wait_for_features()?;
248                return Ok(());
249            }
250            else if n.is("failure", ns::SASL) {
251                let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
252                return Err(Error::XmppSaslError(inner));
253            }
254        }
255    }
256
257    fn bind(&mut self) -> Result<(), Error> {
258        let mut elem = Element::builder("iq")
259                               .attr("id", "bind")
260                               .attr("type", "set")
261                               .build();
262        let mut bind = Element::builder("bind")
263                               .ns(ns::BIND)
264                               .build();
265        if let Some(ref resource) = self.jid.resource {
266            let res = Element::builder("resource")
267                              .ns(ns::BIND)
268                              .append(resource.to_owned())
269                              .build();
270            bind.append_child(res);
271        }
272        elem.append_child(bind);
273        self.write_element(&elem)?;
274        loop {
275            let n = self.read_element()?;
276            if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
277                return Ok(());
278            }
279        }
280    }
281
282    fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
283        // TODO: this is very ugly
284        loop {
285            let e = self.read_event()?;
286            match e {
287                ReaderEvent::StartElement { .. } => {
288                    break;
289                },
290                _ => (),
291            }
292        }
293        loop {
294            let n = self.read_element()?;
295            if n.is("features", ns::STREAM) {
296                let mut features = StreamFeatures {
297                    sasl_mechanisms: None,
298                };
299                if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
300                    let mut res = HashSet::new();
301                    for cld in ms.children() {
302                        res.insert(cld.text());
303                    }
304                    features.sasl_mechanisms = Some(res);
305                }
306                return Ok(features);
307            }
308        }
309    }
310}