async_client.rs

  1use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
  2use sasl::common::{ChannelBinding, Credentials};
  3use std::mem::replace;
  4use std::pin::Pin;
  5use std::task::Context;
  6use tokio::net::TcpStream;
  7use tokio::task::JoinHandle;
  8#[cfg(feature = "tls-native")]
  9use tokio_native_tls::TlsStream;
 10#[cfg(feature = "tls-rust")]
 11use tokio_rustls::client::TlsStream;
 12use xmpp_parsers::{ns, Element, Jid};
 13
 14use super::auth::auth;
 15use super::bind::bind;
 16use crate::event::Event;
 17use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
 18use crate::starttls::starttls;
 19use crate::xmpp_codec::Packet;
 20use crate::xmpp_stream::{self, add_stanza_id};
 21use crate::{Error, ProtocolError};
 22
 23/// XMPP client connection and state
 24///
 25/// It is able to reconnect. TODO: implement session management.
 26///
 27/// This implements the `futures` crate's [`Stream`](#impl-Stream) and
 28/// [`Sink`](#impl-Sink<Packet>) traits.
 29pub struct Client {
 30    config: Config,
 31    state: ClientState,
 32    reconnect: bool,
 33    // TODO: tls_required=true
 34}
 35
 36/// XMPP server connection configuration
 37#[derive(Clone, Debug)]
 38pub enum ServerConfig {
 39    /// Use SRV record to find server host
 40    UseSrv,
 41    #[allow(unused)]
 42    /// Manually define server host and port
 43    Manual {
 44        /// Server host name
 45        host: String,
 46        /// Server port
 47        port: u16,
 48    },
 49}
 50
 51/// XMMPP client configuration
 52#[derive(Clone, Debug)]
 53pub struct Config {
 54    /// jid of the account
 55    pub jid: Jid,
 56    /// password of the account
 57    pub password: String,
 58    /// server configuration for the account
 59    pub server: ServerConfig,
 60}
 61
 62type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
 63
 64enum ClientState {
 65    Invalid,
 66    Disconnected,
 67    Connecting(JoinHandle<Result<XMPPStream, Error>>),
 68    Connected(XMPPStream),
 69}
 70
 71impl Client {
 72    /// Start a new XMPP client
 73    ///
 74    /// Start polling the returned instance so that it will connect
 75    /// and yield events.
 76    pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
 77        let config = Config {
 78            jid: jid.into(),
 79            password: password.into(),
 80            server: ServerConfig::UseSrv,
 81        };
 82        Self::new_with_config(config)
 83    }
 84
 85    /// Start a new client given that the JID is already parsed.
 86    pub fn new_with_config(config: Config) -> Self {
 87        let connect = tokio::spawn(Self::connect(
 88            config.server.clone(),
 89            config.jid.clone(),
 90            config.password.clone(),
 91        ));
 92        let client = Client {
 93            config,
 94            state: ClientState::Connecting(connect),
 95            reconnect: false,
 96        };
 97        client
 98    }
 99
