xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use crate::{ParseError, ParserError};
  4use bytes::{BufMut, BytesMut};
  5use std;
  6use std::borrow::Cow;
  7use std::cell::RefCell;
  8use std::collections::vec_deque::VecDeque;
  9use std::collections::HashMap;
 10use std::default::Default;
 11use std::fmt::Write;
 12use std::io;
 13use std::iter::FromIterator;
 14use std::rc::Rc;
 15use std::str::from_utf8;
 16use tokio_codec::{Decoder, Encoder};
 17use xml5ever::buffer_queue::BufferQueue;
 18use xml5ever::interface::Attribute;
 19use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
 20use xmpp_parsers::Element;
 21
 22/// Anything that can be sent or received on an XMPP/XML stream
 23#[derive(Debug, Clone, PartialEq, Eq)]
 24pub enum Packet {
 25    /// `<stream:stream>` start tag
 26    StreamStart(HashMap<String, String>),
 27    /// A complete stanza or nonza
 28    Stanza(Element),
 29    /// Plain text (think whitespace keep-alive)
 30    Text(String),
 31    /// `</stream:stream>` closing tag
 32    StreamEnd,
 33}
 34
 35type QueueItem = Result<Packet, ParserError>;
 36
 37/// Parser state
 38struct ParserSink {
 39    // Ready stanzas, shared with XMPPCodec
 40    queue: Rc<RefCell<VecDeque<QueueItem>>>,
 41    // Parsing stack
 42    stack: Vec<Element>,
 43    ns_stack: Vec<HashMap<Option<String>, String>>,
 44}
 45
 46impl ParserSink {
 47    pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
 48        ParserSink {
 49            queue,
 50            stack: vec![],
 51            ns_stack: vec![],
 52        }
 53    }
 54
 55    fn push_queue(&self, pkt: Packet) {
 56        self.queue.borrow_mut().push_back(Ok(pkt));
 57    }
 58
 59    fn push_queue_error(&self, e: ParserError) {
 60        self.queue.borrow_mut().push_back(Err(e));
 61    }
 62
 63    /// Lookup XML namespace declaration for given prefix (or no prefix)
 64    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 65        for nss in self.ns_stack.iter().rev() {
 66            if let Some(ns) = nss.get(prefix) {
 67                return Some(ns);
 68            }
 69        }
 70
 71        None
 72    }
 73
 74    fn handle_start_tag(&mut self, tag: Tag) {
 75        let mut nss = HashMap::new();
 76        let is_prefix_xmlns = |attr: &Attribute| {
 77            attr.name
 78                .prefix
 79                .as_ref()
 80                .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 81                .unwrap_or(false)
 82        };
 83        for attr in &tag.attrs {
 84            match attr.name.local.as_ref() {
 85                "xmlns" => {
 86                    nss.insert(None, attr.value.as_ref().to_owned());
 87                }
 88                prefix if is_prefix_xmlns(attr) => {
 89                    nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 90                }
 91                _ => (),
 92            }
 93        }
 94        self.ns_stack.push(nss);
 95
 96        let el = {
 97            let mut el_builder = Element::builder(tag.name.local.as_ref());
 98            if let Some(el_ns) =
 99                self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
100            {
101                el_builder = el_builder.ns(el_ns);
102            }
103            for attr in &tag.attrs {
104                match attr.name.local.as_ref() {
105                    "xmlns" => (),
106                    _ if is_prefix_xmlns(attr) => (),
107                    _ => {
108                        let attr_name = if let Some(ref prefix) = attr.name.prefix {
109                            Cow::Owned(format!("{}:{}", prefix, attr.name.local))
110                        } else {
111                            Cow::Borrowed(attr.name.local.as_ref())
112                        };
113                        el_builder = el_builder.attr(attr_name, attr.value.as_ref());
114                    }
115                }
116            }
117            el_builder.build()
118        };
119
120        if self.stack.is_empty() {
121            let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
122                (
123                    attr.name.local.as_ref().to_owned(),
124                    attr.value.as_ref().to_owned(),
125                )
126            }));
127            self.push_queue(Packet::StreamStart(attrs));
128        }
129
130        self.stack.push(el);
131    }
132
133    fn handle_end_tag(&mut self) {
134        let el = self.stack.pop().unwrap();
135        self.ns_stack.pop();
136
137        match self.stack.len() {
138            // </stream:stream>
139            0 => self.push_queue(Packet::StreamEnd),
140            // </stanza>
141            1 => self.push_queue(Packet::Stanza(el)),
142            len => {
143                let parent = &mut self.stack[len - 1];
144                parent.append_child(el);
145            }
146        }
147    }
148}
149
150impl TokenSink for ParserSink {
151    fn process_token(&mut self, token: Token) {
152        match token {
153            Token::TagToken(tag) => match tag.kind {
154                TagKind::StartTag => self.handle_start_tag(tag),
155                TagKind::EndTag => self.handle_end_tag(),
156                TagKind::EmptyTag => {
157                    self.handle_start_tag(tag);
158                    self.handle_end_tag();
159                }
160                TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
161            },
162            Token::CharacterTokens(tendril) => match self.stack.len() {
163                0 | 1 => self.push_queue(Packet::Text(tendril.into())),
164                len => {
165                    let el = &mut self.stack[len - 1];
166                    el.append_text_node(tendril);
167                }
168            },
169            Token::EOFToken => self.push_queue(Packet::StreamEnd),
170            Token::ParseError(s) => {
171                // println!("ParseError: {:?}", s);
172                self.push_queue_error(ParserError::Parse(ParseError(s)));
173            }
174            _ => (),
175        }
176    }
177
178    // fn end(&mut self) {
179    // }
180}
181
182/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
183pub struct XMPPCodec {
184    /// Outgoing
185    ns: Option<String>,
186    /// Incoming
187    parser: XmlTokenizer<ParserSink>,
188    /// For handling incoming truncated utf8
189    // TODO: optimize using  tendrils?
190    buf: Vec<u8>,
191    /// Shared with ParserSink
192    queue: Rc<RefCell<VecDeque<QueueItem>>>,
193}
194
195impl XMPPCodec {
196    /// Constructor
197    pub fn new() -> Self {
198        let queue = Rc::new(RefCell::new(VecDeque::new()));
199        let sink = ParserSink::new(queue.clone());
200        // TODO: configure parser?
201        let parser = XmlTokenizer::new(sink, Default::default());
202        XMPPCodec {
203            ns: None,
204            parser,
205            queue,
206            buf: vec![],
207        }
208    }
209}
210
211impl Default for XMPPCodec {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl Decoder for XMPPCodec {
218    type Item = Packet;
219    type Error = ParserError;
220
221    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
222        let buf1: Box<dyn AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
223            let mut prefix = std::mem::replace(&mut self.buf, vec![]);
224            prefix.extend_from_slice(buf.take().as_ref());
225            Box::new(prefix)
226        } else {
227            Box::new(buf.take())
228        };
229        let buf1 = buf1.as_ref().as_ref();
230        match from_utf8(buf1) {
231            Ok(mut s) => {
232                s = s.trim();
233                if !s.is_empty() {
234                    // println!("<< {}", s);
235                    let mut buffer_queue = BufferQueue::new();
236                    let tendril = FromIterator::from_iter(s.chars());
237                    buffer_queue.push_back(tendril);
238                    self.parser.feed(&mut buffer_queue);
239                }
240            }
241            // Remedies for truncated utf8
242            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
243                // Prepare all the valid data
244                let mut b = BytesMut::with_capacity(e.valid_up_to());
245                b.put(&buf1[0..e.valid_up_to()]);
246
247                // Retry
248                let result = self.decode(&mut b);
249
250                // Keep the tail back in
251                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
252
253                return result;
254            }
255            Err(e) => {
256                // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
257                return Err(ParserError::Utf8(e));
258            }
259        }
260
261        match self.queue.borrow_mut().pop_front() {
262            None => Ok(None),
263            Some(result) => result.map(|pkt| Some(pkt)),
264        }
265    }
266
267    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
268        self.decode(buf)
269    }
270}
271
272impl Encoder for XMPPCodec {
273    type Item = Packet;
274    type Error = io::Error;
275
276    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
277        let remaining = dst.capacity() - dst.len();
278        let max_stanza_size: usize = 2usize.pow(16);
279        if remaining < max_stanza_size {
280            dst.reserve(max_stanza_size - remaining);
281        }
282
283        fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
284            io::Error::new(io::ErrorKind::InvalidInput, e)
285        }
286
287        match item {
288            Packet::StreamStart(start_attrs) => {
289                let mut buf = String::new();
290                write!(buf, "<stream:stream").map_err(to_io_err)?;
291                for (name, value) in start_attrs {
292                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
293                    if name == "xmlns" {
294                        self.ns = Some(value);
295                    }
296                }
297                write!(buf, ">\n").map_err(to_io_err)?;
298
299                // print!(">> {}", buf);
300                write!(dst, "{}", buf).map_err(to_io_err)
301            }
302            Packet::Stanza(stanza) => {
303                stanza
304                    .write_to(&mut WriteBytes::new(dst))
305                    .and_then(|_| {
306                        // println!(">> {:?}", dst);
307                        Ok(())
308                    })
309                    .map_err(|e| to_io_err(format!("{}", e)))
310            }
311            Packet::Text(text) => {
312                write_text(&text, dst)
313                    .and_then(|_| {
314                        // println!(">> {:?}", dst);
315                        Ok(())
316                    })
317                    .map_err(to_io_err)
318            }
319            Packet::StreamEnd => write!(dst, "</stream:stream>\n").map_err(to_io_err),
320        }
321    }
322}
323
324/// Write XML-escaped text string
325pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
326    write!(writer, "{}", escape(text))
327}
328
329/// Copied from `RustyXML` for now
330pub fn escape(input: &str) -> String {
331    let mut result = String::with_capacity(input.len());
332
333    for c in input.chars() {
334        match c {
335            '&' => result.push_str("&amp;"),
336            '<' => result.push_str("&lt;"),
337            '>' => result.push_str("&gt;"),
338            '\'' => result.push_str("&apos;"),
339            '"' => result.push_str("&quot;"),
340            o => result.push(o),
341        }
342    }
343    result
344}
345
346/// BytesMut impl only std::fmt::Write but not std::io::Write. The
347/// latter trait is required for minidom's
348/// `Element::write_to_inner()`.
349struct WriteBytes<'a> {
350    dst: &'a mut BytesMut,
351}
352
353impl<'a> WriteBytes<'a> {
354    fn new(dst: &'a mut BytesMut) -> Self {
355        WriteBytes { dst }
356    }
357}
358
359impl<'a> std::io::Write for WriteBytes<'a> {
360    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
361        self.dst.put_slice(buf);
362        Ok(buf.len())
363    }
364
365    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
366        Ok(())
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use bytes::BytesMut;
374
375    #[test]
376    fn test_stream_start() {
377        let mut c = XMPPCodec::new();
378        let mut b = BytesMut::with_capacity(1024);
379        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
380        let r = c.decode(&mut b);
381        assert!(match r {
382            Ok(Some(Packet::StreamStart(_))) => true,
383            _ => false,
384        });
385    }
386
387    #[test]
388    fn test_stream_end() {
389        let mut c = XMPPCodec::new();
390        let mut b = BytesMut::with_capacity(1024);
391        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
392        let r = c.decode(&mut b);
393        assert!(match r {
394            Ok(Some(Packet::StreamStart(_))) => true,
395            _ => false,
396        });
397        b.clear();
398        b.put(r"</stream:stream>");
399        let r = c.decode(&mut b);
400        assert!(match r {
401            Ok(Some(Packet::StreamEnd)) => true,
402            _ => false,
403        });
404    }
405
406    #[test]
407    fn test_truncated_stanza() {
408        let mut c = XMPPCodec::new();
409        let mut b = BytesMut::with_capacity(1024);
410        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
411        let r = c.decode(&mut b);
412        assert!(match r {
413            Ok(Some(Packet::StreamStart(_))) => true,
414            _ => false,
415        });
416
417        b.clear();
418        b.put(r"<test>ß</test");
419        let r = c.decode(&mut b);
420        assert!(match r {
421            Ok(None) => true,
422            _ => false,
423        });
424
425        b.clear();
426        b.put(r">");
427        let r = c.decode(&mut b);
428        assert!(match r {
429            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
430            _ => false,
431        });
432    }
433
434    #[test]
435    fn test_truncated_utf8() {
436        let mut c = XMPPCodec::new();
437        let mut b = BytesMut::with_capacity(1024);
438        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
439        let r = c.decode(&mut b);
440        assert!(match r {
441            Ok(Some(Packet::StreamStart(_))) => true,
442            _ => false,
443        });
444
445        b.clear();
446        b.put(&b"<test>\xc3"[..]);
447        let r = c.decode(&mut b);
448        assert!(match r {
449            Ok(None) => true,
450            _ => false,
451        });
452
453        b.clear();
454        b.put(&b"\x9f</test>"[..]);
455        let r = c.decode(&mut b);
456        assert!(match r {
457            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
458            _ => false,
459        });
460    }
461
462    /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
463    #[test]
464    fn test_atrribute_prefix() {
465        let mut c = XMPPCodec::new();
466        let mut b = BytesMut::with_capacity(1024);
467        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
468        let r = c.decode(&mut b);
469        assert!(match r {
470            Ok(Some(Packet::StreamStart(_))) => true,
471            _ => false,
472        });
473
474        b.clear();
475        b.put(r"<status xml:lang='en'>Test status</status>");
476        let r = c.decode(&mut b);
477        assert!(match r {
478            Ok(Some(Packet::Stanza(ref el)))
479                if el.name() == "status"
480                    && el.text() == "Test status"
481                    && el.attr("xml:lang").map_or(false, |a| a == "en") =>
482                true,
483            _ => false,
484        });
485    }
486
487    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
488    #[test]
489    fn test_large_stanza() {
490        use futures::{Future, Sink};
491        use std::io::Cursor;
492        use tokio_codec::FramedWrite;
493        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
494        let mut text = "".to_owned();
495        for _ in 0..2usize.pow(15) {
496            text = text + "A";
497        }
498        let stanza = Element::builder("message")
499            .append(Element::builder("body").append(text.as_ref()).build())
500            .build();
501        let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send");
502        assert_eq!(
503            framed.get_ref().get_ref(),
504            &("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
505        );
506    }
507
508    #[test]
509    fn test_lone_whitespace() {
510        let mut c = XMPPCodec::new();
511        let mut b = BytesMut::with_capacity(1024);
512        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
513        let r = c.decode(&mut b);
514        assert!(match r {
515            Ok(Some(Packet::StreamStart(_))) => true,
516            _ => false,
517        });
518
519        b.clear();
520        b.put(r" ");
521        let r = c.decode(&mut b);
522        assert!(match r {
523            Ok(None) => true,
524            _ => false,
525        });
526    }
527}