client.rs

  1use xml;
  2use jid::Jid;
  3use transport::{Transport, SslTransport};
  4use error::Error;
  5use ns;
  6use plugin::{Plugin, PluginInit, PluginProxyBinding};
  7use connection::{Connection, C2S};
  8use sasl::client::Mechanism as SaslMechanism;
  9use sasl::client::mechanisms::{Plain, Scram};
 10use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
 11use sasl::common::scram::{Sha1, Sha256};
 12use components::sasl_error::SaslError;
 13use util::FromElement;
 14use event::{Event, Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
 15
 16use base64;
 17
 18use minidom::Element;
 19
 20use xml::reader::XmlEvent as ReaderEvent;
 21
 22use std::sync::{Mutex, Arc};
 23
 24use std::collections::{HashSet, HashMap};
 25
 26use std::any::TypeId;
 27
 28/// Struct that should be moved somewhere else and cleaned up.
 29#[derive(Debug)]
 30pub struct StreamFeatures {
 31    pub sasl_mechanisms: Option<HashSet<String>>,
 32}
 33
 34/// A builder for `Client`s.
 35pub struct ClientBuilder {
 36    jid: Jid,
 37    credentials: SaslCredentials,
 38    host: Option<String>,
 39    port: u16,
 40}
 41
 42impl ClientBuilder {
 43    /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
 44    pub fn new(jid: Jid) -> ClientBuilder {
 45        ClientBuilder {
 46            jid: jid,
 47            credentials: SaslCredentials::default(),
 48            host: None,
 49            port: 5222,
 50        }
 51    }
 52
 53    /// Sets the host to connect to.
 54    pub fn host(mut self, host: String) -> ClientBuilder {
 55        self.host = Some(host);
 56        self
 57    }
 58
 59    /// Sets the port to connect to.
 60    pub fn port(mut self, port: u16) -> ClientBuilder {
 61        self.port = port;
 62        self
 63    }
 64
 65    /// Sets the password to use.
 66    pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
 67        self.credentials = SaslCredentials {
 68            identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
 69            secret: Secret::password_plain(password),
 70            channel_binding: ChannelBinding::None,
 71        };
 72        self
 73    }
 74
 75    /// Connects to the server and returns a `Client` when succesful.
 76    pub fn connect(self) -> Result<Client, Error> {
 77        let host = &self.host.unwrap_or(self.jid.domain.clone());
 78        let mut transport = SslTransport::connect(host, self.port)?;
 79        C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
 80        let dispatcher = Arc::new(Dispatcher::new());
 81        let mut credentials = self.credentials;
 82        credentials.channel_binding = transport.channel_bind();
 83        let transport = Arc::new(Mutex::new(transport));
 84        let mut client = Client {
 85            jid: self.jid,
 86            transport: transport.clone(),
 87            plugins: HashMap::new(),
 88            binding: PluginProxyBinding::new(dispatcher.clone()),
 89            dispatcher: dispatcher,
 90        };
 91        client.dispatcher.register(Priority::Default, move |evt: &SendElement| {
 92            let mut t = transport.lock().unwrap();
 93            t.write_element(&evt.0).unwrap();
 94            Propagation::Continue
 95        });
 96        client.connect(credentials)?;
 97        client.bind()?;
 98        Ok(client)
 99    }
100}
101
102/// An XMPP client.
103pub struct Client {
104    jid: Jid,
105    transport: Arc<Mutex<SslTransport>>,
106    plugins: HashMap<TypeId, Arc<Plugin>>,
107    binding: PluginProxyBinding,
108    dispatcher: Arc<Dispatcher>,
109}
110
111impl Client {
112    /// Returns a reference to the `Jid` associated with this `Client`.
113    pub fn jid(&self) -> &Jid {
114        &self.jid
115    }
116
117    /// Registers a plugin.
118    pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
119        let binding = self.binding.clone();
120        plugin.bind(binding);
121        let p = Arc::new(plugin) as Arc<Plugin>;
122        P::init(&self.dispatcher, p.clone());
123        if self.plugins.insert(TypeId::of::<P>(), p).is_some() {
124            panic!("registering a plugin that's already registered");
125        }
126    }
127
128    pub fn register_handler<E, F>(&mut self, pri: Priority, func: F)
129        where
130            E: Event,
131            F: Fn(&E) -> Propagation + 'static {
132        self.dispatcher.register(pri, func);
133    }
134
135    /// Returns the plugin given by the type parameter, if it exists, else panics.
136    pub fn plugin<P: Plugin>(&self) -> &P {
137        self.plugins.get(&TypeId::of::<P>())
138                    .expect("the requested plugin was not registered")
139                    .as_any()
140                    .downcast_ref::<P>()
141                    .expect("plugin downcast failure (should not happen!!)")
142    }
143
144    /// Returns the next event and flush the send queue.
145    pub fn main(&mut self) -> Result<(), Error> {
146        self.dispatcher.flush_all();
147        loop {
148            let elem = self.read_element()?;
149            self.dispatcher.dispatch(ReceiveElement(elem));
150            self.dispatcher.flush_all();
151        }
152    }
153
154    fn reset_stream(&self) {
155        self.transport.lock().unwrap().reset_stream()
156    }
157
158    fn read_element(&self) -> Result<Element, Error> {
159        self.transport.lock().unwrap().read_element()
160    }
161
162    fn write_element(&self, elem: &Element) -> Result<(), Error> {
163        self.transport.lock().unwrap().write_element(elem)
164    }
165
166    fn read_event(&self) -> Result<xml::reader::XmlEvent, Error> {
167        self.transport.lock().unwrap().read_event()
168    }
169
170    fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
171        let features = self.wait_for_features()?;
172        let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
173        fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
174        // TODO: better way for selecting these, enabling anonymous auth
175        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
176            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
177        }
178        else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
179            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
180        }
181        else if ms.contains("SCRAM-SHA-256") {
182            if credentials.channel_binding != ChannelBinding::None {
183                credentials.channel_binding = ChannelBinding::Unsupported;
184            }
185            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
186        }
187        else if ms.contains("SCRAM-SHA-1") {
188            if credentials.channel_binding != ChannelBinding::None {
189                credentials.channel_binding = ChannelBinding::Unsupported;
190            }
191            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
192        }
193        else if ms.contains("PLAIN") {
194            Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
195        }
196        else {
197            return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
198        };
199        let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
200        let mut elem = Element::builder("auth")
201                               .ns(ns::SASL)
202                               .attr("mechanism", mechanism.name())
203                               .build();
204        if !auth.is_empty() {
205            elem.append_text_node(base64::encode(&auth));
206        }
207        self.write_element(&elem)?;
208        loop {
209            let n = self.read_element()?;
210            if n.is("challenge", ns::SASL) {
211                let text = n.text();
212                let challenge = if text == "" {
213                    Vec::new()
214                }
215                else {
216                    base64::decode(&text)?
217                };
218                let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
219                let mut elem = Element::builder("response")
220                                       .ns(ns::SASL)
221                                       .build();
222                if !response.is_empty() {
223                    elem.append_text_node(base64::encode(&response));
224                }
225                self.write_element(&elem)?;
226            }
227            else if n.is("success", ns::SASL) {
228                let text = n.text();
229                let data = if text == "" {
230                    Vec::new()
231                }
232                else {
233                    base64::decode(&text)?
234                };
235                mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
236                self.reset_stream();
237                {
238                    let mut g = self.transport.lock().unwrap();
239                    C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
240                }
241                self.wait_for_features()?;
242                return Ok(());
243            }
244            else if n.is("failure", ns::SASL) {
245                let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
246                return Err(Error::XmppSaslError(inner));
247            }
248        }
249    }
250
251    fn bind(&mut self) -> Result<(), Error> {
252        let mut elem = Element::builder("iq")
253                               .attr("id", "bind")
254                               .attr("type", "set")
255                               .build();
256        let mut bind = Element::builder("bind")
257                               .ns(ns::BIND)
258                               .build();
259        if let Some(ref resource) = self.jid.resource {
260            let res = Element::builder("resource")
261                              .ns(ns::BIND)
262                              .append(resource.to_owned())
263                              .build();
264            bind.append_child(res);
265        }
266        elem.append_child(bind);
267        self.write_element(&elem)?;
268        loop {
269            let n = self.read_element()?;
270            if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
271                return Ok(());
272            }
273        }
274    }
275
276    fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
277        // TODO: this is very ugly
278        loop {
279            let e = self.read_event()?;
280            match e {
281                ReaderEvent::StartElement { .. } => {
282                    break;
283                },
284                _ => (),
285            }
286        }
287        loop {
288            let n = self.read_element()?;
289            if n.is("features", ns::STREAM) {
290                let mut features = StreamFeatures {
291                    sasl_mechanisms: None,
292                };
293                if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
294                    let mut res = HashSet::new();
295                    for cld in ms.children() {
296                        res.insert(cld.text());
297                    }
298                    features.sasl_mechanisms = Some(res);
299                }
300                return Ok(features);
301            }
302        }
303    }
304}