100    /// Set whether to reconnect (`true`) or let the stream end
101    /// (`false`) when a connection to the server has ended.
102    pub fn set_reconnect(&mut self, reconnect: bool) -> &mut Self {
103        self.reconnect = reconnect;
104        self
105    }
106
107    async fn connect(
108        server: ServerConfig,
109        jid: Jid,
110        password: String,
111    ) -> Result<XMPPStream, Error> {
112        let username = jid.clone().node().unwrap();
113        let password = password;
114
115        // TCP connection
116        let tcp_stream = match server {
117            ServerConfig::UseSrv => {
118                connect_with_srv(&jid.clone().domain(), "_xmpp-client._tcp", 5222).await?
119            }
120            ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
121        };
122
123        // Unencryped XMPPStream
124        let xmpp_stream =
125            xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
126                .await?;
127
128        let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
129            // TlsStream
130            let tls_stream = starttls(xmpp_stream).await?;
131            // Encrypted XMPPStream
132            xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
133                .await?
134        } else {
135            return Err(Error::Protocol(ProtocolError::NoTls));
136        };
137
138        let creds = Credentials::default()
139            .with_username(username)
140            .with_password(password)
141            .with_channel_binding(ChannelBinding::None);
142        // Authenticated (unspecified) stream
143        let stream = auth(xmpp_stream, creds).await?;
144        // Authenticated XMPPStream
145        let xmpp_stream =
146            xmpp_stream::XMPPStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?;
147
148        // XMPPStream bound to user session
149        let xmpp_stream = bind(xmpp_stream).await?;
150        Ok(xmpp_stream)
151    }
152
153    /// Get the client's bound JID (the one reported by the XMPP
154    /// server).
155    pub fn bound_jid(&self) -> Option<&Jid> {
156        match self.state {
157            ClientState::Connected(ref stream) => Some(&stream.jid),
158            _ => None,
159        }
160    }
161
162    /// Send stanza
163    pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
164        self.send(Packet::Stanza(add_stanza_id(stanza, ns::JABBER_CLIENT)))
165            .await
166    }
167
168    /// End connection by sending `</stream:stream>`
169    ///
170    /// You may expect the server to respond with the same. This
171    /// client will then drop its connection.
172    ///
173    /// Make sure to disable reconnect.
174    pub async fn send_end(&mut self) -> Result<(), Error> {
175        self.send(Packet::StreamEnd).await
176    }
177}
178
179/// Incoming XMPP events
180///
181/// In an `async fn` you may want to use this with `use
182/// futures::stream::StreamExt;`
183impl Stream for Client {
184    type Item = Event;
185
186    /// Low-level read on the XMPP stream, allowing the underlying
187    /// machinery to:
188    ///
189    /// * connect,
190    /// * starttls,
191    /// * authenticate,
192    /// * bind a session, and finally
193    /// * receive stanzas
194    ///
195    /// ...for your client
196    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
197        let state = replace(&mut self.state, ClientState::Invalid);
198
199        match state {
200            ClientState::Invalid => panic!("Invalid client state"),
201            ClientState::Disconnected if self.reconnect => {
202                // TODO: add timeout
203                let connect = tokio::spawn(Self::connect(
204                    self.config.server.clone(),
205                    self.config.jid.clone(),
206                    self.config.password.clone(),
207                ));
208                self.state = ClientState::Connecting(connect);
209                self.poll_next(cx)
210            }
211            ClientState::Disconnected => Poll::Ready(None),
212            ClientState::Connecting(mut connect) => match Pin::new(&mut connect).poll(cx) {
213                Poll::Ready(Ok(Ok(stream))) => {
214                    let bound_jid = stream.jid.clone();
215                    self.state = ClientState::Connected(stream);
216                    Poll::Ready(Some(Event::Online {
217                        bound_jid,
218                        resumed: false,
219                    }))
220                }
221                Poll::Ready(Ok(Err(e))) => {
222                    self.state = ClientState::Disconnected;
223                    return Poll::Ready(Some(Event::Disconnected(e.into())));
224                }
225                Poll::Ready(Err(e)) => {
226                    self.state = ClientState::Disconnected;
227                    panic!("connect task: {}", e);
228                }
229                Poll::Pending => {
230                    self.state = ClientState::Connecting(connect);
231                    Poll::Pending
232                }
233            },
234            ClientState::Connected(mut stream) => {
235                // Poll sink
236                match Pin::new(&mut stream).poll_ready(cx) {
237                    Poll::Pending => (),
238                    Poll::Ready(Ok(())) => (),
239                    Poll::Ready(Err(e)) => {
240                        self.state = ClientState::Disconnected;
241                        return Poll::Ready(Some(Event::Disconnected(e.into())));
242                    }
243                };
244
245                // Poll stream
246                match Pin::new(&mut stream).poll_next(cx) {
247                    Poll::Ready(None) => {
248                        // EOF
249                        self.state = ClientState::Disconnected;
250                        Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
251                    }
252                    Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => {
253                        // Receive stanza
254                        self.state = ClientState::Connected(stream);
255                        Poll::Ready(Some(Event::Stanza(stanza)))
256                    }
257                    Poll::Ready(Some(Ok(Packet::Text(_)))) => {
258                        // Ignore text between stanzas
259                        self.state = ClientState::Connected(stream);
260                        Poll::Pending
261                    }
262                    Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => {
263                        // <stream:stream>
264                        self.state = ClientState::Disconnected;
265                        Poll::Ready(Some(Event::Disconnected(
266                            ProtocolError::InvalidStreamStart.into(),
267                        )))
268                    }
269                    Poll::Ready(Some(Ok(Packet::StreamEnd))) => {
270                        // End of stream: </stream:stream>
271                        self.state = ClientState::Disconnected;
272                        Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
273                    }
274                    Poll::Pending => {
275                        // Try again later
276                        self.state = ClientState::Connected(stream);
277                        Poll::Pending
278                    }
279                    Poll::Ready(Some(Err(e))) => {
280                        self.state = ClientState::Disconnected;
281                        Poll::Ready(Some(Event::Disconnected(e.into())))
282                    }
283                }
284            }
285        }
286    }
287}
288
289/// Outgoing XMPP packets
290///
291/// See `send_stanza()` for an `async fn`
292impl Sink<Packet> for Client {
293    type Error = Error;
294
295    fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
296        match self.state {
297            ClientState::Connected(ref mut stream) => {
298                Pin::new(stream).start_send(item).map_err(|e| e.into())
299            }
300            _ => Err(Error::InvalidState),
301        }
302    }
303
304    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
305        match self.state {
306            ClientState::Connected(ref mut stream) => {
307                Pin::new(stream).poll_ready(cx).map_err(|e| e.into())
308            }
309            _ => Poll::Pending,
310        }
311    }
312
313    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
314        match self.state {
315            ClientState::Connected(ref mut stream) => {
316                Pin::new(stream).poll_flush(cx).map_err(|e| e.into())
317            }
318            _ => Poll::Pending,
319        }
320    }
321
322    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
323        match self.state {
324            ClientState::Connected(ref mut stream) => {
325                Pin::new(stream).poll_close(cx).map_err(|e| e.into())
326            }
327            _ => Poll::Pending,
328        }
329    }
330}