1//! `XMPPStream` provides encoding/decoding for XMPP
2
3use futures::sink::Send;
4use futures::{sink::SinkExt, task::Poll, Sink, Stream};
5use rand::{thread_rng, Rng};
6use std::pin::Pin;
7use std::task::Context;
8use tokio::io::{AsyncRead, AsyncWrite};
9use tokio_util::codec::Framed;
10use xmpp_parsers::{Element, Jid};
11
12use crate::stream_features::StreamFeatures;
13use crate::stream_start;
14use crate::xmpp_codec::{Packet, XMPPCodec};
15use crate::Error;
16
17fn make_id() -> String {
18 let id: u64 = thread_rng().gen();
19 format!("{}", id)
20}
21
22pub(crate) fn add_stanza_id(mut stanza: Element, default_ns: &str) -> Element {
23 if stanza.is("iq", default_ns)
24 || stanza.is("message", default_ns)
25 || stanza.is("presence", default_ns)
26 {
27 if stanza.attr("id").is_none() {
28 stanza.set_attr("id", make_id());
29 }
30 }
31
32 stanza
33}
34
35/// Wraps a binary stream (tokio's `AsyncRead + AsyncWrite`) to decode
36/// and encode XMPP packets.
37///
38/// Implements `Sink + Stream`
39pub struct XMPPStream<S: AsyncRead + AsyncWrite + Unpin> {
40 /// The local Jabber-Id
41 pub jid: Jid,
42 /// Codec instance
43 pub stream: Framed<S, XMPPCodec>,
44 /// `<stream:features/>` for XMPP version 1.0
45 pub stream_features: StreamFeatures,
46 /// Root namespace
47 ///
48 /// This is different for either c2s, s2s, or component
49 /// connections.
50 pub ns: String,
51 /// Stream `id` attribute
52 pub id: String,
53}
54
55impl<S: AsyncRead + AsyncWrite + Unpin> XMPPStream<S> {
56 /// Constructor
57 pub fn new(
58 jid: Jid,
59 stream: Framed<S, XMPPCodec>,
60 ns: String,
61 id: String,
62 stream_features: Element,
63 ) -> Self {
64 XMPPStream {
65 jid,
66 stream,
67 stream_features: StreamFeatures::new(stream_features),
68 ns,
69 id,
70 }
71 }
72
73 /// Send a `<stream:stream>` start tag
74 pub async fn start(stream: S, jid: Jid, ns: String) -> Result<Self, Error> {
75 let xmpp_stream = Framed::new(stream, XMPPCodec::new());
76 stream_start::start(xmpp_stream, jid, ns).await
77 }
78
79 /// Unwraps the inner stream
80 pub fn into_inner(self) -> S {
81 self.stream.into_inner()
82 }
83
84 /// Re-run `start()`
85 pub async fn restart(self) -> Result<Self, Error> {
86 let stream = self.stream.into_inner();
87 Self::start(stream, self.jid, self.ns).await
88 }
89}
90
91impl<S: AsyncRead + AsyncWrite + Unpin> XMPPStream<S> {
92 /// Convenience method
93 pub fn send_stanza<E: Into<Element>>(&mut self, e: E) -> Send<Self, Packet> {
94 self.send(Packet::Stanza(e.into()))
95 }
96}
97
98/// Proxy to self.stream
99impl<S: AsyncRead + AsyncWrite + Unpin> Sink<Packet> for XMPPStream<S> {
100 type Error = crate::Error;
101
102 fn poll_ready(self: Pin<&mut Self>, _ctx: &mut Context) -> Poll<Result<(), Self::Error>> {
103 // Pin::new(&mut self.stream).poll_ready(ctx)
104 // .map_err(|e| e.into())
105 Poll::Ready(Ok(()))
106 }
107
108 fn start_send(
109 #[cfg_attr(rustc_least_1_46, allow(unused_mut))] mut self: Pin<&mut Self>,
110 item: Packet,
111 ) -> Result<(), Self::Error> {
112 Pin::new(&mut self.stream)
113 .start_send(item)
114 .map_err(|e| e.into())
115 }
116
117 fn poll_flush(
118 #[cfg_attr(rustc_least_1_46, allow(unused_mut))] mut self: Pin<&mut Self>,
119 cx: &mut Context,
120 ) -> Poll<Result<(), Self::Error>> {
121 Pin::new(&mut self.stream)
122 .poll_flush(cx)
123 .map_err(|e| e.into())
124 }
125
126 fn poll_close(
127 #[cfg_attr(rustc_least_1_46, allow(unused_mut))] mut self: Pin<&mut Self>,
128 cx: &mut Context,
129 ) -> Poll<Result<(), Self::Error>> {
130 Pin::new(&mut self.stream)
131 .poll_close(cx)
132 .map_err(|e| e.into())
133 }
134}
135
136/// Proxy to self.stream
137impl<S: AsyncRead + AsyncWrite + Unpin> Stream for XMPPStream<S> {
138 type Item = Result<Packet, crate::Error>;
139
140 fn poll_next(
141 #[cfg_attr(rustc_least_1_46, allow(unused_mut))] mut self: Pin<&mut Self>,
142 cx: &mut Context,
143 ) -> Poll<Option<Self::Item>> {
144 Pin::new(&mut self.stream)
145 .poll_next(cx)
146 .map(|result| result.map(|result| result.map_err(|e| e.into())))
147 }
148}