1use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
2use sasl::common::{ChannelBinding, Credentials};
3use std::mem::replace;
4use std::pin::Pin;
5use std::str::FromStr;
6use std::task::Context;
7use tokio::net::TcpStream;
8use tokio::task::JoinHandle;
9use tokio::task::LocalSet;
10#[cfg(feature = "tls-native")]
11use tokio_native_tls::TlsStream;
12#[cfg(feature = "tls-rust")]
13use tokio_rustls::client::TlsStream;
14use xmpp_parsers::{ns, Element, Jid, JidParseError};
15
16use super::auth::auth;
17use super::bind::bind;
18use crate::event::Event;
19use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
20use crate::starttls::starttls;
21use crate::xmpp_codec::Packet;
22use crate::xmpp_stream;
23use crate::{Error, ProtocolError};
24
25/// XMPP client connection and state
26///
27/// It is able to reconnect. TODO: implement session management.
28///
29/// This implements the `futures` crate's [`Stream`](#impl-Stream) and
30/// [`Sink`](#impl-Sink<Packet>) traits.
31pub struct Client {
32 config: Config,
33 state: ClientState,
34 reconnect: bool,
35 // TODO: tls_required=true
36}
37
38/// XMPP server connection configuration
39#[derive(Clone)]
40pub enum ServerConfig {
41 UseSrv,
42 #[allow(unused)]
43 Manual {
44 host: String,
45 port: u16,
46 },
47}
48
49/// XMMPP client configuration
50pub struct Config {
51 jid: Jid,
52 password: String,
53 server: ServerConfig,
54}
55
56type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
57
58enum ClientState {
59 Invalid,
60 Disconnected,
61 Connecting(JoinHandle<Result<XMPPStream, Error>>, LocalSet),
62 Connected(XMPPStream),
63}
64
65impl Client {
66 /// Start a new XMPP client
67 ///
68 /// Start polling the returned instance so that it will connect
69 /// and yield events.
70 pub fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, JidParseError> {
71 let jid = Jid::from_str(jid)?;
72 let config = Config {
73 jid: jid.clone(),
74 password: password.into(),
75 server: ServerConfig::UseSrv,
76 };
77 let client = Self::new_with_config(config);
78 Ok(client)
79 }
80
81 /// Start a new client given that the JID is already parsed.
82 pub fn new_with_config(config: Config) -> Self {
83 let local = LocalSet::new();
84 let connect = local.spawn_local(Self::connect(
85 config.server.clone(),
86 config.jid.clone(),
87 config.password.clone(),
88 ));
89 let client = Client {
90 config,
91 state: ClientState::Connecting(connect, local),
92 reconnect: false,
93 };
94 client
95 }
96
97 /// Set whether to reconnect (`true`) or let the stream end
98 /// (`false`) when a connection to the server has ended.
99 pub fn set_reconnect(&mut self, reconnect: bool) -> &mut Self {
100 self.reconnect = reconnect;
101 self
102 }
103
104 async fn connect(
105 server: ServerConfig,
106 jid: Jid,
107 password: String,
108 ) -> Result<XMPPStream, Error> {
109 let username = jid.clone().node().unwrap();
110 let password = password;
111
112 // TCP connection
113 let tcp_stream = match server {
114 ServerConfig::UseSrv => {
115 connect_with_srv(&jid.clone().domain(), "_xmpp-client._tcp", 5222).await?
116 }
117 ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
118 };
119
120 // Unencryped XMPPStream
121 let xmpp_stream =
122 xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
123 .await?;
124
125 let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
126 // TlsStream
127 let tls_stream = starttls(xmpp_stream).await?;
128 // Encrypted XMPPStream
129 xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
130 .await?
131 } else {
132 return Err(Error::Protocol(ProtocolError::NoTls));
133 };
134
135 let creds = Credentials::default()
136 .with_username(username)
137 .with_password(password)
138 .with_channel_binding(ChannelBinding::None);
139 // Authenticated (unspecified) stream
140 let stream = auth(xmpp_stream, creds).await?;
141 // Authenticated XMPPStream
142 let xmpp_stream =
143 xmpp_stream::XMPPStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?;
144
145 // XMPPStream bound to user session
146 let xmpp_stream = bind(xmpp_stream).await?;
147 Ok(xmpp_stream)
148 }
149
150 /// Get the client's bound JID (the one reported by the XMPP
151 /// server).
152 pub fn bound_jid(&self) -> Option<&Jid> {
153 match self.state {
154 ClientState::Connected(ref stream) => Some(&stream.jid),
155 _ => None,
156 }
157 }
158
159 /// Send stanza
160 pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
161 self.send(Packet::Stanza(stanza)).await
162 }
163
164 /// End connection by sending `</stream:stream>`
165 ///
166 /// You may expect the server to respond with the same. This
167 /// client will then drop its connection.
168 ///
169 /// Make sure to disable reconnect.
170 pub async fn send_end(&mut self) -> Result<(), Error> {
171 self.send(Packet::StreamEnd).await
172 }
173}
174
175/// Incoming XMPP events
176///
177/// In an `async fn` you may want to use this with `use
178/// futures::stream::StreamExt;`
179impl Stream for Client {
180 type Item = Event;
181
182 /// Low-level read on the XMPP stream, allowing the underlying
183 /// machinery to:
184 ///
185 /// * connect,
186 /// * starttls,
187 /// * authenticate,
188 /// * bind a session, and finally
189 /// * receive stanzas
190 ///
191 /// ...for your client
192 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
193 let state = replace(&mut self.state, ClientState::Invalid);
194
195 match state {
196 ClientState::Invalid => panic!("Invalid client state"),
197 ClientState::Disconnected if self.reconnect => {
198 // TODO: add timeout
199 let mut local = LocalSet::new();
200 let connect = local.spawn_local(Self::connect(
201 self.config.server.clone(),
202 self.config.jid.clone(),
203 self.config.password.clone(),
204 ));
205 let _ = Pin::new(&mut local).poll(cx);
206 self.state = ClientState::Connecting(connect, local);
207 self.poll_next(cx)
208 }
209 ClientState::Disconnected => Poll::Ready(None),
210 ClientState::Connecting(mut connect, mut local) => {
211 match Pin::new(&mut connect).poll(cx) {
212 Poll::Ready(Ok(Ok(stream))) => {
213 let bound_jid = stream.jid.clone();
214 self.state = ClientState::Connected(stream);
215 Poll::Ready(Some(Event::Online {
216 bound_jid,
217 resumed: false,
218 }))
219 }
220 Poll::Ready(Ok(Err(e))) => {
221 self.state = ClientState::Disconnected;
222 return Poll::Ready(Some(Event::Disconnected(e.into())));
223 }
224 Poll::Ready(Err(e)) => {
225 self.state = ClientState::Disconnected;
226 panic!("connect task: {}", e);
227 }
228 Poll::Pending => {
229 let _ = Pin::new(&mut local).poll(cx);
230
231 self.state = ClientState::Connecting(connect, local);
232 Poll::Pending
233 }
234 }
235 }
236 ClientState::Connected(mut stream) => {
237 // Poll sink
238 match Pin::new(&mut stream).poll_ready(cx) {
239 Poll::Pending => (),
240 Poll::Ready(Ok(())) => (),
241 Poll::Ready(Err(e)) => {
242 self.state = ClientState::Disconnected;
243 return Poll::Ready(Some(Event::Disconnected(e.into())));
244 }
245 };
246
247 // Poll stream
248 match Pin::new(&mut stream).poll_next(cx) {
249 Poll::Ready(None) => {
250 // EOF
251 self.state = ClientState::Disconnected;
252 Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
253 }
254 Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => {
255 // Receive stanza
256 self.state = ClientState::Connected(stream);
257 Poll::Ready(Some(Event::Stanza(stanza)))
258 }
259 Poll::Ready(Some(Ok(Packet::Text(_)))) => {
260 // Ignore text between stanzas
261 self.state = ClientState::Connected(stream);
262 Poll::Pending
263 }
264 Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => {
265 // <stream:stream>
266 self.state = ClientState::Disconnected;
267 Poll::Ready(Some(Event::Disconnected(
268 ProtocolError::InvalidStreamStart.into(),
269 )))
270 }
271 Poll::Ready(Some(Ok(Packet::StreamEnd))) => {
272 // End of stream: </stream:stream>
273 self.state = ClientState::Disconnected;
274 Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
275 }
276 Poll::Pending => {
277 // Try again later
278 self.state = ClientState::Connected(stream);
279 Poll::Pending
280 }
281 Poll::Ready(Some(Err(e))) => {
282 self.state = ClientState::Disconnected;
283 Poll::Ready(Some(Event::Disconnected(e.into())))
284 }
285 }
286 }
287 }
288 }
289}
290
291/// Outgoing XMPP packets
292///
293/// See `send_stanza()` for an `async fn`
294impl Sink<Packet> for Client {
295 type Error = Error;
296
297 fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
298 match self.state {
299 ClientState::Connected(ref mut stream) => {
300 Pin::new(stream).start_send(item).map_err(|e| e.into())
301 }
302 _ => Err(Error::InvalidState),
303 }
304 }
305
306 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
307 match self.state {
308 ClientState::Connected(ref mut stream) => {
309 Pin::new(stream).poll_ready(cx).map_err(|e| e.into())
310 }
311 _ => Poll::Pending,
312 }
313 }
314
315 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
316 match self.state {
317 ClientState::Connected(ref mut stream) => {
318 Pin::new(stream).poll_flush(cx).map_err(|e| e.into())
319 }
320 _ => Poll::Pending,
321 }
322 }
323
324 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
325 match self.state {
326 ClientState::Connected(ref mut stream) => {
327 Pin::new(stream).poll_close(cx).map_err(|e| e.into())
328 }
329 _ => Poll::Pending,
330 }
331 }
332}