async_client.rs

  1use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
  2use sasl::common::{ChannelBinding, Credentials};
  3use std::mem::replace;
  4use std::pin::Pin;
  5use std::str::FromStr;
  6use std::task::Context;
  7use tokio::net::TcpStream;
  8use tokio::task::JoinHandle;
  9use tokio::task::LocalSet;
 10#[cfg(feature = "tls-native")]
 11use tokio_native_tls::TlsStream;
 12#[cfg(feature = "tls-rust")]
 13use tokio_rustls::client::TlsStream;
 14use xmpp_parsers::{ns, Element, Jid, JidParseError};
 15
 16use super::auth::auth;
 17use super::bind::bind;
 18use crate::event::Event;
 19use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
 20use crate::starttls::starttls;
 21use crate::xmpp_codec::Packet;
 22use crate::xmpp_stream;
 23use crate::{Error, ProtocolError};
 24
 25/// XMPP client connection and state
 26///
 27/// It is able to reconnect. TODO: implement session management.
 28///
 29/// This implements the `futures` crate's [`Stream`](#impl-Stream) and
 30/// [`Sink`](#impl-Sink<Packet>) traits.
 31pub struct Client {
 32    config: Config,
 33    state: ClientState,
 34    reconnect: bool,
 35    // TODO: tls_required=true
 36}
 37
 38/// XMPP server connection configuration
 39#[derive(Clone)]
 40pub enum ServerConfig {
 41    UseSrv,
 42    #[allow(unused)]
 43    Manual {
 44        host: String,
 45        port: u16,
 46    },
 47}
 48
 49/// XMMPP client configuration
 50pub struct Config {
 51    jid: Jid,
 52    password: String,
 53    server: ServerConfig,
 54}
 55
 56type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
 57
 58enum ClientState {
 59    Invalid,
 60    Disconnected,
 61    Connecting(JoinHandle<Result<XMPPStream, Error>>, LocalSet),
 62    Connected(XMPPStream),
 63}
 64
 65impl Client {
 66    /// Start a new XMPP client
 67    ///
 68    /// Start polling the returned instance so that it will connect
 69    /// and yield events.
 70    pub fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, JidParseError> {
 71        let jid = Jid::from_str(jid)?;
 72        let config = Config {
 73            jid: jid.clone(),
 74            password: password.into(),
 75            server: ServerConfig::UseSrv,
 76        };
 77        let client = Self::new_with_config(config);
 78        Ok(client)
 79    }
 80
 81    /// Start a new client given that the JID is already parsed.
 82    pub fn new_with_config(config: Config) -> Self {
 83        let local = LocalSet::new();
 84        let connect = local.spawn_local(Self::connect(
 85            config.server.clone(),
 86            config.jid.clone(),
 87            config.password.clone(),
 88        ));
 89        let client = Client {
 90            config,
 91            state: ClientState::Connecting(connect, local),
 92            reconnect: false,
 93        };
 94        client
 95    }
 96
 97    /// Set whether to reconnect (`true`) or let the stream end
 98    /// (`false`) when a connection to the server has ended.
 99    pub fn set_reconnect(&mut self, reconnect: bool) -> &mut Self {
100        self.reconnect = reconnect;
101        self
102    }
103
104    async fn connect(
105        server: ServerConfig,
106        jid: Jid,
107        password: String,
108    ) -> Result<XMPPStream, Error> {
109        let username = jid.clone().node().unwrap();
110        let password = password;
111
112        // TCP connection
113        let tcp_stream = match server {
114            ServerConfig::UseSrv => {
115                connect_with_srv(&jid.clone().domain(), "_xmpp-client._tcp", 5222).await?
116            }
117            ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
118        };
119
120        // Unencryped XMPPStream
121        let xmpp_stream =
122            xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
123                .await?;
124
125        let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
126            // TlsStream
127            let tls_stream = starttls(xmpp_stream).await?;
128            // Encrypted XMPPStream
129            xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
130                .await?
131        } else {
132            return Err(Error::Protocol(ProtocolError::NoTls));
133        };
134
135        let creds = Credentials::default()
136            .with_username(username)
137            .with_password(password)
138            .with_channel_binding(ChannelBinding::None);
139        // Authenticated (unspecified) stream
140        let stream = auth(xmpp_stream, creds).await?;
141        // Authenticated XMPPStream
142        let xmpp_stream =
143            xmpp_stream::XMPPStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?;
144
145        // XMPPStream bound to user session
146        let xmpp_stream = bind(xmpp_stream).await?;
147        Ok(xmpp_stream)
148    }
149
150    /// Get the client's bound JID (the one reported by the XMPP
151    /// server).
152    pub fn bound_jid(&self) -> Option<&Jid> {
153        match self.state {
154            ClientState::Connected(ref stream) => Some(&stream.jid),
155            _ => None,
156        }
157    }
158
159    /// Send stanza
160    pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
161        self.send(Packet::Stanza(stanza)).await
162    }
163
164    /// End connection by sending `</stream:stream>`
165    ///
166    /// You may expect the server to respond with the same. This
167    /// client will then drop its connection.
168    ///
169    /// Make sure to disable reconnect.
170    pub async fn send_end(&mut self) -> Result<(), Error> {
171        self.send(Packet::StreamEnd).await
172    }
173}
174
175/// Incoming XMPP events
176///
177/// In an `async fn` you may want to use this with `use
178/// futures::stream::StreamExt;`
179impl Stream for Client {
180    type Item = Event;
181
182    /// Low-level read on the XMPP stream, allowing the underlying
183    /// machinery to:
184    ///
185    /// * connect,
186    /// * starttls,
187    /// * authenticate,
188    /// * bind a session, and finally
189    /// * receive stanzas
190    ///
191    /// ...for your client
192    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
193        let state = replace(&mut self.state, ClientState::Invalid);
194
195        match state {
196            ClientState::Invalid => panic!("Invalid client state"),
197            ClientState::Disconnected if self.reconnect => {
198                // TODO: add timeout
199                let mut local = LocalSet::new();
200                let connect = local.spawn_local(Self::connect(
201                    self.config.server.clone(),
202                    self.config.jid.clone(),
203                    self.config.password.clone(),
204                ));
205                let _ = Pin::new(&mut local).poll(cx);
206                self.state = ClientState::Connecting(connect, local);
207                self.poll_next(cx)
208            }
209            ClientState::Disconnected => Poll::Ready(None),
210            ClientState::Connecting(mut connect, mut local) => {
211                match Pin::new(&mut connect).poll(cx) {
212                    Poll::Ready(Ok(Ok(stream))) => {
213                        let bound_jid = stream.jid.clone();
214                        self.state = ClientState::Connected(stream);
215                        Poll::Ready(Some(Event::Online {
216                            bound_jid,
217                            resumed: false,
218                        }))
219                    }
220                    Poll::Ready(Ok(Err(e))) => {
221                        self.state = ClientState::Disconnected;
222                        return Poll::Ready(Some(Event::Disconnected(e.into())));
223                    }
224                    Poll::Ready(Err(e)) => {
225                        self.state = ClientState::Disconnected;
226                        panic!("connect task: {}", e);
227                    }
228                    Poll::Pending => {
229                        let _ = Pin::new(&mut local).poll(cx);
230
231                        self.state = ClientState::Connecting(connect, local);
232                        Poll::Pending
233                    }
234                }
235            }
236            ClientState::Connected(mut stream) => {
237                // Poll sink
238                match Pin::new(&mut stream).poll_ready(cx) {
239                    Poll::Pending => (),
240                    Poll::Ready(Ok(())) => (),
241                    Poll::Ready(Err(e)) => {
242                        self.state = ClientState::Disconnected;
243                        return Poll::Ready(Some(Event::Disconnected(e.into())));
244                    }
245                };
246
247                // Poll stream
248                match Pin::new(&mut stream).poll_next(cx) {
249                    Poll::Ready(None) => {
250                        // EOF
251                        self.state = ClientState::Disconnected;
252                        Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
253                    }
254                    Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => {
255                        // Receive stanza
256                        self.state = ClientState::Connected(stream);
257                        Poll::Ready(Some(Event::Stanza(stanza)))
258                    }
259                    Poll::Ready(Some(Ok(Packet::Text(_)))) => {
260                        // Ignore text between stanzas
261                        self.state = ClientState::Connected(stream);
262                        Poll::Pending
263                    }
264                    Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => {
265                        // <stream:stream>
266                        self.state = ClientState::Disconnected;
267                        Poll::Ready(Some(Event::Disconnected(
268                            ProtocolError::InvalidStreamStart.into(),
269                        )))
270                    }
271                    Poll::Ready(Some(Ok(Packet::StreamEnd))) => {
272                        // End of stream: </stream:stream>
273                        self.state = ClientState::Disconnected;
274                        Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
275                    }
276                    Poll::Pending => {
277                        // Try again later
278                        self.state = ClientState::Connected(stream);
279                        Poll::Pending
280                    }
281                    Poll::Ready(Some(Err(e))) => {
282                        self.state = ClientState::Disconnected;
283                        Poll::Ready(Some(Event::Disconnected(e.into())))
284                    }
285                }
286            }
287        }
288    }
289}
290
291/// Outgoing XMPP packets
292///
293/// See `send_stanza()` for an `async fn`
294impl Sink<Packet> for Client {
295    type Error = Error;
296
297    fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
298        match self.state {
299            ClientState::Connected(ref mut stream) => {
300                Pin::new(stream).start_send(item).map_err(|e| e.into())
301            }
302            _ => Err(Error::InvalidState),
303        }
304    }
305
306    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
307        match self.state {
308            ClientState::Connected(ref mut stream) => {
309                Pin::new(stream).poll_ready(cx).map_err(|e| e.into())
310            }
311            _ => Poll::Pending,
312        }
313    }
314
315    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
316        match self.state {
317            ClientState::Connected(ref mut stream) => {
318                Pin::new(stream).poll_flush(cx).map_err(|e| e.into())
319            }
320            _ => Poll::Pending,
321        }
322    }
323
324    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
325        match self.state {
326            ClientState::Connected(ref mut stream) => {
327                Pin::new(stream).poll_close(cx).map_err(|e| e.into())
328            }
329            _ => Poll::Pending,
330        }
331    }
332}