1use xml;
2use jid::Jid;
3use transport::{Transport, SslTransport};
4use error::Error;
5use ns;
6use plugin::{Plugin, PluginInit, PluginProxyBinding, PluginContainer, PluginRef};
7use connection::{Connection, C2S};
8use sasl::client::Mechanism as SaslMechanism;
9use sasl::client::mechanisms::{Plain, Scram};
10use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
11use sasl::common::scram::{Sha1, Sha256};
12use components::sasl_error::SaslError;
13use util::FromElement;
14use event::{Event, Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
15
16use base64;
17
18use minidom::Element;
19
20use xml::reader::XmlEvent as ReaderEvent;
21
22use std::sync::{Mutex, Arc};
23
24use std::collections::HashSet;
25
26/// Struct that should be moved somewhere else and cleaned up.
27#[derive(Debug)]
28pub struct StreamFeatures {
29 pub sasl_mechanisms: Option<HashSet<String>>,
30}
31
32/// A builder for `Client`s.
33pub struct ClientBuilder {
34 jid: Jid,
35 credentials: SaslCredentials,
36 host: Option<String>,
37 port: u16,
38}
39
40impl ClientBuilder {
41 /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
42 pub fn new(jid: Jid) -> ClientBuilder {
43 ClientBuilder {
44 jid: jid,
45 credentials: SaslCredentials::default(),
46 host: None,
47 port: 5222,
48 }
49 }
50
51 /// Sets the host to connect to.
52 pub fn host(mut self, host: String) -> ClientBuilder {
53 self.host = Some(host);
54 self
55 }
56
57 /// Sets the port to connect to.
58 pub fn port(mut self, port: u16) -> ClientBuilder {
59 self.port = port;
60 self
61 }
62
63 /// Sets the password to use.
64 pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
65 self.credentials = SaslCredentials {
66 identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
67 secret: Secret::password_plain(password),
68 channel_binding: ChannelBinding::None,
69 };
70 self
71 }
72
73 /// Connects to the server and returns a `Client` when succesful.
74 pub fn connect(self) -> Result<Client, Error> {
75 let host = &self.host.unwrap_or(self.jid.domain.clone());
76 let mut transport = SslTransport::connect(host, self.port)?;
77 C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
78 let dispatcher = Arc::new(Dispatcher::new());
79 let mut credentials = self.credentials;
80 credentials.channel_binding = transport.channel_bind();
81 let transport = Arc::new(Mutex::new(transport));
82 let plugin_container = Arc::new(PluginContainer::new());
83 let mut client = Client {
84 jid: self.jid.clone(),
85 transport: transport.clone(),
86 binding: PluginProxyBinding::new(dispatcher.clone(), plugin_container.clone(), self.jid),
87 plugin_container: plugin_container,
88 dispatcher: dispatcher,
89 };
90 client.dispatcher.register(Priority::Default, move |evt: &SendElement| {
91 let mut t = transport.lock().unwrap();
92 t.write_element(&evt.0).unwrap();
93 Propagation::Continue
94 });
95 client.connect(credentials)?;
96 client.bind()?;
97 Ok(client)
98 }
99}
100
101/// An XMPP client.
102pub struct Client {
103 jid: Jid,
104 transport: Arc<Mutex<SslTransport>>,
105 plugin_container: Arc<PluginContainer>,
106 binding: PluginProxyBinding,
107 dispatcher: Arc<Dispatcher>,
108}
109
110impl Client {
111 /// Returns a reference to the `Jid` associated with this `Client`.
112 pub fn jid(&self) -> &Jid {
113 &self.jid
114 }
115
116 /// Registers a plugin.
117 pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
118 let binding = self.binding.clone();
119 plugin.bind(binding);
120 let p = Arc::new(plugin);
121 P::init(&self.dispatcher, p.clone());
122 self.plugin_container.register(p);
123 }
124
125 pub fn register_handler<E, F>(&mut self, pri: Priority, func: F)
126 where
127 E: Event,
128 F: Fn(&E) -> Propagation + 'static {
129 self.dispatcher.register(pri, func);
130 }
131
132 /// Returns the plugin given by the type parameter, if it exists, else panics.
133 pub fn plugin<P: Plugin>(&self) -> PluginRef<P> {
134 self.plugin_container.get::<P>().unwrap()
135 }
136
137 /// Returns the next event and flush the send queue.
138 pub fn main(&mut self) -> Result<(), Error> {
139 self.dispatcher.flush_all();
140 loop {
141 let elem = self.read_element()?;
142 self.dispatcher.dispatch(ReceiveElement(elem));
143 self.dispatcher.flush_all();
144 }
145 }
146
147 fn reset_stream(&self) {
148 self.transport.lock().unwrap().reset_stream()
149 }
150
151 fn read_element(&self) -> Result<Element, Error> {
152 self.transport.lock().unwrap().read_element()
153 }
154
155 fn write_element(&self, elem: &Element) -> Result<(), Error> {
156 self.transport.lock().unwrap().write_element(elem)
157 }
158
159 fn read_event(&self) -> Result<xml::reader::XmlEvent, Error> {
160 self.transport.lock().unwrap().read_event()
161 }
162
163 fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
164 let features = self.wait_for_features()?;
165 let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
166 fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
167 // TODO: better way for selecting these, enabling anonymous auth
168 let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
169 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
170 }
171 else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
172 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
173 }
174 else if ms.contains("SCRAM-SHA-256") {
175 if credentials.channel_binding != ChannelBinding::None {
176 credentials.channel_binding = ChannelBinding::Unsupported;
177 }
178 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
179 }
180 else if ms.contains("SCRAM-SHA-1") {
181 if credentials.channel_binding != ChannelBinding::None {
182 credentials.channel_binding = ChannelBinding::Unsupported;
183 }
184 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
185 }
186 else if ms.contains("PLAIN") {
187 Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
188 }
189 else {
190 return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
191 };
192 let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
193 let mut elem = Element::builder("auth")
194 .ns(ns::SASL)
195 .attr("mechanism", mechanism.name())
196 .build();
197 if !auth.is_empty() {
198 elem.append_text_node(base64::encode(&auth));
199 }
200 self.write_element(&elem)?;
201 loop {
202 let n = self.read_element()?;
203 if n.is("challenge", ns::SASL) {
204 let text = n.text();
205 let challenge = if text == "" {
206 Vec::new()
207 }
208 else {
209 base64::decode(&text)?
210 };
211 let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
212 let mut elem = Element::builder("response")
213 .ns(ns::SASL)
214 .build();
215 if !response.is_empty() {
216 elem.append_text_node(base64::encode(&response));
217 }
218 self.write_element(&elem)?;
219 }
220 else if n.is("success", ns::SASL) {
221 let text = n.text();
222 let data = if text == "" {
223 Vec::new()
224 }
225 else {
226 base64::decode(&text)?
227 };
228 mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
229 self.reset_stream();
230 {
231 let mut g = self.transport.lock().unwrap();
232 C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
233 }
234 self.wait_for_features()?;
235 return Ok(());
236 }
237 else if n.is("failure", ns::SASL) {
238 let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
239 return Err(Error::XmppSaslError(inner));
240 }
241 }
242 }
243
244 fn bind(&mut self) -> Result<(), Error> {
245 let mut elem = Element::builder("iq")
246 .attr("id", "bind")
247 .attr("type", "set")
248 .build();
249 let mut bind = Element::builder("bind")
250 .ns(ns::BIND)
251 .build();
252 if let Some(ref resource) = self.jid.resource {
253 let res = Element::builder("resource")
254 .ns(ns::BIND)
255 .append(resource.to_owned())
256 .build();
257 bind.append_child(res);
258 }
259 elem.append_child(bind);
260 self.write_element(&elem)?;
261 loop {
262 let n = self.read_element()?;
263 if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
264 return Ok(());
265 }
266 }
267 }
268
269 fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
270 // TODO: this is very ugly
271 loop {
272 let e = self.read_event()?;
273 match e {
274 ReaderEvent::StartElement { .. } => {
275 break;
276 },
277 _ => (),
278 }
279 }
280 loop {
281 let n = self.read_element()?;
282 if n.is("features", ns::STREAM) {
283 let mut features = StreamFeatures {
284 sasl_mechanisms: None,
285 };
286 if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
287 let mut res = HashSet::new();
288 for cld in ms.children() {
289 res.insert(cld.text());
290 }
291 features.sasl_mechanisms = Some(res);
292 }
293 return Ok(features);
294 }
295 }
296 }
297}