xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use std;
  4use std::default::Default;
  5use std::iter::FromIterator;
  6use std::cell::RefCell;
  7use std::rc::Rc;
  8use std::fmt::Write;
  9use std::str::{from_utf8, Utf8Error};
 10use std::io;
 11use std::collections::HashMap;
 12use std::collections::vec_deque::VecDeque;
 13use std::error::Error as StdError;
 14use std::fmt;
 15use std::borrow::Cow;
 16use tokio_codec::{Encoder, Decoder};
 17use minidom::Element;
 18use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
 19use xml5ever::interface::Attribute;
 20use bytes::{BytesMut, BufMut};
 21use quick_xml::Writer as EventWriter;
 22
 23/// Anything that can be sent or received on an XMPP/XML stream
 24#[derive(Debug)]
 25pub enum Packet {
 26    /// `<stream:stream>` start tag
 27    StreamStart(HashMap<String, String>),
 28    /// A complete stanza or nonza
 29    Stanza(Element),
 30    /// Plain text (think whitespace keep-alive)
 31    Text(String),
 32    /// `</stream:stream>` closing tag
 33    StreamEnd,
 34}
 35
 36/// Causes for stream parsing errors
 37#[derive(Debug)]
 38pub enum ParserError {
 39    /// Encoding error
 40    Utf8(Utf8Error),
 41    /// XML parse error
 42    Parse(Cow<'static, str>),
 43    /// Illegal `</>`
 44    ShortTag,
 45    /// Required by `impl Decoder`
 46    IO(io::Error),
 47}
 48
 49impl From<io::Error> for ParserError {
 50    fn from(e: io::Error) -> Self {
 51        ParserError::IO(e)
 52    }
 53}
 54
 55impl StdError for ParserError {
 56    fn description(&self) -> &str {
 57        match *self {
 58            ParserError::Utf8(ref ue) => ue.description(),
 59            ParserError::Parse(ref pe) => pe,
 60            ParserError::ShortTag => "short tag",
 61            ParserError::IO(ref ie) => ie.description(),
 62        }
 63    }
 64
 65    fn cause(&self) -> Option<&StdError> {
 66        match *self {
 67            ParserError::Utf8(ref ue) => ue.cause(),
 68            ParserError::IO(ref ie) => ie.cause(),
 69            _ => None,
 70        }
 71    }
 72}
 73
 74impl fmt::Display for ParserError {
 75    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
 76        match *self {
 77            ParserError::Utf8(ref ue) => write!(f, "{}", ue),
 78            ParserError::Parse(ref pe) => write!(f, "{}", pe),
 79            ParserError::ShortTag => write!(f, "Short tag"),
 80            ParserError::IO(ref ie) => write!(f, "{}", ie),
 81        }
 82    }
 83}
 84
 85type QueueItem = Result<Packet, ParserError>;
 86
 87/// Parser state
 88struct ParserSink {
 89    // Ready stanzas, shared with XMPPCodec
 90    queue: Rc<RefCell<VecDeque<QueueItem>>>,
 91    // Parsing stack
 92    stack: Vec<Element>,
 93    ns_stack: Vec<HashMap<Option<String>, String>>,
 94}
 95
 96impl ParserSink {
 97    pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
 98        ParserSink {
 99            queue,
100            stack: vec![],
101            ns_stack: vec![],
102        }
103    }
104
105    fn push_queue(&self, pkt: Packet) {
106        self.queue.borrow_mut().push_back(Ok(pkt));
107    }
108
109    fn push_queue_error(&self, e: ParserError) {
110        self.queue.borrow_mut().push_back(Err(e));
111    }
112
113    /// Lookup XML namespace declaration for given prefix (or no prefix)
114    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
115        for nss in self.ns_stack.iter().rev() {
116            if let Some(ns) = nss.get(prefix) {
117                return Some(ns);
118            }
119        }
120
121        None
122    }
123
124    fn handle_start_tag(&mut self, tag: Tag) {
125        let mut nss = HashMap::new();
126        let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
127            .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
128            .unwrap_or(false);
129        for attr in &tag.attrs {
130            match attr.name.local.as_ref() {
131                "xmlns" => {
132                    nss.insert(None, attr.value.as_ref().to_owned());
133                },
134                prefix if is_prefix_xmlns(attr) => {
135                        nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
136                    },
137                _ => (),
138            }
139        }
140        self.ns_stack.push(nss);
141
142        let el = {
143            let mut el_builder = Element::builder(tag.name.local.as_ref());
144            if let Some(el_ns) = self.lookup_ns(
145                &tag.name.prefix.map(|prefix|
146                                     prefix.as_ref().to_owned())
147            ) {
148                el_builder = el_builder.ns(el_ns);
149            }
150            for attr in &tag.attrs {
151                match attr.name.local.as_ref() {
152                    "xmlns" => (),
153                    _ if is_prefix_xmlns(attr) => (),
154                    _ => {
155                        el_builder = el_builder.attr(
156                            attr.name.local.as_ref(),
157                            attr.value.as_ref()
158                        );
159                    },
160                }
161            }
162            el_builder.build()
163        };
164
165        if self.stack.is_empty() {
166            let attrs = HashMap::from_iter(
167                tag.attrs.iter()
168                    .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
169            );
170            self.push_queue(Packet::StreamStart(attrs));
171        }
172
173        self.stack.push(el);
174    }
175
176    fn handle_end_tag(&mut self) {
177        let el = self.stack.pop().unwrap();
178        self.ns_stack.pop();
179
180        match self.stack.len() {
181            // </stream:stream>
182            0 =>
183                self.push_queue(Packet::StreamEnd),
184            // </stanza>
185            1 =>
186                self.push_queue(Packet::Stanza(el)),
187            len => {
188                let parent = &mut self.stack[len - 1];
189                parent.append_child(el);
190            },
191        }
192    }
193}
194
195impl TokenSink for ParserSink {
196    fn process_token(&mut self, token: Token) {
197        match token {
198            Token::TagToken(tag) => match tag.kind {
199                TagKind::StartTag =>
200                    self.handle_start_tag(tag),
201                TagKind::EndTag =>
202                    self.handle_end_tag(),
203                TagKind::EmptyTag => {
204                    self.handle_start_tag(tag);
205                    self.handle_end_tag();
206                },
207                TagKind::ShortTag =>
208                    self.push_queue_error(ParserError::ShortTag),
209            },
210            Token::CharacterTokens(tendril) =>
211                match self.stack.len() {
212                    0 | 1 =>
213                        self.push_queue(Packet::Text(tendril.into())),
214                    len => {
215                        let el = &mut self.stack[len - 1];
216                        el.append_text_node(tendril);
217                    },
218                },
219            Token::EOFToken =>
220                self.push_queue(Packet::StreamEnd),
221            Token::ParseError(s) => {
222                // println!("ParseError: {:?}", s);
223                self.push_queue_error(ParserError::Parse(s));
224            },
225            _ => (),
226        }
227    }
228
229    // fn end(&mut self) {
230    // }
231}
232
233/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
234pub struct XMPPCodec {
235    /// Outgoing
236    ns: Option<String>,
237    /// Incoming
238    parser: XmlTokenizer<ParserSink>,
239    /// For handling incoming truncated utf8
240    // TODO: optimize using  tendrils?
241    buf: Vec<u8>,
242    /// Shared with ParserSink
243    queue: Rc<RefCell<VecDeque<QueueItem>>>,
244}
245
246impl XMPPCodec {
247    /// Constructor
248    pub fn new() -> Self {
249        let queue = Rc::new(RefCell::new(VecDeque::new()));
250        let sink = ParserSink::new(queue.clone());
251        // TODO: configure parser?
252        let parser = XmlTokenizer::new(sink, Default::default());
253        XMPPCodec {
254            ns: None,
255            parser,
256            queue,
257            buf: vec![],
258        }
259    }
260}
261
262impl Default for XMPPCodec {
263    fn default() -> Self {
264        Self::new()
265    }
266}
267
268impl Decoder for XMPPCodec {
269    type Item = Packet;
270    type Error = ParserError;
271
272    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
273        let buf1: Box<AsRef<[u8]>> =
274            if ! self.buf.is_empty() && ! buf.is_empty() {
275                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
276                prefix.extend_from_slice(buf.take().as_ref());
277                Box::new(prefix)
278            } else {
279                Box::new(buf.take())
280            };
281        let buf1 = buf1.as_ref().as_ref();
282        match from_utf8(buf1) {
283            Ok(s) => {
284                if ! s.is_empty() {
285                    // println!("<< {}", s);
286                    let tendril = FromIterator::from_iter(s.chars());
287                    self.parser.feed(tendril);
288                }
289            },
290            // Remedies for truncated utf8
291            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
292                // Prepare all the valid data
293                let mut b = BytesMut::with_capacity(e.valid_up_to());
294                b.put(&buf1[0..e.valid_up_to()]);
295
296                // Retry
297                let result = self.decode(&mut b);
298
299                // Keep the tail back in
300                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
301
302                return result;
303            },
304            Err(e) => {
305                // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
306                return Err(ParserError::Utf8(e));
307            },
308        }
309
310        match self.queue.borrow_mut().pop_front() {
311            None => Ok(None),
312            Some(result) =>
313                result.map(|pkt| Some(pkt)),
314        }
315    }
316
317    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
318        self.decode(buf)
319    }
320}
321
322impl Encoder for XMPPCodec {
323    type Item = Packet;
324    type Error = io::Error;
325
326    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
327        let remaining = dst.capacity() - dst.len();
328        let max_stanza_size: usize = 2usize.pow(16);
329        if remaining < max_stanza_size {
330            dst.reserve(max_stanza_size - remaining);
331        }
332
333        match item {
334            Packet::StreamStart(start_attrs) => {
335                let mut buf = String::new();
336                write!(buf, "<stream:stream").unwrap();
337                for (name, value) in start_attrs {
338                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
339                        .unwrap();
340                    if name == "xmlns" {
341                        self.ns = Some(value);
342                    }
343                }
344                write!(buf, ">\n").unwrap();
345
346                print!(">> {}", buf);
347                write!(dst, "{}", buf)
348                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
349            },
350            Packet::Stanza(stanza) => {
351                stanza.write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
352                    .and_then(|_| {
353                        // println!(">> {:?}", dst);
354                        Ok(())
355                    })
356                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
357            },
358            Packet::Text(text) => {
359                write_text(&text, dst)
360                    .and_then(|_| {
361                        // println!(">> {:?}", dst);
362                        Ok(())
363                    })
364                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
365            },
366            // TODO: Implement all
367            _ => Ok(())
368        }
369    }
370}
371
372/// Write XML-escaped text string
373pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
374    write!(writer, "{}", escape(text))
375}
376
377/// Copied from `RustyXML` for now
378pub fn escape(input: &str) -> String {
379    let mut result = String::with_capacity(input.len());
380
381    for c in input.chars() {
382        match c {
383            '&' => result.push_str("&amp;"),
384            '<' => result.push_str("&lt;"),
385            '>' => result.push_str("&gt;"),
386            '\'' => result.push_str("&apos;"),
387            '"' => result.push_str("&quot;"),
388            o => result.push(o)
389        }
390    }
391    result
392}
393
394/// BytesMut impl only std::fmt::Write but not std::io::Write. The
395/// latter trait is required for minidom's
396/// `Element::write_to_inner()`.
397struct WriteBytes<'a> {
398    dst: &'a mut BytesMut,
399}
400
401impl<'a> WriteBytes<'a> {
402    fn new(dst: &'a mut BytesMut) -> Self {
403        WriteBytes { dst }
404    }
405}
406
407impl<'a> std::io::Write for WriteBytes<'a> {
408    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
409        self.dst.put_slice(buf);
410        Ok(buf.len())
411    }
412
413    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
414        Ok(())
415    }
416}
417
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use bytes::BytesMut;
423
424    #[test]
425    fn test_stream_start() {
426        let mut c = XMPPCodec::new();
427        let mut b = BytesMut::with_capacity(1024);
428        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
429        let r = c.decode(&mut b);
430        assert!(match r {
431            Ok(Some(Packet::StreamStart(_))) => true,
432            _ => false,
433        });
434    }
435
436    #[test]
437    fn test_truncated_stanza() {
438        let mut c = XMPPCodec::new();
439        let mut b = BytesMut::with_capacity(1024);
440        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
441        let r = c.decode(&mut b);
442        assert!(match r {
443            Ok(Some(Packet::StreamStart(_))) => true,
444            _ => false,
445        });
446
447        b.clear();
448        b.put(r"<test>ß</test");
449        let r = c.decode(&mut b);
450        assert!(match r {
451            Ok(None) => true,
452            _ => false,
453        });
454
455        b.clear();
456        b.put(r">");
457        let r = c.decode(&mut b);
458        assert!(match r {
459            Ok(Some(Packet::Stanza(ref el)))
460                if el.name() == "test"
461                && el.text() == "ß"
462                => true,
463            _ => false,
464        });
465    }
466
467    #[test]
468    fn test_truncated_utf8() {
469        let mut c = XMPPCodec::new();
470        let mut b = BytesMut::with_capacity(1024);
471        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
472        let r = c.decode(&mut b);
473        assert!(match r {
474            Ok(Some(Packet::StreamStart(_))) => true,
475            _ => false,
476        });
477
478        b.clear();
479        b.put(&b"<test>\xc3"[..]);
480        let r = c.decode(&mut b);
481        assert!(match r {
482            Ok(None) => true,
483            _ => false,
484        });
485
486        b.clear();
487        b.put(&b"\x9f</test>"[..]);
488        let r = c.decode(&mut b);
489        assert!(match r {
490            Ok(Some(Packet::Stanza(ref el)))
491                if el.name() == "test"
492                && el.text() == "ß"
493                => true,
494            _ => false,
495        });
496    }
497
498    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
499    #[test]
500    fn test_large_stanza() {
501        use std::io::Cursor;
502        use futures::{Future, Sink};
503        use tokio_codec::FramedWrite;
504        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
505        let mut text = "".to_owned();
506        for _ in 0..2usize.pow(15) {
507            text = text + "A";
508        }
509        let stanza = Element::builder("message")
510            .append(
511                Element::builder("body")
512                    .append(&text)
513                    .build()
514            )
515            .build();
516        let framed = framed.send(Packet::Stanza(stanza))
517            .wait()
518            .expect("send");
519        assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
520    }
521}