1use jid::Jid;
2use transport::{Transport, SslTransport};
3use error::Error;
4use ns;
5use plugin::{Plugin, PluginProxyBinding};
6use event::AbstractEvent;
7use connection::{Connection, C2S};
8use sasl::{ Mechanism as SaslMechanism
9 , Credentials as SaslCredentials
10 , Secret as SaslSecret
11 , ChannelBinding
12 };
13use sasl::mechanisms::{Plain, Scram, Sha1, Sha256};
14use components::sasl_error::SaslError;
15use util::FromElement;
16
17use base64;
18
19use minidom::Element;
20
21use xml::reader::XmlEvent as ReaderEvent;
22
23use std::sync::mpsc::{Receiver, channel};
24
25use std::collections::HashSet;
26
27/// Struct that should be moved somewhere else and cleaned up.
28#[derive(Debug)]
29pub struct StreamFeatures {
30 pub sasl_mechanisms: Option<HashSet<String>>,
31}
32
33/// A builder for `Client`s.
34pub struct ClientBuilder {
35 jid: Jid,
36 credentials: SaslCredentials,
37 host: Option<String>,
38 port: u16,
39}
40
41impl ClientBuilder {
42 /// Creates a new builder for an XMPP client that will connect to `jid` with default parameters.
43 pub fn new(jid: Jid) -> ClientBuilder {
44 ClientBuilder {
45 jid: jid,
46 credentials: SaslCredentials::default(),
47 host: None,
48 port: 5222,
49 }
50 }
51
52 /// Sets the host to connect to.
53 pub fn host(mut self, host: String) -> ClientBuilder {
54 self.host = Some(host);
55 self
56 }
57
58 /// Sets the port to connect to.
59 pub fn port(mut self, port: u16) -> ClientBuilder {
60 self.port = port;
61 self
62 }
63
64 /// Sets the password to use.
65 pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
66 self.credentials = SaslCredentials {
67 username: Some(self.jid.node.clone().expect("JID has no node")),
68 secret: SaslSecret::Password(password.into()),
69 channel_binding: ChannelBinding::None,
70 };
71 self
72 }
73
74 /// Connects to the server and returns a `Client` when succesful.
75 pub fn connect(self) -> Result<Client, Error> {
76 let host = &self.host.unwrap_or(self.jid.domain.clone());
77 let mut transport = SslTransport::connect(host, self.port)?;
78 C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
79 let (sender_out, sender_in) = channel();
80 let (dispatcher_out, dispatcher_in) = channel();
81 let mut credentials = self.credentials;
82 credentials.channel_binding = transport.channel_bind();
83 let mut client = Client {
84 jid: self.jid,
85 transport: transport,
86 plugins: Vec::new(),
87 binding: PluginProxyBinding::new(sender_out, dispatcher_out),
88 sender_in: sender_in,
89 dispatcher_in: dispatcher_in,
90 };
91 client.connect(credentials)?;
92 client.bind()?;
93 Ok(client)
94 }
95}
96
97/// An XMPP client.
98pub struct Client {
99 jid: Jid,
100 transport: SslTransport,
101 plugins: Vec<Box<Plugin>>,
102 binding: PluginProxyBinding,
103 sender_in: Receiver<Element>,
104 dispatcher_in: Receiver<AbstractEvent>,
105}
106
107impl Client {
108 /// Returns a reference to the `Jid` associated with this `Client`.
109 pub fn jid(&self) -> &Jid {
110 &self.jid
111 }
112
113 /// Registers a plugin.
114 pub fn register_plugin<P: Plugin + 'static>(&mut self, mut plugin: P) {
115 plugin.bind(self.binding.clone());
116 self.plugins.push(Box::new(plugin));
117 }
118
119 /// Returns the plugin given by the type parameter, if it exists, else panics.
120 pub fn plugin<P: Plugin>(&self) -> &P {
121 for plugin in &self.plugins {
122 let any = plugin.as_any();
123 if let Some(ret) = any.downcast_ref::<P>() {
124 return ret;
125 }
126 }
127 panic!("plugin does not exist!");
128 }
129
130 /// Returns the next event and flush the send queue.
131 pub fn next_event(&mut self) -> Result<AbstractEvent, Error> {
132 self.flush_send_queue()?;
133 loop {
134 if let Ok(evt) = self.dispatcher_in.try_recv() {
135 return Ok(evt);
136 }
137 let elem = self.transport.read_element()?;
138 for plugin in self.plugins.iter_mut() {
139 plugin.handle(&elem);
140 // TODO: handle plugin return
141 }
142 self.flush_send_queue()?;
143 }
144 }
145
146 /// Flushes the send queue, sending all queued up stanzas.
147 pub fn flush_send_queue(&mut self) -> Result<(), Error> { // TODO: not sure how great of an
148 // idea it is to flush in this
149 // manner…
150 while let Ok(elem) = self.sender_in.try_recv() {
151 self.transport.write_element(&elem)?;
152 }
153 Ok(())
154 }
155
156 fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
157 let features = self.wait_for_features()?;
158 let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
159 fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
160 // TODO: better way for selecting these, enabling anonymous auth
161 let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") && credentials.channel_binding != ChannelBinding::None {
162 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
163 }
164 else if ms.contains("SCRAM-SHA-1-PLUS") && credentials.channel_binding != ChannelBinding::None {
165 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
166 }
167 else if ms.contains("SCRAM-SHA-256") {
168 if credentials.channel_binding != ChannelBinding::None {
169 credentials.channel_binding = ChannelBinding::Unsupported;
170 }
171 Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
172 }
173 else if ms.contains("SCRAM-SHA-1") {
174 if credentials.channel_binding != ChannelBinding::None {
175 credentials.channel_binding = ChannelBinding::Unsupported;
176 }
177 Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
178 }
179 else if ms.contains("PLAIN") {
180 Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
181 }
182 else {
183 return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
184 };
185 let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
186 let mut elem = Element::builder("auth")
187 .ns(ns::SASL)
188 .attr("mechanism", mechanism.name())
189 .build();
190 if !auth.is_empty() {
191 elem.append_text_node(base64::encode(&auth));
192 }
193 self.transport.write_element(&elem)?;
194 loop {
195 let n = self.transport.read_element()?;
196 if n.is("challenge", ns::SASL) {
197 let text = n.text();
198 let challenge = if text == "" {
199 Vec::new()
200 }
201 else {
202 base64::decode(&text)?
203 };
204 let response = mechanism.response(&challenge).map_err(|x| Error::SaslError(Some(x)))?;
205 let mut elem = Element::builder("response")
206 .ns(ns::SASL)
207 .build();
208 if !response.is_empty() {
209 elem.append_text_node(base64::encode(&response));
210 }
211 self.transport.write_element(&elem)?;
212 }
213 else if n.is("success", ns::SASL) {
214 let text = n.text();
215 let data = if text == "" {
216 Vec::new()
217 }
218 else {
219 base64::decode(&text)?
220 };
221 mechanism.success(&data).map_err(|x| Error::SaslError(Some(x)))?;
222 self.transport.reset_stream();
223 C2S::init(&mut self.transport, &self.jid.domain, "after_sasl")?;
224 self.wait_for_features()?;
225 return Ok(());
226 }
227 else if n.is("failure", ns::SASL) {
228 let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
229 return Err(Error::XmppSaslError(inner));
230 }
231 }
232 }
233
234 fn bind(&mut self) -> Result<(), Error> {
235 let mut elem = Element::builder("iq")
236 .attr("id", "bind")
237 .attr("type", "set")
238 .build();
239 let mut bind = Element::builder("bind")
240 .ns(ns::BIND)
241 .build();
242 if let Some(ref resource) = self.jid.resource {
243 let res = Element::builder("resource")
244 .ns(ns::BIND)
245 .append(resource.to_owned())
246 .build();
247 bind.append_child(res);
248 }
249 elem.append_child(bind);
250 self.transport.write_element(&elem)?;
251 loop {
252 let n = self.transport.read_element()?;
253 if n.is("iq", ns::CLIENT) && n.has_child("bind", ns::BIND) {
254 return Ok(());
255 }
256 }
257 }
258
259 fn wait_for_features(&mut self) -> Result<StreamFeatures, Error> {
260 // TODO: this is very ugly
261 loop {
262 let e = self.transport.read_event()?;
263 match e {
264 ReaderEvent::StartElement { .. } => {
265 break;
266 },
267 _ => (),
268 }
269 }
270 loop {
271 let n = self.transport.read_element()?;
272 if n.is("features", ns::STREAM) {
273 let mut features = StreamFeatures {
274 sasl_mechanisms: None,
275 };
276 if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
277 let mut res = HashSet::new();
278 for cld in ms.children() {
279 res.insert(cld.text());
280 }
281 features.sasl_mechanisms = Some(res);
282 }
283 return Ok(features);
284 }
285 }
286 }
287}