1use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
2use sasl::common::{ChannelBinding, Credentials};
3use std::mem::replace;
4use std::pin::Pin;
5use std::task::Context;
6use tokio::net::TcpStream;
7use tokio::task::JoinHandle;
8#[cfg(feature = "tls-native")]
9use tokio_native_tls::TlsStream;
10#[cfg(feature = "tls-rust")]
11use tokio_rustls::client::TlsStream;
12use xmpp_parsers::{ns, Element, Jid};
13
14use super::auth::auth;
15use super::bind::bind;
16use crate::event::Event;
17use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
18use crate::starttls::starttls;
19use crate::xmpp_codec::Packet;
20use crate::xmpp_stream::{self, add_stanza_id};
21use crate::{Error, ProtocolError};
22
23/// XMPP client connection and state
24///
25/// It is able to reconnect. TODO: implement session management.
26///
27/// This implements the `futures` crate's [`Stream`](#impl-Stream) and
28/// [`Sink`](#impl-Sink<Packet>) traits.
29pub struct Client {
30 config: Config,
31 state: ClientState,
32 reconnect: bool,
33 // TODO: tls_required=true
34}
35
36/// XMPP server connection configuration
37#[derive(Clone, Debug)]
38pub enum ServerConfig {
39 /// Use SRV record to find server host
40 UseSrv,
41 #[allow(unused)]
42 /// Manually define server host and port
43 Manual {
44 /// Server host name
45 host: String,
46 /// Server port
47 port: u16,
48 },
49}
50
51/// XMMPP client configuration
52#[derive(Clone, Debug)]
53pub struct Config {
54 /// jid of the account
55 pub jid: Jid,
56 /// password of the account
57 pub password: String,
58 /// server configuration for the account
59 pub server: ServerConfig,
60}
61
62type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
63
64enum ClientState {
65 Invalid,
66 Disconnected,
67 Connecting(JoinHandle<Result<XMPPStream, Error>>),
68 Connected(XMPPStream),
69}
70
71impl Client {
72 /// Start a new XMPP client
73 ///
74 /// Start polling the returned instance so that it will connect
75 /// and yield events.
76 pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
77 let config = Config {
78 jid: jid.into(),
79 password: password.into(),
80 server: ServerConfig::UseSrv,
81 };
82 Self::new_with_config(config)
83 }
84
85 /// Start a new client given that the JID is already parsed.
86 pub fn new_with_config(config: Config) -> Self {
87 let connect = tokio::spawn(Self::connect(
88 config.server.clone(),
89 config.jid.clone(),
90 config.password.clone(),
91 ));
92 let client = Client {
93 config,
94 state: ClientState::Connecting(connect),
95 reconnect: false,
96 };
97 client
98 }
99
100 /// Set whether to reconnect (`true`) or let the stream end
101 /// (`false`) when a connection to the server has ended.
102 pub fn set_reconnect(&mut self, reconnect: bool) -> &mut Self {
103 self.reconnect = reconnect;
104 self
105 }
106
107 async fn connect(
108 server: ServerConfig,
109 jid: Jid,
110 password: String,
111 ) -> Result<XMPPStream, Error> {
112 let username = jid.clone().node().unwrap();
113 let password = password;
114
115 // TCP connection
116 let tcp_stream = match server {
117 ServerConfig::UseSrv => {
118 connect_with_srv(&jid.clone().domain(), "_xmpp-client._tcp", 5222).await?
119 }
120 ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
121 };
122
123 // Unencryped XMPPStream
124 let xmpp_stream =
125 xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
126 .await?;
127
128 let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
129 // TlsStream
130 let tls_stream = starttls(xmpp_stream).await?;
131 // Encrypted XMPPStream
132 xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
133 .await?
134 } else {
135 return Err(Error::Protocol(ProtocolError::NoTls));
136 };
137
138 let creds = Credentials::default()
139 .with_username(username)
140 .with_password(password)
141 .with_channel_binding(ChannelBinding::None);
142 // Authenticated (unspecified) stream
143 let stream = auth(xmpp_stream, creds).await?;
144 // Authenticated XMPPStream
145 let xmpp_stream =
146 xmpp_stream::XMPPStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?;
147
148 // XMPPStream bound to user session
149 let xmpp_stream = bind(xmpp_stream).await?;
150 Ok(xmpp_stream)
151 }
152
153 /// Get the client's bound JID (the one reported by the XMPP
154 /// server).
155 pub fn bound_jid(&self) -> Option<&Jid> {
156 match self.state {
157 ClientState::Connected(ref stream) => Some(&stream.jid),
158 _ => None,
159 }
160 }
161
162 /// Send stanza
163 pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
164 self.send(Packet::Stanza(add_stanza_id(stanza, ns::JABBER_CLIENT)))
165 .await
166 }
167
168 /// End connection by sending `</stream:stream>`
169 ///
170 /// You may expect the server to respond with the same. This
171 /// client will then drop its connection.
172 ///
173 /// Make sure to disable reconnect.
174 pub async fn send_end(&mut self) -> Result<(), Error> {
175 self.send(Packet::StreamEnd).await
176 }
177}
178
179/// Incoming XMPP events
180///
181/// In an `async fn` you may want to use this with `use
182/// futures::stream::StreamExt;`
183impl Stream for Client {
184 type Item = Event;
185
186 /// Low-level read on the XMPP stream, allowing the underlying
187 /// machinery to:
188 ///
189 /// * connect,
190 /// * starttls,
191 /// * authenticate,
192 /// * bind a session, and finally
193 /// * receive stanzas
194 ///
195 /// ...for your client
196 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
197 let state = replace(&mut self.state, ClientState::Invalid);
198
199 match state {
200 ClientState::Invalid => panic!("Invalid client state"),
201 ClientState::Disconnected if self.reconnect => {
202 // TODO: add timeout
203 let connect = tokio::spawn(Self::connect(
204 self.config.server.clone(),
205 self.config.jid.clone(),
206 self.config.password.clone(),
207 ));
208 self.state = ClientState::Connecting(connect);
209 self.poll_next(cx)
210 }
211 ClientState::Disconnected => Poll::Ready(None),
212 ClientState::Connecting(mut connect) => match Pin::new(&mut connect).poll(cx) {
213 Poll::Ready(Ok(Ok(stream))) => {
214 let bound_jid = stream.jid.clone();
215 self.state = ClientState::Connected(stream);
216 Poll::Ready(Some(Event::Online {
217 bound_jid,
218 resumed: false,
219 }))
220 }
221 Poll::Ready(Ok(Err(e))) => {
222 self.state = ClientState::Disconnected;
223 return Poll::Ready(Some(Event::Disconnected(e.into())));
224 }
225 Poll::Ready(Err(e)) => {
226 self.state = ClientState::Disconnected;
227 panic!("connect task: {}", e);
228 }
229 Poll::Pending => {
230 self.state = ClientState::Connecting(connect);
231 Poll::Pending
232 }
233 },
234 ClientState::Connected(mut stream) => {
235 // Poll sink
236 match Pin::new(&mut stream).poll_ready(cx) {
237 Poll::Pending => (),
238 Poll::Ready(Ok(())) => (),
239 Poll::Ready(Err(e)) => {
240 self.state = ClientState::Disconnected;
241 return Poll::Ready(Some(Event::Disconnected(e.into())));
242 }
243 };
244
245 // Poll stream
246 match Pin::new(&mut stream).poll_next(cx) {
247 Poll::Ready(None) => {
248 // EOF
249 self.state = ClientState::Disconnected;
250 Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
251 }
252 Poll::Ready(Some(Ok(Packet::Stanza(stanza)))) => {
253 // Receive stanza
254 self.state = ClientState::Connected(stream);
255 Poll::Ready(Some(Event::Stanza(stanza)))
256 }
257 Poll::Ready(Some(Ok(Packet::Text(_)))) => {
258 // Ignore text between stanzas
259 self.state = ClientState::Connected(stream);
260 Poll::Pending
261 }
262 Poll::Ready(Some(Ok(Packet::StreamStart(_)))) => {
263 // <stream:stream>
264 self.state = ClientState::Disconnected;
265 Poll::Ready(Some(Event::Disconnected(
266 ProtocolError::InvalidStreamStart.into(),
267 )))
268 }
269 Poll::Ready(Some(Ok(Packet::StreamEnd))) => {
270 // End of stream: </stream:stream>
271 self.state = ClientState::Disconnected;
272 Poll::Ready(Some(Event::Disconnected(Error::Disconnected)))
273 }
274 Poll::Pending => {
275 // Try again later
276 self.state = ClientState::Connected(stream);
277 Poll::Pending
278 }
279 Poll::Ready(Some(Err(e))) => {
280 self.state = ClientState::Disconnected;
281 Poll::Ready(Some(Event::Disconnected(e.into())))
282 }
283 }
284 }
285 }
286 }
287}
288
289/// Outgoing XMPP packets
290///
291/// See `send_stanza()` for an `async fn`
292impl Sink<Packet> for Client {
293 type Error = Error;
294
295 fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {
296 match self.state {
297 ClientState::Connected(ref mut stream) => {
298 Pin::new(stream).start_send(item).map_err(|e| e.into())
299 }
300 _ => Err(Error::InvalidState),
301 }
302 }
303
304 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
305 match self.state {
306 ClientState::Connected(ref mut stream) => {
307 Pin::new(stream).poll_ready(cx).map_err(|e| e.into())
308 }
309 _ => Poll::Pending,
310 }
311 }
312
313 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
314 match self.state {
315 ClientState::Connected(ref mut stream) => {
316 Pin::new(stream).poll_flush(cx).map_err(|e| e.into())
317 }
318 _ => Poll::Pending,
319 }
320 }
321
322 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
323 match self.state {
324 ClientState::Connected(ref mut stream) => {
325 Pin::new(stream).poll_close(cx).map_err(|e| e.into())
326 }
327 _ => Poll::Pending,
328 }
329 }
330}