client.rs

  1use xml;
  2use jid::Jid;
  3use transport::{Transport, SslTransport};
  4use error::Error;
  5use ns;
  6use plugin::{Plugin, PluginInit, PluginProxyBinding};
  7use connection::{Connection, C2S};
  8use sasl::client::Mechanism as SaslMechanism;
  9use sasl::client::mechanisms::{Plain, Scram};
 10use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
 11use sasl::common::scram::{Sha1, Sha256};
 12use components::sasl_error::SaslError;
 13use util::FromElement;
 14use event::{Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
 15
 16use base64;
 17
 18use minidom::Element;
 19
 20use xml::reader::XmlEvent as ReaderEvent;
 21
 22use std::sync::{Mutex, Arc};
 23
 24use std::collections::{HashSet, HashMap};
 25
 26use std::any::TypeId;
 27
 28/// Struct that should be moved somewhere else and cleaned up.
 29#[derive(Debug)]
 30pub struct StreamFeatures {
 31    pub sasl_mechanisms: Option<HashSet<String>>,
 32}
 33
 34/// A builder for `Client`s.
 35pub struct ClientBuilder {
 36    jid: Jid,
 37    credentials: SaslCredentials,
 38    host: Option<String>,
 39    port: u16,
 40}
 41
 42impl ClientBuilder {
 43    /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
 44    pub fn new(jid: Jid) -> ClientBuilder {
 45        ClientBuilder {
 46            jid: jid,
 47            credentials: SaslCredentials::default(),
 48            host: None,
 49            port: 5222,
 50        }
 51    }
 52
 53    /// Sets the host to connect to.
 54    pub fn host(mut self, host: String) -> ClientBuilder {
 55        self.host = Some(host);
 56        self
 57    }
 58
 59    /// Sets the port to connect to.
 60    pub fn port(mut self, port: u16) -> ClientBuilder {
 61        self.port = port;
 62        self
 63    }
 64
 65    /// Sets the password to use.
 66    pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
 67        self.credentials = SaslCredentials {
 68            identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
 69            secret: Secret::password_plain(password),
 70            channel_binding: ChannelBinding::None,
 71        };
 72        self
 73    }
 74
 75    /// Connects to the server and returns a `Client` when succesful.
 76    pub fn connect(self) -> Result<Client, Error> {
 77        let host = &self.host.unwrap_or(self.jid.domain.clone());
 78        let mut transport = SslTransport::connect(host, self.port)?;
 79        C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
 80        let dispatcher = Arc::new(Mutex::new(Dispatcher::new()));
 81        let mut credentials = self.credentials;
 82        credentials.channel_binding = transport.channel_bind();
 83        let transport = Arc::new(Mutex::new(transport));
 84        let mut client = Client {
 85            jid: self.jid,
 86            transport: transport.clone(),
 87            plugins: HashMap::new(),
 88            binding: PluginProxyBinding::new(dispatcher.clone()),
 89            dispatcher: dispatcher,
 90        };
 91        client.dispatcher.lock().unwrap().register(Priority::Default, move |evt: &SendElement| {
 92            let mut t = transport.lock().unwrap();
 93            t.write_element(&evt.0).unwrap();
 94            Propagation::Continue
 95        });
 96        client.connect(credentials)?;
 97        client.bind()?;
 98        Ok(client)
 99    }
100}
101
102/// An XMPP client.
103pub struct Client {
104    jid: Jid,
105    transport: Arc<Mutex<SslTransport>>,
106    plugins: HashMap<TypeId, Arc<Box<Plugin>>>,
107    binding: PluginProxyBinding,
108    dispatcher: Arc<Mutex<Dispatcher>>,
109}
110
111impl Client {
112    /// Returns a reference to the `Jid` associated with this `Client`.
113    pub fn jid(&self) -> &Jid {
114        &self.jid
115    }
116
117    /// Registers a plugin.
118    pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
119        let binding = self.binding.clone();
120        plugin.bind(binding);
121        let p = Arc::new(Box::new(plugin) as Box<Plugin>);
122        {
123            let mut disp = self.dispatcher.lock().unwrap();
124            P::init(&mut disp, p.clone());
125        }
126        if self.plugins.insert(TypeId::of::<P>(), p).is_some() {
127            panic!("registering a plugin that's already registered");
128        }
129    }
130
131    /// Returns the plugin given by the type parameter, if it exists, else panics.
132    pub fn plugin<P: Plugin>(&self) -> &P {
133        self.plugins.get(&TypeId::of::<P>())
134                    .expect("the requested plugin was not registered")
135                    .as_any()
136                    .downcast_ref::<P>()
137                    .expect("plugin downcast failure (should not happen!!)")
138    }
139
140    /// Returns the next event and flush the send queue.
141    pub fn main(&mut self) -> Result<(), Error> {
142        self.dispatcher.lock().unwrap().flush_all();
143        loop {
144            let elem = self.read_element()?;
145            {
146                let mut disp = self.dispatcher.lock().unwrap();
147                disp.dispatch(ReceiveElement(elem));
148                disp.flush_all();
149            }
150        }
151    }
152
153    fn reset_stream(&self) {
154        self.transport.lock().unwrap().reset_stream()
155    }
156
157    fn read_element(&self) -> Result<Element, Error> {
158        self.transport.lock().unwrap().read_element()
159    }
160
161    fn write_element(&self, elem: &Element) -> Result<(), Error> {
162        self.transport.lock().unwrap().write_element(elem)
163    }
164
165    fn read_event(&self) -> Result<xml::reader::XmlEvent, Error> {
166        self.transport.lock().unwrap().read_event()
167    }
168
169    fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
170        let features = self.wait_for_features()?;
171        let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
172        fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
173        // TODO: better way for selecting these, enabling anonymous auth
174        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
175            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
176        }
177        else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
178            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
179        }
180        else if ms.contains("SCRAM-SHA-256") {
181            if credentials.channel_binding != ChannelBinding::None {
182                credentials.channel_binding = ChannelBinding::Unsupported;
183            }
184            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
185        }
186        else if ms.contains("SCRAM-SHA-1") {
187            if credentials.channel_binding != ChannelBinding::None {
188                credentials.channel_binding = ChannelBinding::Unsupported;
189            }
190            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
191        }
192        else if ms.contains("PLAIN") {
193            Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
194        }
195        else {
196            return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
197        };
198        let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
199        let mut elem = Element::builder("auth")
200                               .ns(ns::SASL)
201                               .attr("mechanism", mechanism.name())
202                               .build();
203        if !auth.is_empty() {
204            elem.append_text_node(base64::encode(&auth));
205        }
206        self.write_element(&elem)?;
207        loop {
208            let n = self.read_element()?;
209            if n.is("challenge", ns::SASL) {
210                let text = n.text();
211                let challenge = if text == "" {
212                    Vec::new()
213                }
214                else {
215                    base64::decode(&text)?
216                };
217                let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
218                let mut elem = Element::builder("response")
219                                       .ns(ns::SASL)
220                                       .build();
221                if !response.is_empty() {
222                    elem.append_text_node(base64::encode(&response));
223                }
224                self.write_element(&elem)?;
225            }
226            else if n.is("success", ns::SASL) {
227                let text = n.text();
228                let data = if text == "" {
229                    Vec::new()
230                }
231                else {
232                    base64::decode(&text)?
233                };
234                mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
235                self.reset_stream();
236                {
237                    let mut g = self.transport.lock().unwrap();
238                    C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
239                }
240                self.wait_for_features()?;
241                return Ok(());
242            }
243            else if n.is("failure", ns::SASL) {
244                let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
245                return Err(Error::XmppSaslError(inner));
246            }
247        }
248    }
249
250    fn bind(&mut self) -> Result<(), Error> {
251        let mut elem = Element::builder("iq")
252                               .attr("id", "bind")
253                               .attr("type", "set")
254                               .build();
255        let mut bind = Element::builder("bind")
256                               .ns(ns::BIND)
257                               .build();
258        if let Some(ref resource) = self.jid.resource {
259            let res = Element::builder("resource")
260                              .ns(ns::BIND)
261                              .append(resource.to_owned())
262                              .build();
263            bind.append_child(res);
264        }
265        elem.append_child(bind);
266        self.write_element(&elem)?;
267        loop {
268            let n = self.read_element()?;
269            if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
270                return Ok(());
271            }
272        }
273    }
274
275    fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
276        // TODO: this is very ugly
277        loop {
278            let e = self.read_event()?;
279            match e {
280                ReaderEvent::StartElement { .. } => {
281                    break;
282                },
283                _ => (),
284            }
285        }
286        loop {
287            let n = self.read_element()?;
288            if n.is("features", ns::STREAM) {
289                let mut features = StreamFeatures {
290                    sasl_mechanisms: None,
291                };
292                if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
293                    let mut res = HashSet::new();
294                    for cld in ms.children() {
295                        res.insert(cld.text());
296                    }
297                    features.sasl_mechanisms = Some(res);
298                }
299                return Ok(features);
300            }
301        }
302    }
303}