xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use crate::Error;
  4use bytes::{BufMut, BytesMut};
  5use log::debug;
  6use minidom::tree_builder::TreeBuilder;
  7use rxml::{Lexer, PushDriver, RawParser};
  8use std;
  9use std::collections::HashMap;
 10use std::default::Default;
 11use std::fmt::Write;
 12use std::io;
 13use tokio_util::codec::{Decoder, Encoder};
 14use xmpp_parsers::Element;
 15
 16/// Anything that can be sent or received on an XMPP/XML stream
 17#[derive(Debug, Clone, PartialEq, Eq)]
 18pub enum Packet {
 19    /// `<stream:stream>` start tag
 20    StreamStart(HashMap<String, String>),
 21    /// A complete stanza or nonza
 22    Stanza(Element),
 23    /// Plain text (think whitespace keep-alive)
 24    Text(String),
 25    /// `</stream:stream>` closing tag
 26    StreamEnd,
 27}
 28
 29/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
 30pub struct XMPPCodec {
 31    /// Outgoing
 32    ns: Option<String>,
 33    /// Incoming
 34    driver: PushDriver<RawParser>,
 35    stanza_builder: TreeBuilder,
 36}
 37
 38impl XMPPCodec {
 39    /// Constructor
 40    pub fn new() -> Self {
 41        let stanza_builder = TreeBuilder::new();
 42        let driver = PushDriver::wrap(Lexer::new(), RawParser::new());
 43        XMPPCodec {
 44            ns: None,
 45            driver,
 46            stanza_builder,
 47        }
 48    }
 49}
 50
 51impl Default for XMPPCodec {
 52    fn default() -> Self {
 53        Self::new()
 54    }
 55}
 56
 57impl Decoder for XMPPCodec {
 58    type Item = Packet;
 59    type Error = Error;
 60
 61    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
 62        loop {
 63            let token = match self.driver.parse(buf, false) {
 64                Ok(Some(token)) => token,
 65                Ok(None) => break,
 66                Err(rxml::Error::IO(e)) if e.kind() == std::io::ErrorKind::WouldBlock => break,
 67                Err(e) => return Err(minidom::Error::from(e).into()),
 68            };
 69
 70            let had_stream_root = self.stanza_builder.depth() > 0;
 71            self.stanza_builder.process_event(token)?;
 72            let has_stream_root = self.stanza_builder.depth() > 0;
 73
 74            if !had_stream_root && has_stream_root {
 75                let root = self.stanza_builder.top().unwrap();
 76                let attrs =
 77                    root.attrs()
 78                        .map(|(name, value)| (name.to_owned(), value.to_owned()))
 79                        .chain(root.prefixes.declared_prefixes().iter().map(
 80                            |(prefix, namespace)| {
 81                                (
 82                                    prefix
 83                                        .as_ref()
 84                                        .map(|prefix| format!("xmlns:{}", prefix))
 85                                        .unwrap_or_else(|| "xmlns".to_owned()),
 86                                    namespace.clone(),
 87                                )
 88                            },
 89                        ))
 90                        .collect();
 91                debug!("<< {}", String::from(root));
 92                return Ok(Some(Packet::StreamStart(attrs)));
 93            } else if self.stanza_builder.depth() == 1 {
 94                self.driver.release_temporaries();
 95
 96                if let Some(stanza) = self.stanza_builder.unshift_child() {
 97                    debug!("<< {}", String::from(&stanza));
 98                    return Ok(Some(Packet::Stanza(stanza)));
 99                }
100            } else if let Some(_) = self.stanza_builder.root.take() {
101                self.driver.release_temporaries();
102
103                debug!("<< </stream:stream>");
104                return Ok(Some(Packet::StreamEnd));
105            }
106        }
107
108        Ok(None)
109    }
110
111    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
112        self.decode(buf)
113    }
114}
115
116impl Encoder<Packet> for XMPPCodec {
117    type Error = Error;
118
119    fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
120        let remaining = dst.capacity() - dst.len();
121        let max_stanza_size: usize = 2usize.pow(16);
122        if remaining < max_stanza_size {
123            dst.reserve(max_stanza_size - remaining);
124        }
125
126        fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
127            io::Error::new(io::ErrorKind::InvalidInput, e)
128        }
129
130        match item {
131            Packet::StreamStart(start_attrs) => {
132                let mut buf = String::new();
133                write!(buf, "<stream:stream").map_err(to_io_err)?;
134                for (name, value) in start_attrs {
135                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
136                    if name == "xmlns" {
137                        self.ns = Some(value);
138                    }
139                }
140                write!(buf, ">\n").map_err(to_io_err)?;
141
142                let utf8 = std::str::from_utf8(dst)?;
143                debug!(">> {}", utf8);
144                write!(dst, "{}", buf)?
145            }
146            Packet::Stanza(stanza) => {
147                let _ = stanza
148                    .write_to(&mut WriteBytes::new(dst))
149                    .map_err(|e| to_io_err(format!("{}", e)))?;
150                let utf8 = std::str::from_utf8(dst)?;
151                debug!(">> {}", utf8);
152            }
153            Packet::Text(text) => {
154                let _ = write_text(&text, dst).map_err(to_io_err)?;
155                let utf8 = std::str::from_utf8(dst)?;
156                debug!(">> {}", utf8);
157            }
158            Packet::StreamEnd => {
159                let _ = write!(dst, "</stream:stream>\n").map_err(to_io_err);
160            }
161        }
162
163        Ok(())
164    }
165}
166
167/// Write XML-escaped text string
168pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
169    write!(writer, "{}", escape(text))
170}
171
172/// Copied from `RustyXML` for now
173pub fn escape(input: &str) -> String {
174    let mut result = String::with_capacity(input.len());
175
176    for c in input.chars() {
177        match c {
178            '&' => result.push_str("&amp;"),
179            '<' => result.push_str("&lt;"),
180            '>' => result.push_str("&gt;"),
181            '\'' => result.push_str("&apos;"),
182            '"' => result.push_str("&quot;"),
183            o => result.push(o),
184        }
185    }
186    result
187}
188
189/// BytesMut impl only std::fmt::Write but not std::io::Write. The
190/// latter trait is required for minidom's
191/// `Element::write_to_inner()`.
192struct WriteBytes<'a> {
193    dst: &'a mut BytesMut,
194}
195
196impl<'a> WriteBytes<'a> {
197    fn new(dst: &'a mut BytesMut) -> Self {
198        WriteBytes { dst }
199    }
200}
201
202impl<'a> std::io::Write for WriteBytes<'a> {
203    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
204        self.dst.put_slice(buf);
205        Ok(buf.len())
206    }
207
208    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
209        Ok(())
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use bytes::BytesMut;
217
218    #[test]
219    fn test_stream_start() {
220        let mut c = XMPPCodec::new();
221        let mut b = BytesMut::with_capacity(1024);
222        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
223        let r = c.decode(&mut b);
224        assert!(match r {
225            Ok(Some(Packet::StreamStart(_))) => true,
226            _ => false,
227        });
228    }
229
230    #[test]
231    fn test_stream_end() {
232        let mut c = XMPPCodec::new();
233        let mut b = BytesMut::with_capacity(1024);
234        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
235        let r = c.decode(&mut b);
236        assert!(match r {
237            Ok(Some(Packet::StreamStart(_))) => true,
238            _ => false,
239        });
240        b.put_slice(b"</stream:stream>");
241        let r = c.decode(&mut b);
242        assert!(match r {
243            Ok(Some(Packet::StreamEnd)) => true,
244            _ => false,
245        });
246    }
247
248    #[test]
249    fn test_truncated_stanza() {
250        let mut c = XMPPCodec::new();
251        let mut b = BytesMut::with_capacity(1024);
252        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
253        let r = c.decode(&mut b);
254        assert!(match r {
255            Ok(Some(Packet::StreamStart(_))) => true,
256            _ => false,
257        });
258
259        b.put_slice("<test>ß</test".as_bytes());
260        let r = c.decode(&mut b);
261        assert!(match r {
262            Ok(None) => true,
263            _ => false,
264        });
265
266        b.put_slice(b">");
267        let r = c.decode(&mut b);
268        assert!(match r {
269            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
270            _ => false,
271        });
272    }
273
274    #[test]
275    fn test_truncated_utf8() {
276        let mut c = XMPPCodec::new();
277        let mut b = BytesMut::with_capacity(1024);
278        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
279        let r = c.decode(&mut b);
280        assert!(match r {
281            Ok(Some(Packet::StreamStart(_))) => true,
282            _ => false,
283        });
284
285        b.put(&b"<test>\xc3"[..]);
286        let r = c.decode(&mut b);
287        assert!(match r {
288            Ok(None) => true,
289            _ => false,
290        });
291
292        b.put(&b"\x9f</test>"[..]);
293        let r = c.decode(&mut b);
294        assert!(match r {
295            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
296            _ => false,
297        });
298    }
299
300    /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
301    #[test]
302    fn test_atrribute_prefix() {
303        let mut c = XMPPCodec::new();
304        let mut b = BytesMut::with_capacity(1024);
305        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
306        let r = c.decode(&mut b);
307        assert!(match r {
308            Ok(Some(Packet::StreamStart(_))) => true,
309            _ => false,
310        });
311
312        b.put_slice(b"<status xml:lang='en'>Test status</status>");
313        let r = c.decode(&mut b);
314        assert!(match r {
315            Ok(Some(Packet::Stanza(ref el)))
316                if el.name() == "status"
317                    && el.text() == "Test status"
318                    && el.attr("xml:lang").map_or(false, |a| a == "en") =>
319                true,
320            _ => false,
321        });
322    }
323
324    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
325    #[test]
326    fn test_large_stanza() {
327        use futures::{executor::block_on, sink::SinkExt};
328        use std::io::Cursor;
329        use tokio_util::codec::FramedWrite;
330        let mut framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
331        let mut text = "".to_owned();
332        for _ in 0..2usize.pow(15) {
333            text = text + "A";
334        }
335        let stanza = Element::builder("message", "jabber:client")
336            .append(
337                Element::builder("body", "jabber:client")
338                    .append(text.as_ref())
339                    .build(),
340            )
341            .build();
342        block_on(framed.send(Packet::Stanza(stanza))).expect("send");
343        assert_eq!(
344            framed.get_ref().get_ref(),
345            &format!(
346                "<message xmlns='jabber:client'><body>{}</body></message>",
347                text
348            )
349            .as_bytes()
350        );
351    }
352
353    #[test]
354    fn test_cut_out_stanza() {
355        let mut c = XMPPCodec::new();
356        let mut b = BytesMut::with_capacity(1024);
357        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
358        let r = c.decode(&mut b);
359        assert!(match r {
360            Ok(Some(Packet::StreamStart(_))) => true,
361            _ => false,
362        });
363
364        b.put_slice(b"<message ");
365        b.put_slice(b"type='chat'><body>Foo</body></message>");
366        let r = c.decode(&mut b);
367        assert!(match r {
368            Ok(Some(Packet::Stanza(_))) => true,
369            _ => false,
370        });
371    }
372}