xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use crate::{ParseError, ParserError};
  4use bytes::{BufMut, BytesMut};
  5use minidom::Element;
  6use quick_xml::Writer as EventWriter;
  7use std;
  8use std::cell::RefCell;
  9use std::collections::vec_deque::VecDeque;
 10use std::collections::HashMap;
 11use std::default::Default;
 12use std::fmt::Write;
 13use std::io;
 14use std::iter::FromIterator;
 15use std::rc::Rc;
 16use std::str::from_utf8;
 17use tokio_codec::{Decoder, Encoder};
 18use xml5ever::interface::Attribute;
 19use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
 20
 21/// Anything that can be sent or received on an XMPP/XML stream
 22#[derive(Debug, Clone, PartialEq, Eq)]
 23pub enum Packet {
 24    /// `<stream:stream>` start tag
 25    StreamStart(HashMap<String, String>),
 26    /// A complete stanza or nonza
 27    Stanza(Element),
 28    /// Plain text (think whitespace keep-alive)
 29    Text(String),
 30    /// `</stream:stream>` closing tag
 31    StreamEnd,
 32}
 33
 34type QueueItem = Result<Packet, ParserError>;
 35
 36/// Parser state
 37struct ParserSink {
 38    // Ready stanzas, shared with XMPPCodec
 39    queue: Rc<RefCell<VecDeque<QueueItem>>>,
 40    // Parsing stack
 41    stack: Vec<Element>,
 42    ns_stack: Vec<HashMap<Option<String>, String>>,
 43}
 44
 45impl ParserSink {
 46    pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
 47        ParserSink {
 48            queue,
 49            stack: vec![],
 50            ns_stack: vec![],
 51        }
 52    }
 53
 54    fn push_queue(&self, pkt: Packet) {
 55        self.queue.borrow_mut().push_back(Ok(pkt));
 56    }
 57
 58    fn push_queue_error(&self, e: ParserError) {
 59        self.queue.borrow_mut().push_back(Err(e));
 60    }
 61
 62    /// Lookup XML namespace declaration for given prefix (or no prefix)
 63    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 64        for nss in self.ns_stack.iter().rev() {
 65            if let Some(ns) = nss.get(prefix) {
 66                return Some(ns);
 67            }
 68        }
 69
 70        None
 71    }
 72
 73    fn handle_start_tag(&mut self, tag: Tag) {
 74        let mut nss = HashMap::new();
 75        let is_prefix_xmlns = |attr: &Attribute| {
 76            attr.name
 77                .prefix
 78                .as_ref()
 79                .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 80                .unwrap_or(false)
 81        };
 82        for attr in &tag.attrs {
 83            match attr.name.local.as_ref() {
 84                "xmlns" => {
 85                    nss.insert(None, attr.value.as_ref().to_owned());
 86                }
 87                prefix if is_prefix_xmlns(attr) => {
 88                    nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 89                }
 90                _ => (),
 91            }
 92        }
 93        self.ns_stack.push(nss);
 94
 95        let el = {
 96            let mut el_builder = Element::builder(tag.name.local.as_ref());
 97            if let Some(el_ns) =
 98                self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
 99            {
100                el_builder = el_builder.ns(el_ns);
101            }
102            for attr in &tag.attrs {
103                match attr.name.local.as_ref() {
104                    "xmlns" => (),
105                    _ if is_prefix_xmlns(attr) => (),
106                    _ => {
107                        el_builder = el_builder.attr(attr.name.local.as_ref(), attr.value.as_ref());
108                    }
109                }
110            }
111            el_builder.build()
112        };
113
114        if self.stack.is_empty() {
115            let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
116                (
117                    attr.name.local.as_ref().to_owned(),
118                    attr.value.as_ref().to_owned(),
119                )
120            }));
121            self.push_queue(Packet::StreamStart(attrs));
122        }
123
124        self.stack.push(el);
125    }
126
127    fn handle_end_tag(&mut self) {
128        let el = self.stack.pop().unwrap();
129        self.ns_stack.pop();
130
131        match self.stack.len() {
132            // </stream:stream>
133            0 => self.push_queue(Packet::StreamEnd),
134            // </stanza>
135            1 => self.push_queue(Packet::Stanza(el)),
136            len => {
137                let parent = &mut self.stack[len - 1];
138                parent.append_child(el);
139            }
140        }
141    }
142}
143
144impl TokenSink for ParserSink {
145    fn process_token(&mut self, token: Token) {
146        match token {
147            Token::TagToken(tag) => match tag.kind {
148                TagKind::StartTag => self.handle_start_tag(tag),
149                TagKind::EndTag => self.handle_end_tag(),
150                TagKind::EmptyTag => {
151                    self.handle_start_tag(tag);
152                    self.handle_end_tag();
153                }
154                TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
155            },
156            Token::CharacterTokens(tendril) => match self.stack.len() {
157                0 | 1 => self.push_queue(Packet::Text(tendril.into())),
158                len => {
159                    let el = &mut self.stack[len - 1];
160                    el.append_text_node(tendril);
161                }
162            },
163            Token::EOFToken => self.push_queue(Packet::StreamEnd),
164            Token::ParseError(s) => {
165                // println!("ParseError: {:?}", s);
166                self.push_queue_error(ParserError::Parse(ParseError(s)));
167            }
168            _ => (),
169        }
170    }
171
172    // fn end(&mut self) {
173    // }
174}
175
176/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
177pub struct XMPPCodec {
178    /// Outgoing
179    ns: Option<String>,
180    /// Incoming
181    parser: XmlTokenizer<ParserSink>,
182    /// For handling incoming truncated utf8
183    // TODO: optimize using  tendrils?
184    buf: Vec<u8>,
185    /// Shared with ParserSink
186    queue: Rc<RefCell<VecDeque<QueueItem>>>,
187}
188
189impl XMPPCodec {
190    /// Constructor
191    pub fn new() -> Self {
192        let queue = Rc::new(RefCell::new(VecDeque::new()));
193        let sink = ParserSink::new(queue.clone());
194        // TODO: configure parser?
195        let parser = XmlTokenizer::new(sink, Default::default());
196        XMPPCodec {
197            ns: None,
198            parser,
199            queue,
200            buf: vec![],
201        }
202    }
203}
204
205impl Default for XMPPCodec {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211impl Decoder for XMPPCodec {
212    type Item = Packet;
213    type Error = ParserError;
214
215    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
216        let buf1: Box<AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
217            let mut prefix = std::mem::replace(&mut self.buf, vec![]);
218            prefix.extend_from_slice(buf.take().as_ref());
219            Box::new(prefix)
220        } else {
221            Box::new(buf.take())
222        };
223        let buf1 = buf1.as_ref().as_ref();
224        match from_utf8(buf1) {
225            Ok(s) => {
226                if !s.is_empty() {
227                    // println!("<< {}", s);
228                    let tendril = FromIterator::from_iter(s.chars());
229                    self.parser.feed(tendril);
230                }
231            }
232            // Remedies for truncated utf8
233            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
234                // Prepare all the valid data
235                let mut b = BytesMut::with_capacity(e.valid_up_to());
236                b.put(&buf1[0..e.valid_up_to()]);
237
238                // Retry
239                let result = self.decode(&mut b);
240
241                // Keep the tail back in
242                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
243
244                return result;
245            }
246            Err(e) => {
247                // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
248                return Err(ParserError::Utf8(e));
249            }
250        }
251
252        match self.queue.borrow_mut().pop_front() {
253            None => Ok(None),
254            Some(result) => result.map(|pkt| Some(pkt)),
255        }
256    }
257
258    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
259        self.decode(buf)
260    }
261}
262
263impl Encoder for XMPPCodec {
264    type Item = Packet;
265    type Error = io::Error;
266
267    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
268        let remaining = dst.capacity() - dst.len();
269        let max_stanza_size: usize = 2usize.pow(16);
270        if remaining < max_stanza_size {
271            dst.reserve(max_stanza_size - remaining);
272        }
273
274        match item {
275            Packet::StreamStart(start_attrs) => {
276                let mut buf = String::new();
277                write!(buf, "<stream:stream").unwrap();
278                for (name, value) in start_attrs {
279                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).unwrap();
280                    if name == "xmlns" {
281                        self.ns = Some(value);
282                    }
283                }
284                write!(buf, ">\n").unwrap();
285
286                print!(">> {}", buf);
287                write!(dst, "{}", buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
288            }
289            Packet::Stanza(stanza) => {
290                stanza
291                    .write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
292                    .and_then(|_| {
293                        // println!(">> {:?}", dst);
294                        Ok(())
295                    })
296                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
297            }
298            Packet::Text(text) => {
299                write_text(&text, dst)
300                    .and_then(|_| {
301                        // println!(">> {:?}", dst);
302                        Ok(())
303                    })
304                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
305            }
306            // TODO: Implement all
307            _ => Ok(()),
308        }
309    }
310}
311
312/// Write XML-escaped text string
313pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
314    write!(writer, "{}", escape(text))
315}
316
317/// Copied from `RustyXML` for now
318pub fn escape(input: &str) -> String {
319    let mut result = String::with_capacity(input.len());
320
321    for c in input.chars() {
322        match c {
323            '&' => result.push_str("&amp;"),
324            '<' => result.push_str("&lt;"),
325            '>' => result.push_str("&gt;"),
326            '\'' => result.push_str("&apos;"),
327            '"' => result.push_str("&quot;"),
328            o => result.push(o),
329        }
330    }
331    result
332}
333
334/// BytesMut impl only std::fmt::Write but not std::io::Write. The
335/// latter trait is required for minidom's
336/// `Element::write_to_inner()`.
337struct WriteBytes<'a> {
338    dst: &'a mut BytesMut,
339}
340
341impl<'a> WriteBytes<'a> {
342    fn new(dst: &'a mut BytesMut) -> Self {
343        WriteBytes { dst }
344    }
345}
346
347impl<'a> std::io::Write for WriteBytes<'a> {
348    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
349        self.dst.put_slice(buf);
350        Ok(buf.len())
351    }
352
353    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
354        Ok(())
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use bytes::BytesMut;
362
363    #[test]
364    fn test_stream_start() {
365        let mut c = XMPPCodec::new();
366        let mut b = BytesMut::with_capacity(1024);
367        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
368        let r = c.decode(&mut b);
369        assert!(match r {
370            Ok(Some(Packet::StreamStart(_))) => true,
371            _ => false,
372        });
373    }
374
375    #[test]
376    fn test_truncated_stanza() {
377        let mut c = XMPPCodec::new();
378        let mut b = BytesMut::with_capacity(1024);
379        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
380        let r = c.decode(&mut b);
381        assert!(match r {
382            Ok(Some(Packet::StreamStart(_))) => true,
383            _ => false,
384        });
385
386        b.clear();
387        b.put(r"<test>ß</test");
388        let r = c.decode(&mut b);
389        assert!(match r {
390            Ok(None) => true,
391            _ => false,
392        });
393
394        b.clear();
395        b.put(r">");
396        let r = c.decode(&mut b);
397        assert!(match r {
398            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
399            _ => false,
400        });
401    }
402
403    #[test]
404    fn test_truncated_utf8() {
405        let mut c = XMPPCodec::new();
406        let mut b = BytesMut::with_capacity(1024);
407        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
408        let r = c.decode(&mut b);
409        assert!(match r {
410            Ok(Some(Packet::StreamStart(_))) => true,
411            _ => false,
412        });
413
414        b.clear();
415        b.put(&b"<test>\xc3"[..]);
416        let r = c.decode(&mut b);
417        assert!(match r {
418            Ok(None) => true,
419            _ => false,
420        });
421
422        b.clear();
423        b.put(&b"\x9f</test>"[..]);
424        let r = c.decode(&mut b);
425        assert!(match r {
426            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
427            _ => false,
428        });
429    }
430
431    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
432    #[test]
433    fn test_large_stanza() {
434        use futures::{Future, Sink};
435        use std::io::Cursor;
436        use tokio_codec::FramedWrite;
437        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
438        let mut text = "".to_owned();
439        for _ in 0..2usize.pow(15) {
440            text = text + "A";
441        }
442        let stanza = Element::builder("message")
443            .append(Element::builder("body").append(&text).build())
444            .build();
445        let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send");
446        assert_eq!(
447            framed.get_ref().get_ref(),
448            &("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
449        );
450    }
451}