xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use crate::{ParseError, ParserError};
  4use bytes::{BufMut, BytesMut};
  5use xmpp_parsers::Element;
  6use quick_xml::Writer as EventWriter;
  7use std;
  8use std::cell::RefCell;
  9use std::collections::vec_deque::VecDeque;
 10use std::collections::HashMap;
 11use std::default::Default;
 12use std::fmt::Write;
 13use std::io;
 14use std::iter::FromIterator;
 15use std::rc::Rc;
 16use std::str::from_utf8;
 17use std::borrow::Cow;
 18use tokio_codec::{Decoder, Encoder};
 19use xml5ever::interface::Attribute;
 20use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
 21use xml5ever::buffer_queue::BufferQueue;
 22
 23/// Anything that can be sent or received on an XMPP/XML stream
 24#[derive(Debug, Clone, PartialEq, Eq)]
 25pub enum Packet {
 26    /// `<stream:stream>` start tag
 27    StreamStart(HashMap<String, String>),
 28    /// A complete stanza or nonza
 29    Stanza(Element),
 30    /// Plain text (think whitespace keep-alive)
 31    Text(String),
 32    /// `</stream:stream>` closing tag
 33    StreamEnd,
 34}
 35
 36type QueueItem = Result<Packet, ParserError>;
 37
 38/// Parser state
 39struct ParserSink {
 40    // Ready stanzas, shared with XMPPCodec
 41    queue: Rc<RefCell<VecDeque<QueueItem>>>,
 42    // Parsing stack
 43    stack: Vec<Element>,
 44    ns_stack: Vec<HashMap<Option<String>, String>>,
 45}
 46
 47impl ParserSink {
 48    pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
 49        ParserSink {
 50            queue,
 51            stack: vec![],
 52            ns_stack: vec![],
 53        }
 54    }
 55
 56    fn push_queue(&self, pkt: Packet) {
 57        self.queue.borrow_mut().push_back(Ok(pkt));
 58    }
 59
 60    fn push_queue_error(&self, e: ParserError) {
 61        self.queue.borrow_mut().push_back(Err(e));
 62    }
 63
 64    /// Lookup XML namespace declaration for given prefix (or no prefix)
 65    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 66        for nss in self.ns_stack.iter().rev() {
 67            if let Some(ns) = nss.get(prefix) {
 68                return Some(ns);
 69            }
 70        }
 71
 72        None
 73    }
 74
 75    fn handle_start_tag(&mut self, tag: Tag) {
 76        let mut nss = HashMap::new();
 77        let is_prefix_xmlns = |attr: &Attribute| {
 78            attr.name
 79                .prefix
 80                .as_ref()
 81                .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 82                .unwrap_or(false)
 83        };
 84        for attr in &tag.attrs {
 85            match attr.name.local.as_ref() {
 86                "xmlns" => {
 87                    nss.insert(None, attr.value.as_ref().to_owned());
 88                }
 89                prefix if is_prefix_xmlns(attr) => {
 90                    nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 91                }
 92                _ => (),
 93            }
 94        }
 95        self.ns_stack.push(nss);
 96
 97        let el = {
 98            let mut el_builder = Element::builder(tag.name.local.as_ref());
 99            if let Some(el_ns) =
100                self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
101            {
102                el_builder = el_builder.ns(el_ns);
103            }
104            for attr in &tag.attrs {
105                match attr.name.local.as_ref() {
106                    "xmlns" => (),
107                    _ if is_prefix_xmlns(attr) => (),
108                    _ => {
109                        let attr_name = if let Some(ref prefix) = attr.name.prefix {
110                            Cow::Owned(format!("{}:{}", prefix, attr.name.local))
111                        } else {
112                            Cow::Borrowed(attr.name.local.as_ref())
113                        };
114                        el_builder = el_builder.attr(attr_name, attr.value.as_ref());
115                    }
116                }
117            }
118            el_builder.build()
119        };
120
121        if self.stack.is_empty() {
122            let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
123                (
124                    attr.name.local.as_ref().to_owned(),
125                    attr.value.as_ref().to_owned(),
126                )
127            }));
128            self.push_queue(Packet::StreamStart(attrs));
129        }
130
131        self.stack.push(el);
132    }
133
134    fn handle_end_tag(&mut self) {
135        let el = self.stack.pop().unwrap();
136        self.ns_stack.pop();
137
138        match self.stack.len() {
139            // </stream:stream>
140            0 => self.push_queue(Packet::StreamEnd),
141            // </stanza>
142            1 => self.push_queue(Packet::Stanza(el)),
143            len => {
144                let parent = &mut self.stack[len - 1];
145                parent.append_child(el);
146            }
147        }
148    }
149}
150
151impl TokenSink for ParserSink {
152    fn process_token(&mut self, token: Token) {
153        match token {
154            Token::TagToken(tag) => match tag.kind {
155                TagKind::StartTag => self.handle_start_tag(tag),
156                TagKind::EndTag => self.handle_end_tag(),
157                TagKind::EmptyTag => {
158                    self.handle_start_tag(tag);
159                    self.handle_end_tag();
160                }
161                TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
162            },
163            Token::CharacterTokens(tendril) => match self.stack.len() {
164                0 | 1 => self.push_queue(Packet::Text(tendril.into())),
165                len => {
166                    let el = &mut self.stack[len - 1];
167                    el.append_text_node(tendril);
168                }
169            },
170            Token::EOFToken => self.push_queue(Packet::StreamEnd),
171            Token::ParseError(s) => {
172                // println!("ParseError: {:?}", s);
173                self.push_queue_error(ParserError::Parse(ParseError(s)));
174            }
175            _ => (),
176        }
177    }
178
179    // fn end(&mut self) {
180    // }
181}
182
183/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
184pub struct XMPPCodec {
185    /// Outgoing
186    ns: Option<String>,
187    /// Incoming
188    parser: XmlTokenizer<ParserSink>,
189    /// For handling incoming truncated utf8
190    // TODO: optimize using  tendrils?
191    buf: Vec<u8>,
192    /// Shared with ParserSink
193    queue: Rc<RefCell<VecDeque<QueueItem>>>,
194}
195
196impl XMPPCodec {
197    /// Constructor
198    pub fn new() -> Self {
199        let queue = Rc::new(RefCell::new(VecDeque::new()));
200        let sink = ParserSink::new(queue.clone());
201        // TODO: configure parser?
202        let parser = XmlTokenizer::new(sink, Default::default());
203        XMPPCodec {
204            ns: None,
205            parser,
206            queue,
207            buf: vec![],
208        }
209    }
210}
211
212impl Default for XMPPCodec {
213    fn default() -> Self {
214        Self::new()
215    }
216}
217
218impl Decoder for XMPPCodec {
219    type Item = Packet;
220    type Error = ParserError;
221
222    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
223        let buf1: Box<dyn AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
224            let mut prefix = std::mem::replace(&mut self.buf, vec![]);
225            prefix.extend_from_slice(buf.take().as_ref());
226            Box::new(prefix)
227        } else {
228            Box::new(buf.take())
229        };
230        let buf1 = buf1.as_ref().as_ref();
231        match from_utf8(buf1) {
232            Ok(mut s) => {
233                s = s.trim();
234                if !s.is_empty() {
235                    // println!("<< {}", s);
236                    let mut buffer_queue = BufferQueue::new();
237                    let tendril = FromIterator::from_iter(s.chars());
238                    buffer_queue.push_back(tendril);
239                    self.parser.feed(&mut buffer_queue);
240                }
241            }
242            // Remedies for truncated utf8
243            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
244                // Prepare all the valid data
245                let mut b = BytesMut::with_capacity(e.valid_up_to());
246                b.put(&buf1[0..e.valid_up_to()]);
247
248                // Retry
249                let result = self.decode(&mut b);
250
251                // Keep the tail back in
252                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
253
254                return result;
255            }
256            Err(e) => {
257                // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
258                return Err(ParserError::Utf8(e));
259            }
260        }
261
262        match self.queue.borrow_mut().pop_front() {
263            None => Ok(None),
264            Some(result) => result.map(|pkt| Some(pkt)),
265        }
266    }
267
268    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
269        self.decode(buf)
270    }
271}
272
273impl Encoder for XMPPCodec {
274    type Item = Packet;
275    type Error = io::Error;
276
277    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
278        let remaining = dst.capacity() - dst.len();
279        let max_stanza_size: usize = 2usize.pow(16);
280        if remaining < max_stanza_size {
281            dst.reserve(max_stanza_size - remaining);
282        }
283
284        fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
285            io::Error::new(io::ErrorKind::InvalidInput, e)
286        }
287
288        match item {
289            Packet::StreamStart(start_attrs) => {
290                let mut buf = String::new();
291                write!(buf, "<stream:stream")
292                    .map_err(to_io_err)?;
293                for (name, value) in start_attrs {
294                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
295                        .map_err(to_io_err)?;
296                    if name == "xmlns" {
297                        self.ns = Some(value);
298                    }
299                }
300                write!(buf, ">\n")
301                    .map_err(to_io_err)?;
302
303                // print!(">> {}", buf);
304                write!(dst, "{}", buf)
305                    .map_err(to_io_err)
306            }
307            Packet::Stanza(stanza) => {
308                stanza
309                    .write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
310                    .and_then(|_| {
311                        // println!(">> {:?}", dst);
312                        Ok(())
313                    })
314                    .map_err(|e| to_io_err(format!("{}", e)))
315            }
316            Packet::Text(text) => {
317                write_text(&text, dst)
318                    .and_then(|_| {
319                        // println!(">> {:?}", dst);
320                        Ok(())
321                    })
322                    .map_err(to_io_err)
323            }
324            Packet::StreamEnd => {
325                write!(dst, "</stream:stream>\n")
326                    .map_err(to_io_err)
327            }
328        }
329    }
330}
331
332/// Write XML-escaped text string
333pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
334    write!(writer, "{}", escape(text))
335}
336
337/// Copied from `RustyXML` for now
338pub fn escape(input: &str) -> String {
339    let mut result = String::with_capacity(input.len());
340
341    for c in input.chars() {
342        match c {
343            '&' => result.push_str("&amp;"),
344            '<' => result.push_str("&lt;"),
345            '>' => result.push_str("&gt;"),
346            '\'' => result.push_str("&apos;"),
347            '"' => result.push_str("&quot;"),
348            o => result.push(o),
349        }
350    }
351    result
352}
353
354/// BytesMut impl only std::fmt::Write but not std::io::Write. The
355/// latter trait is required for minidom's
356/// `Element::write_to_inner()`.
357struct WriteBytes<'a> {
358    dst: &'a mut BytesMut,
359}
360
361impl<'a> WriteBytes<'a> {
362    fn new(dst: &'a mut BytesMut) -> Self {
363        WriteBytes { dst }
364    }
365}
366
367impl<'a> std::io::Write for WriteBytes<'a> {
368    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
369        self.dst.put_slice(buf);
370        Ok(buf.len())
371    }
372
373    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
374        Ok(())
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use bytes::BytesMut;
382
383    #[test]
384    fn test_stream_start() {
385        let mut c = XMPPCodec::new();
386        let mut b = BytesMut::with_capacity(1024);
387        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
388        let r = c.decode(&mut b);
389        assert!(match r {
390            Ok(Some(Packet::StreamStart(_))) => true,
391            _ => false,
392        });
393    }
394
395    #[test]
396    fn test_stream_end() {
397        let mut c = XMPPCodec::new();
398        let mut b = BytesMut::with_capacity(1024);
399        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
400        let r = c.decode(&mut b);
401        assert!(match r {
402            Ok(Some(Packet::StreamStart(_))) => true,
403            _ => false,
404        });
405        b.clear();
406        b.put(r"</stream:stream>");
407        let r = c.decode(&mut b);
408        assert!(match r {
409            Ok(Some(Packet::StreamEnd)) => true,
410            _ => false,
411        });
412    }
413
414    #[test]
415    fn test_truncated_stanza() {
416        let mut c = XMPPCodec::new();
417        let mut b = BytesMut::with_capacity(1024);
418        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
419        let r = c.decode(&mut b);
420        assert!(match r {
421            Ok(Some(Packet::StreamStart(_))) => true,
422            _ => false,
423        });
424
425        b.clear();
426        b.put(r"<test>ß</test");
427        let r = c.decode(&mut b);
428        assert!(match r {
429            Ok(None) => true,
430            _ => false,
431        });
432
433        b.clear();
434        b.put(r">");
435        let r = c.decode(&mut b);
436        assert!(match r {
437            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
438            _ => false,
439        });
440    }
441
442    #[test]
443    fn test_truncated_utf8() {
444        let mut c = XMPPCodec::new();
445        let mut b = BytesMut::with_capacity(1024);
446        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
447        let r = c.decode(&mut b);
448        assert!(match r {
449            Ok(Some(Packet::StreamStart(_))) => true,
450            _ => false,
451        });
452
453        b.clear();
454        b.put(&b"<test>\xc3"[..]);
455        let r = c.decode(&mut b);
456        assert!(match r {
457            Ok(None) => true,
458            _ => false,
459        });
460
461        b.clear();
462        b.put(&b"\x9f</test>"[..]);
463        let r = c.decode(&mut b);
464        assert!(match r {
465            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
466            _ => false,
467        });
468    }
469
470    /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
471    #[test]
472    fn test_atrribute_prefix() {
473        let mut c = XMPPCodec::new();
474        let mut b = BytesMut::with_capacity(1024);
475        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
476        let r = c.decode(&mut b);
477        assert!(match r {
478            Ok(Some(Packet::StreamStart(_))) => true,
479            _ => false,
480        });
481
482        b.clear();
483        b.put(r"<status xml:lang='en'>Test status</status>");
484        let r = c.decode(&mut b);
485        assert!(match r {
486            Ok(Some(Packet::Stanza(ref el))) if el.name() == "status" && el.text() == "Test status" && el.attr("xml:lang").map_or(false, |a| a == "en") => true,
487            _ => false,
488        });
489
490    }
491
492    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
493    #[test]
494    fn test_large_stanza() {
495        use futures::{Future, Sink};
496        use std::io::Cursor;
497        use tokio_codec::FramedWrite;
498        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
499        let mut text = "".to_owned();
500        for _ in 0..2usize.pow(15) {
501            text = text + "A";
502        }
503        let stanza = Element::builder("message")
504            .append(Element::builder("body").append(text.as_ref()).build())
505            .build();
506        let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send");
507        assert_eq!(
508            framed.get_ref().get_ref(),
509            &("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
510        );
511    }
512
513    #[test]
514    fn test_lone_whitespace() {
515        let mut c = XMPPCodec::new();
516        let mut b = BytesMut::with_capacity(1024);
517        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
518        let r = c.decode(&mut b);
519        assert!(match r {
520            Ok(Some(Packet::StreamStart(_))) => true,
521            _ => false,
522        });
523
524        b.clear();
525        b.put(r" ");
526        let r = c.decode(&mut b);
527        assert!(match r {
528            Ok(None) => true,
529            _ => false,
530        });
531    }
532}