xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use crate::Error;
  4use bytes::{BufMut, BytesMut};
  5use log::debug;
  6use minidom::tree_builder::TreeBuilder;
  7use rxml::{Parse, RawParser};
  8use std::collections::HashMap;
  9use std::fmt::Write;
 10use std::io;
 11#[cfg(feature = "syntax-highlighting")]
 12use std::sync::OnceLock;
 13use tokio_util::codec::{Decoder, Encoder};
 14use xmpp_parsers::Element;
 15
 16#[cfg(feature = "syntax-highlighting")]
 17static PS: OnceLock<syntect::parsing::SyntaxSet> = OnceLock::new();
 18#[cfg(feature = "syntax-highlighting")]
 19static SYNTAX: OnceLock<syntect::parsing::SyntaxReference> = OnceLock::new();
 20#[cfg(feature = "syntax-highlighting")]
 21static THEME: OnceLock<syntect::highlighting::Theme> = OnceLock::new();
 22
 23#[cfg(feature = "syntax-highlighting")]
 24fn init_syntect() {
 25    let ps = syntect::parsing::SyntaxSet::load_defaults_newlines();
 26    let syntax = ps.find_syntax_by_extension("xml").unwrap();
 27    let ts = syntect::highlighting::ThemeSet::load_defaults();
 28    let theme = ts.themes["Solarized (dark)"].clone();
 29
 30    SYNTAX.set(syntax.clone()).unwrap();
 31    PS.set(ps).unwrap();
 32    THEME.set(theme).unwrap();
 33}
 34
 35#[cfg(feature = "syntax-highlighting")]
 36fn highlight_xml(xml: &str) -> String {
 37    let mut h = syntect::easy::HighlightLines::new(SYNTAX.get().unwrap(), THEME.get().unwrap());
 38    let ranges: Vec<_> = h.highlight_line(&xml, PS.get().unwrap()).unwrap();
 39    let escaped = syntect::util::as_24_bit_terminal_escaped(&ranges[..], false);
 40    format!("{}\x1b[0m", escaped)
 41}
 42
 43#[cfg(not(feature = "syntax-highlighting"))]
 44fn highlight_xml(xml: &str) -> &str {
 45    xml
 46}
 47
 48/// Anything that can be sent or received on an XMPP/XML stream
 49#[derive(Debug, Clone, PartialEq, Eq)]
 50pub enum Packet {
 51    /// `<stream:stream>` start tag
 52    StreamStart(HashMap<String, String>),
 53    /// A complete stanza or nonza
 54    Stanza(Element),
 55    /// Plain text (think whitespace keep-alive)
 56    Text(String),
 57    /// `</stream:stream>` closing tag
 58    StreamEnd,
 59}
 60
 61/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
 62pub struct XmppCodec {
 63    /// Outgoing
 64    ns: Option<String>,
 65    /// Incoming
 66    driver: RawParser,
 67    stanza_builder: TreeBuilder,
 68}
 69
 70impl XmppCodec {
 71    /// Constructor
 72    pub fn new() -> Self {
 73        let stanza_builder = TreeBuilder::new();
 74        let driver = RawParser::new();
 75        #[cfg(feature = "syntax-highlighting")]
 76        if log::log_enabled!(log::Level::Debug) && PS.get().is_none() {
 77            init_syntect();
 78        }
 79        XmppCodec {
 80            ns: None,
 81            driver,
 82            stanza_builder,
 83        }
 84    }
 85}
 86
 87impl Default for XmppCodec {
 88    fn default() -> Self {
 89        Self::new()
 90    }
 91}
 92
 93impl Decoder for XmppCodec {
 94    type Item = Packet;
 95    type Error = Error;
 96
 97    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
 98        loop {
 99            let token = match self.driver.parse_buf(buf, false) {
100                Ok(Some(token)) => token,
101                Ok(None) => break,
102                Err(rxml::Error::IO(e)) if e.kind() == std::io::ErrorKind::WouldBlock => break,
103                Err(e) => return Err(minidom::Error::from(e).into()),
104            };
105
106            let had_stream_root = self.stanza_builder.depth() > 0;
107            self.stanza_builder.process_event(token)?;
108            let has_stream_root = self.stanza_builder.depth() > 0;
109
110            if !had_stream_root && has_stream_root {
111                let root = self.stanza_builder.top().unwrap();
112                let attrs =
113                    root.attrs()
114                        .map(|(name, value)| (name.to_owned(), value.to_owned()))
115                        .chain(root.prefixes.declared_prefixes().iter().map(
116                            |(prefix, namespace)| {
117                                (
118                                    prefix
119                                        .as_ref()
120                                        .map(|prefix| format!("xmlns:{}", prefix))
121                                        .unwrap_or_else(|| "xmlns".to_owned()),
122                                    namespace.clone(),
123                                )
124                            },
125                        ))
126                        .collect();
127                debug!("<< {}", highlight_xml(&String::from(root)));
128                return Ok(Some(Packet::StreamStart(attrs)));
129            } else if self.stanza_builder.depth() == 1 {
130                self.driver.release_temporaries();
131
132                if let Some(stanza) = self.stanza_builder.unshift_child() {
133                    debug!("<< {}", highlight_xml(&String::from(&stanza)));
134                    return Ok(Some(Packet::Stanza(stanza)));
135                }
136            } else if let Some(_) = self.stanza_builder.root.take() {
137                self.driver.release_temporaries();
138
139                debug!("<< {}", highlight_xml("</stream:stream>"));
140                return Ok(Some(Packet::StreamEnd));
141            }
142        }
143
144        Ok(None)
145    }
146
147    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
148        self.decode(buf)
149    }
150}
151
152impl Encoder<Packet> for XmppCodec {
153    type Error = Error;
154
155    fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
156        let remaining = dst.capacity() - dst.len();
157        let max_stanza_size: usize = 2usize.pow(16);
158        if remaining < max_stanza_size {
159            dst.reserve(max_stanza_size - remaining);
160        }
161
162        fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
163            io::Error::new(io::ErrorKind::InvalidInput, e)
164        }
165
166        match item {
167            Packet::StreamStart(start_attrs) => {
168                let mut buf = String::new();
169                write!(buf, "<stream:stream").map_err(to_io_err)?;
170                for (name, value) in start_attrs {
171                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
172                    if name == "xmlns" {
173                        self.ns = Some(value);
174                    }
175                }
176                write!(buf, ">").map_err(to_io_err)?;
177
178                write!(dst, "{}", buf)?;
179                let utf8 = std::str::from_utf8(dst)?;
180                debug!(">> {}", highlight_xml(utf8))
181            }
182            Packet::Stanza(stanza) => {
183                let _ = stanza
184                    .write_to(&mut WriteBytes::new(dst))
185                    .map_err(|e| to_io_err(format!("{}", e)))?;
186                let utf8 = std::str::from_utf8(dst)?;
187                debug!(">> {}", highlight_xml(utf8));
188            }
189            Packet::Text(text) => {
190                let _ = write_text(&text, dst).map_err(to_io_err)?;
191                let utf8 = std::str::from_utf8(dst)?;
192                debug!(">> {}", highlight_xml(utf8));
193            }
194            Packet::StreamEnd => {
195                let _ = write!(dst, "</stream:stream>\n").map_err(to_io_err);
196                debug!(">> {}", highlight_xml("</stream:stream>"));
197            }
198        }
199
200        Ok(())
201    }
202}
203
204/// Write XML-escaped text string
205pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
206    write!(writer, "{}", escape(text))
207}
208
209/// Copied from `RustyXML` for now
210pub fn escape(input: &str) -> String {
211    let mut result = String::with_capacity(input.len());
212
213    for c in input.chars() {
214        match c {
215            '&' => result.push_str("&amp;"),
216            '<' => result.push_str("&lt;"),
217            '>' => result.push_str("&gt;"),
218            '\'' => result.push_str("&apos;"),
219            '"' => result.push_str("&quot;"),
220            o => result.push(o),
221        }
222    }
223    result
224}
225
226/// BytesMut impl only std::fmt::Write but not std::io::Write. The
227/// latter trait is required for minidom's
228/// `Element::write_to_inner()`.
229struct WriteBytes<'a> {
230    dst: &'a mut BytesMut,
231}
232
233impl<'a> WriteBytes<'a> {
234    fn new(dst: &'a mut BytesMut) -> Self {
235        WriteBytes { dst }
236    }
237}
238
239impl<'a> std::io::Write for WriteBytes<'a> {
240    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
241        self.dst.put_slice(buf);
242        Ok(buf.len())
243    }
244
245    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
246        Ok(())
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_stream_start() {
256        let mut c = XmppCodec::new();
257        let mut b = BytesMut::with_capacity(1024);
258        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
259        let r = c.decode(&mut b);
260        assert!(match r {
261            Ok(Some(Packet::StreamStart(_))) => true,
262            _ => false,
263        });
264    }
265
266    #[test]
267    fn test_stream_end() {
268        let mut c = XmppCodec::new();
269        let mut b = BytesMut::with_capacity(1024);
270        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
271        let r = c.decode(&mut b);
272        assert!(match r {
273            Ok(Some(Packet::StreamStart(_))) => true,
274            _ => false,
275        });
276        b.put_slice(b"</stream:stream>");
277        let r = c.decode(&mut b);
278        assert!(match r {
279            Ok(Some(Packet::StreamEnd)) => true,
280            _ => false,
281        });
282    }
283
284    #[test]
285    fn test_truncated_stanza() {
286        let mut c = XmppCodec::new();
287        let mut b = BytesMut::with_capacity(1024);
288        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
289        let r = c.decode(&mut b);
290        assert!(match r {
291            Ok(Some(Packet::StreamStart(_))) => true,
292            _ => false,
293        });
294
295        b.put_slice("<test>ß</test".as_bytes());
296        let r = c.decode(&mut b);
297        assert!(match r {
298            Ok(None) => true,
299            _ => false,
300        });
301
302        b.put_slice(b">");
303        let r = c.decode(&mut b);
304        assert!(match r {
305            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
306            _ => false,
307        });
308    }
309
310    #[test]
311    fn test_truncated_utf8() {
312        let mut c = XmppCodec::new();
313        let mut b = BytesMut::with_capacity(1024);
314        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
315        let r = c.decode(&mut b);
316        assert!(match r {
317            Ok(Some(Packet::StreamStart(_))) => true,
318            _ => false,
319        });
320
321        b.put(&b"<test>\xc3"[..]);
322        let r = c.decode(&mut b);
323        assert!(match r {
324            Ok(None) => true,
325            _ => false,
326        });
327
328        b.put(&b"\x9f</test>"[..]);
329        let r = c.decode(&mut b);
330        assert!(match r {
331            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
332            _ => false,
333        });
334    }
335
336    /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
337    #[test]
338    fn test_atrribute_prefix() {
339        let mut c = XmppCodec::new();
340        let mut b = BytesMut::with_capacity(1024);
341        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
342        let r = c.decode(&mut b);
343        assert!(match r {
344            Ok(Some(Packet::StreamStart(_))) => true,
345            _ => false,
346        });
347
348        b.put_slice(b"<status xml:lang='en'>Test status</status>");
349        let r = c.decode(&mut b);
350        assert!(match r {
351            Ok(Some(Packet::Stanza(ref el)))
352                if el.name() == "status"
353                    && el.text() == "Test status"
354                    && el.attr("xml:lang").map_or(false, |a| a == "en") =>
355                true,
356            _ => false,
357        });
358    }
359
360    /// By default, encode() only gets a BytesMut that has 8 KiB space reserved.
361    #[test]
362    fn test_large_stanza() {
363        use futures::{executor::block_on, sink::SinkExt};
364        use std::io::Cursor;
365        use tokio_util::codec::FramedWrite;
366        let mut framed = FramedWrite::new(Cursor::new(vec![]), XmppCodec::new());
367        let mut text = "".to_owned();
368        for _ in 0..2usize.pow(15) {
369            text = text + "A";
370        }
371        let stanza = Element::builder("message", "jabber:client")
372            .append(
373                Element::builder("body", "jabber:client")
374                    .append(text.as_ref())
375                    .build(),
376            )
377            .build();
378        block_on(framed.send(Packet::Stanza(stanza))).expect("send");
379        assert_eq!(
380            framed.get_ref().get_ref(),
381            &format!(
382                "<message xmlns='jabber:client'><body>{}</body></message>",
383                text
384            )
385            .as_bytes()
386        );
387    }
388
389    #[test]
390    fn test_cut_out_stanza() {
391        let mut c = XmppCodec::new();
392        let mut b = BytesMut::with_capacity(1024);
393        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
394        let r = c.decode(&mut b);
395        assert!(match r {
396            Ok(Some(Packet::StreamStart(_))) => true,
397            _ => false,
398        });
399
400        b.put_slice(b"<message ");
401        b.put_slice(b"type='chat'><body>Foo</body></message>");
402        let r = c.decode(&mut b);
403        assert!(match r {
404            Ok(Some(Packet::Stanza(_))) => true,
405            _ => false,
406        });
407    }
408}