1use jid::Jid;
2use transport::{Transport, SslTransport};
3use error::Error;
4use ns;
5use plugin::{Plugin, PluginInit, PluginProxyBinding, PluginContainer, PluginRef};
6use connection::{Connection, C2S};
7use sasl::client::Mechanism as SaslMechanism;
8use sasl::client::mechanisms::{Plain, Scram};
9use sasl::common::{Credentials as SaslCredentials, Identity, Secret, ChannelBinding};
10use sasl::common::scram::{Sha1, Sha256};
11use components::sasl_error::SaslError;
12use util::FromElement;
13use event::{Event, Dispatcher, Propagation, SendElement, ReceiveElement, Priority};
14
15use base64;
16
17use minidom::Element;
18
19use quick_xml::events::Event as XmlEvent;
20
21use std::sync::{Mutex, Arc};
22
23use std::collections::HashSet;
24
25/// Struct that should be moved somewhere else and cleaned up.
26#[derive(Debug)]
27pub struct StreamFeatures {
28 pub sasl_mechanisms: Option<HashSet<String>>,
29}
30
31/// A builder for `Client`s.
32pub struct ClientBuilder {
33 jid: Jid,
34 credentials: SaslCredentials,
35 host: Option<String>,
36 port: u16,
37}
38
39impl ClientBuilder {
40 /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
41 pub fn new(jid: Jid) -> ClientBuilder {
42 ClientBuilder {
43 jid: jid,
44 credentials: SaslCredentials::default(),
45 host: None,
46 port: 5222,
47 }
48 }
49
50 /// Sets the host to connect to.
51 pub fn host(mut self, host: String) -> ClientBuilder {
52 self.host = Some(host);
53 self
54 }
55
56 /// Sets the port to connect to.
57 pub fn port(mut self, port: u16) -> ClientBuilder {
58 self.port = port;
59 self
60 }
61
62 /// Sets the password to use.
63 pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
64 self.credentials = SaslCredentials {
65 identity: Identity::Username(self.jid.node.clone().expect("JID has no node")),
66 secret: Secret::password_plain(password),
67 channel_binding: ChannelBinding::None,
68 };
69 self
70 }
71
72 /// Connects to the server and returns a `Client` when succesful.
73 pub fn connect(self) -> Result<Client, Error> {
74 let host = &self.host.unwrap_or(self.jid.domain.clone());
75 let mut transport = SslTransport::connect(host, self.port)?;
76 C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
77 let dispatcher = Arc::new(Dispatcher::new());
78 let mut credentials = self.credentials;
79 credentials.channel_binding = transport.channel_bind();
80 let transport = Arc::new(Mutex::new(transport));
81 let plugin_container = Arc::new(PluginContainer::new());
82 let mut client = Client {
83 jid: self.jid.clone(),
84 transport: transport.clone(),
85 binding: PluginProxyBinding::new(dispatcher.clone(), plugin_container.clone(), self.jid),
86 plugin_container: plugin_container,
87 dispatcher: dispatcher,
88 };
89 client.dispatcher.register(Priority::Default, move |evt: &SendElement| {
90 let mut t = transport.lock().unwrap();
91 t.write_element(&evt.0).unwrap();
92 Propagation::Continue
93 });
94 client.connect(credentials)?;
95 client.bind()?;
96 Ok(client)
97 }
98}
99
100/// An XMPP client.
101pub struct Client {
102 jid: Jid,
103 transport: Arc<Mutex<SslTransport>>,
104 plugin_container: Arc<PluginContainer>,
105 binding: PluginProxyBinding,
106 dispatcher: Arc<Dispatcher>,
107}
108
109impl Client {
110 /// Returns a reference to the `Jid` associated with this `Client`.
111 pub fn jid(&self) -> &Jid {
112 &self.jid
113 }
114
115 /// Registers a plugin.
116 pub fn register_plugin<P: Plugin + PluginInit + 'static>(&mut self, mut plugin: P) {
117 let binding = self.binding.clone();
118 plugin.bind(binding);
119 let p = Arc::new(plugin);
120 P::init(&self.dispatcher, p.clone());
121 self.plugin_container.register(p);
122 }
123
124 pub fn register_handler<E, F>(&mut self, pri: Priority, func: F)
125 where
126 E: Event,
127 F: Fn(&E) -> Propagation + 'static {
128 self.dispatcher.register(pri, func);
129 }
130
131 /// Returns the plugin given by the type parameter, if it exists, else panics.
132 pub fn plugin<P: Plugin>(&self) -> PluginRef<P> {
133 self.plugin_container.get::<P>().unwrap()
134 }
135
136 /// Returns the next event and flush the send queue.
137 pub fn main(&mut self) -> Result<(), Error> {
138 self.dispatcher.flush_all();
139 loop {
140 let elem = self.read_element()?;
141 self.dispatcher.dispatch(ReceiveElement(elem));
142 self.dispatcher.flush_all();
143 }
144 }
145
146 fn reset_stream(&self) {
147 self.transport.lock().unwrap().reset_stream()
148 }
149
150 fn read_element(&self) -> Result<Element, Error> {
151 self.transport.lock().unwrap().read_element()
152 }
153
154 fn write_element(&self, elem: &Element) -> Result<(), Error> {
155 self.transport.lock().unwrap().write_element(elem)
156 }
157
158 fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
159 let features = self.wait_for_features()?;
160 let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
161 fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
162 // TODO: better way for selecting these, enabling anonymous auth
163 let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
164 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
165 }
166 else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
167 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
168 }
169 else if ms.contains("SCRAM-SHA-256") {
170 if credentials.channel_binding != ChannelBinding::None {
171 credentials.channel_binding = ChannelBinding::Unsupported;
172 }
173 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
174 }
175 else if ms.contains("SCRAM-SHA-1") {
176 if credentials.channel_binding != ChannelBinding::None {
177 credentials.channel_binding = ChannelBinding::Unsupported;
178 }
179 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
180 }
181 else if ms.contains("PLAIN") {
182 Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
183 }
184 else {
185 return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
186 };
187 let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
188 let mut elem = Element::builder("auth")
189 .ns(ns::SASL)
190 .attr("mechanism", mechanism.name())
191 .build();
192 if !auth.is_empty() {
193 elem.append_text_node(base64::encode(&auth));
194 }
195 self.write_element(&elem)?;
196 loop {
197 let n = self.read_element()?;
198 if n.is("challenge", ns::SASL) {
199 let text = n.text();
200 let challenge = if text == "" {
201 Vec::new()
202 }
203 else {
204 base64::decode(&text)?
205 };
206 let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
207 let mut elem = Element::builder("response")
208 .ns(ns::SASL)
209 .build();
210 if !response.is_empty() {
211 elem.append_text_node(base64::encode(&response));
212 }
213 self.write_element(&elem)?;
214 }
215 else if n.is("success", ns::SASL) {
216 let text = n.text();
217 let data = if text == "" {
218 Vec::new()
219 }
220 else {
221 base64::decode(&text)?
222 };
223 mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
224 self.reset_stream();
225 {
226 let mut g = self.transport.lock().unwrap();
227 C2S::init(&mut *g, &self.jid.domain, "after_sasl")?;
228 }
229 self.wait_for_features()?;
230 return Ok(());
231 }
232 else if n.is("failure", ns::SASL) {
233 let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
234 return Err(Error::XmppSaslError(inner));
235 }
236 }
237 }
238
239 fn bind(&mut self) -> Result<(), Error> {
240 let mut elem = Element::builder("iq")
241 .attr("id", "bind")
242 .attr("type", "set")
243 .build();
244 let mut bind = Element::builder("bind")
245 .ns(ns::BIND)
246 .build();
247 if let Some(ref resource) = self.jid.resource {
248 let res = Element::builder("resource")
249 .ns(ns::BIND)
250 .append(resource.to_owned())
251 .build();
252 bind.append_child(res);
253 }
254 elem.append_child(bind);
255 self.write_element(&elem)?;
256 loop {
257 let n = self.read_element()?;
258 if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
259 return Ok(());
260 }
261 }
262 }
263
264 fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
265 // TODO: this is very ugly
266 loop {
267 let mut transport = self.transport.lock().unwrap();
268 let e = transport.read_event();
269 match e {
270 Ok(XmlEvent::Start { .. }) => {
271 break;
272 },
273 _ => (),
274 }
275 }
276 loop {
277 let n = self.read_element()?;
278 if n.is("features", ns::STREAM) {
279 let mut features = StreamFeatures {
280 sasl_mechanisms: None,
281 };
282 if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
283 let mut res = HashSet::new();
284 for cld in ms.children() {
285 res.insert(cld.text());
286 }
287 features.sasl_mechanisms = Some(res);
288 }
289 return Ok(features);
290 }
291 }
292 }
293}