1use xml;
2use jid::Jid;
3use transport::{Transport, SslTransport};
4use error::Error;
5use ns;
6use plugin::{Plugin, PluginInit, PluginProxyBinding};
7use connection::{Connection, C2S};
8use sasl::client::Mechanism as SaslMechanism;
9use sasl::client::mechanisms::{Plain, Scram};
10use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
11use sasl::common::scram::{Sha1, Sha256};
12use components::sasl_error::SaslError;
13use util::FromElement;
14use event::{Event, Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
15
16use base64;
17
18use minidom::Element;
19
20use xml::reader::XmlEvent as ReaderEvent;
21
22use std::sync::{Mutex, Arc};
23
24use std::collections::{HashSet, HashMap};
25
26use std::any::TypeId;
27
28/// Struct that should be moved somewhere else and cleaned up.
29#[derive(Debug)]
30pub struct StreamFeatures {
31 pub sasl_mechanisms: Option<HashSet<String>>,
32}
33
34/// A builder for `Client`s.
35pub struct ClientBuilder {
36 jid: Jid,
37 credentials: SaslCredentials,
38 host: Option<String>,
39 port: u16,
40}
41
42impl ClientBuilder {
43 /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
44 pub fn new(jid: Jid) -> ClientBuilder {
45 ClientBuilder {
46 jid: jid,
47 credentials: SaslCredentials::default(),
48 host: None,
49 port: 5222,
50 }
51 }
52
53 /// Sets the host to connect to.
54 pub fn host(mut self, host: String) -> ClientBuilder {
55 self.host = Some(host);
56 self
57 }
58
59 /// Sets the port to connect to.
60 pub fn port(mut self, port: u16) -> ClientBuilder {
61 self.port = port;
62 self
63 }
64
65 /// Sets the password to use.
66 pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
67 self.credentials = SaslCredentials {
68 identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
69 secret: Secret::password_plain(password),
70 channel_binding: ChannelBinding::None,
71 };
72 self
73 }
74
75 /// Connects to the server and returns a `Client` when succesful.
76 pub fn connect(self) -> Result<Client, Error> {
77 let host = &self.host.unwrap_or(self.jid.domain.clone());
78 let mut transport = SslTransport::connect(host, self.port)?;
79 C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
80 let dispatcher = Arc::new(Dispatcher::new());
81 let mut credentials = self.credentials;
82 credentials.channel_binding = transport.channel_bind();
83 let transport = Arc::new(Mutex::new(transport));
84 let mut client = Client {
85 jid: self.jid,
86 transport: transport.clone(),
87 plugins: HashMap::new(),
88 binding: PluginProxyBinding::new(dispatcher.clone()),
89 dispatcher: dispatcher,
90 };
91 client.dispatcher.register(Priority::Default, move |evt: &SendElement| {
92 let mut t = transport.lock().unwrap();
93 t.write_element(&evt.0).unwrap();
94 Propagation::Continue
95 });
96 client.connect(credentials)?;
97 client.bind()?;
98 Ok(client)
99 }
100}
101
102/// An XMPP client.
103pub struct Client {
104 jid: Jid,
105 transport: Arc<Mutex<SslTransport>>,
106 plugins: HashMap<TypeId, Arc<Plugin>>,
107 binding: PluginProxyBinding,
108 dispatcher: Arc<Dispatcher>,
109}
110
111impl Client {
112 /// Returns a reference to the `Jid` associated with this `Client`.
113 pub fn jid(&self) -> &Jid {
114 &self.jid
115 }
116
117 /// Registers a plugin.
118 pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
119 let binding = self.binding.clone();
120 plugin.bind(binding);
121 let p = Arc::new(plugin) as Arc<Plugin>;
122 P::init(&self.dispatcher, p.clone());
123 if self.plugins.insert(TypeId::of::<P>(), p).is_some() {
124 panic!("registering a plugin that's already registered");
125 }
126 }
127
128 pub fn register_handler<E, F>(&mut self, pri: Priority, func: F)
129 where
130 E: Event,
131 F: Fn(&E) -> Propagation + 'static {
132 self.dispatcher.register(pri, func);
133 }
134
135 /// Returns the plugin given by the type parameter, if it exists, else panics.
136 pub fn plugin<P: Plugin>(&self) -> &P {
137 self.plugins.get(&TypeId::of::<P>())
138 .expect("the requested plugin was not registered")
139 .as_any()
140 .downcast_ref::<P>()
141 .expect("plugin downcast failure (should not happen!!)")
142 }
143
144 /// Returns the next event and flush the send queue.
145 pub fn main(&mut self) -> Result<(), Error> {
146 self.dispatcher.flush_all();
147 loop {
148 let elem = self.read_element()?;
149 self.dispatcher.dispatch(ReceiveElement(elem));
150 self.dispatcher.flush_all();
151 }
152 }
153
154 fn reset_stream(&self) {
155 self.transport.lock().unwrap().reset_stream()
156 }
157
158 fn read_element(&self) -> Result<Element, Error> {
159 self.transport.lock().unwrap().read_element()
160 }
161
162 fn write_element(&self, elem: &Element) -> Result<(), Error> {
163 self.transport.lock().unwrap().write_element(elem)
164 }
165
166 fn read_event(&self) -> Result<xml::reader::XmlEvent, Error> {
167 self.transport.lock().unwrap().read_event()
168 }
169
170 fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
171 let features = self.wait_for_features()?;
172 let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
173 fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
174 // TODO: better way for selecting these, enabling anonymous auth
175 let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
176 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
177 }
178 else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
179 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
180 }
181 else if ms.contains("SCRAM-SHA-256") {
182 if credentials.channel_binding != ChannelBinding::None {
183 credentials.channel_binding = ChannelBinding::Unsupported;
184 }
185 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
186 }
187 else if ms.contains("SCRAM-SHA-1") {
188 if credentials.channel_binding != ChannelBinding::None {
189 credentials.channel_binding = ChannelBinding::Unsupported;
190 }
191 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
192 }
193 else if ms.contains("PLAIN") {
194 Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
195 }
196 else {
197 return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
198 };
199 let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
200 let mut elem = Element::builder("auth")
201 .ns(ns::SASL)
202 .attr("mechanism", mechanism.name())
203 .build();
204 if !auth.is_empty() {
205 elem.append_text_node(base64::encode(&auth));
206 }
207 self.write_element(&elem)?;
208 loop {
209 let n = self.read_element()?;
210 if n.is("challenge", ns::SASL) {
211 let text = n.text();
212 let challenge = if text == "" {
213 Vec::new()
214 }
215 else {
216 base64::decode(&text)?
217 };
218 let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
219 let mut elem = Element::builder("response")
220 .ns(ns::SASL)
221 .build();
222 if !response.is_empty() {
223 elem.append_text_node(base64::encode(&response));
224 }
225 self.write_element(&elem)?;
226 }
227 else if n.is("success", ns::SASL) {
228 let text = n.text();
229 let data = if text == "" {
230 Vec::new()
231 }
232 else {
233 base64::decode(&text)?
234 };
235 mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
236 self.reset_stream();
237 {
238 let mut g = self.transport.lock().unwrap();
239 C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
240 }
241 self.wait_for_features()?;
242 return Ok(());
243 }
244 else if n.is("failure", ns::SASL) {
245 let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
246 return Err(Error::XmppSaslError(inner));
247 }
248 }
249 }
250
251 fn bind(&mut self) -> Result<(), Error> {
252 let mut elem = Element::builder("iq")
253 .attr("id", "bind")
254 .attr("type", "set")
255 .build();
256 let mut bind = Element::builder("bind")
257 .ns(ns::BIND)
258 .build();
259 if let Some(ref resource) = self.jid.resource {
260 let res = Element::builder("resource")
261 .ns(ns::BIND)
262 .append(resource.to_owned())
263 .build();
264 bind.append_child(res);
265 }
266 elem.append_child(bind);
267 self.write_element(&elem)?;
268 loop {
269 let n = self.read_element()?;
270 if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
271 return Ok(());
272 }
273 }
274 }
275
276 fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
277 // TODO: this is very ugly
278 loop {
279 let e = self.read_event()?;
280 match e {
281 ReaderEvent::StartElement { .. } => {
282 break;
283 },
284 _ => (),
285 }
286 }
287 loop {
288 let n = self.read_element()?;
289 if n.is("features", ns::STREAM) {
290 let mut features = StreamFeatures {
291 sasl_mechanisms: None,
292 };
293 if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
294 let mut res = HashSet::new();
295 for cld in ms.children() {
296 res.insert(cld.text());
297 }
298 features.sasl_mechanisms = Some(res);
299 }
300 return Ok(features);
301 }
302 }
303 }
304}