xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use crate::{ParseError, ParserError};
  4use bytes::{BufMut, BytesMut};
  5use minidom::Element;
  6use quick_xml::Writer as EventWriter;
  7use std;
  8use std::cell::RefCell;
  9use std::collections::vec_deque::VecDeque;
 10use std::collections::HashMap;
 11use std::default::Default;
 12use std::fmt::Write;
 13use std::io;
 14use std::iter::FromIterator;
 15use std::rc::Rc;
 16use std::str::from_utf8;
 17use tokio_codec::{Decoder, Encoder};
 18use xml5ever::interface::Attribute;
 19use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
 20
 21/// Anything that can be sent or received on an XMPP/XML stream
 22#[derive(Debug, Clone, PartialEq, Eq)]
 23pub enum Packet {
 24    /// `<stream:stream>` start tag
 25    StreamStart(HashMap<String, String>),
 26    /// A complete stanza or nonza
 27    Stanza(Element),
 28    /// Plain text (think whitespace keep-alive)
 29    Text(String),
 30    /// `</stream:stream>` closing tag
 31    StreamEnd,
 32}
 33
 34type QueueItem = Result<Packet, ParserError>;
 35
 36/// Parser state
 37struct ParserSink {
 38    // Ready stanzas, shared with XMPPCodec
 39    queue: Rc<RefCell<VecDeque<QueueItem>>>,
 40    // Parsing stack
 41    stack: Vec<Element>,
 42    ns_stack: Vec<HashMap<Option<String>, String>>,
 43}
 44
 45impl ParserSink {
 46    pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
 47        ParserSink {
 48            queue,
 49            stack: vec![],
 50            ns_stack: vec![],
 51        }
 52    }
 53
 54    fn push_queue(&self, pkt: Packet) {
 55        self.queue.borrow_mut().push_back(Ok(pkt));
 56    }
 57
 58    fn push_queue_error(&self, e: ParserError) {
 59        self.queue.borrow_mut().push_back(Err(e));
 60    }
 61
 62    /// Lookup XML namespace declaration for given prefix (or no prefix)
 63    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 64        for nss in self.ns_stack.iter().rev() {
 65            if let Some(ns) = nss.get(prefix) {
 66                return Some(ns);
 67            }
 68        }
 69
 70        None
 71    }
 72
 73    fn handle_start_tag(&mut self, tag: Tag) {
 74        let mut nss = HashMap::new();
 75        let is_prefix_xmlns = |attr: &Attribute| {
 76            attr.name
 77                .prefix
 78                .as_ref()
 79                .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 80                .unwrap_or(false)
 81        };
 82        for attr in &tag.attrs {
 83            match attr.name.local.as_ref() {
 84                "xmlns" => {
 85                    nss.insert(None, attr.value.as_ref().to_owned());
 86                }
 87                prefix if is_prefix_xmlns(attr) => {
 88                    nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 89                }
 90                _ => (),
 91            }
 92        }
 93        self.ns_stack.push(nss);
 94
 95        let el = {
 96            let mut el_builder = Element::builder(tag.name.local.as_ref());
 97            if let Some(el_ns) =
 98                self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
 99            {
100                el_builder = el_builder.ns(el_ns);
101            }
102            for attr in &tag.attrs {
103                match attr.name.local.as_ref() {
104                    "xmlns" => (),
105                    _ if is_prefix_xmlns(attr) => (),
106                    _ => {
107                        if let Some(ref prefix) = attr.name.prefix {
108                            el_builder = el_builder.attr(format!("{}:{}", prefix, attr.name.local), attr.value.as_ref());
109                        } else {
110                            el_builder = el_builder.attr(attr.name.local.as_ref(), attr.value.as_ref());
111                        }
112                    }
113                }
114            }
115            el_builder.build()
116        };
117
118        if self.stack.is_empty() {
119            let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
120                (
121                    attr.name.local.as_ref().to_owned(),
122                    attr.value.as_ref().to_owned(),
123                )
124            }));
125            self.push_queue(Packet::StreamStart(attrs));
126        }
127
128        self.stack.push(el);
129    }
130
131    fn handle_end_tag(&mut self) {
132        let el = self.stack.pop().unwrap();
133        self.ns_stack.pop();
134
135        match self.stack.len() {
136            // </stream:stream>
137            0 => self.push_queue(Packet::StreamEnd),
138            // </stanza>
139            1 => self.push_queue(Packet::Stanza(el)),
140            len => {
141                let parent = &mut self.stack[len - 1];
142                parent.append_child(el);
143            }
144        }
145    }
146}
147
148impl TokenSink for ParserSink {
149    fn process_token(&mut self, token: Token) {
150        match token {
151            Token::TagToken(tag) => match tag.kind {
152                TagKind::StartTag => self.handle_start_tag(tag),
153                TagKind::EndTag => self.handle_end_tag(),
154                TagKind::EmptyTag => {
155                    self.handle_start_tag(tag);
156                    self.handle_end_tag();
157                }
158                TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
159            },
160            Token::CharacterTokens(tendril) => match self.stack.len() {
161                0 | 1 => self.push_queue(Packet::Text(tendril.into())),
162                len => {
163                    let el = &mut self.stack[len - 1];
164                    el.append_text_node(tendril);
165                }
166            },
167            Token::EOFToken => self.push_queue(Packet::StreamEnd),
168            Token::ParseError(s) => {
169                // println!("ParseError: {:?}", s);
170                self.push_queue_error(ParserError::Parse(ParseError(s)));
171            }
172            _ => (),
173        }
174    }
175
176    // fn end(&mut self) {
177    // }
178}
179
180/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
181pub struct XMPPCodec {
182    /// Outgoing
183    ns: Option<String>,
184    /// Incoming
185    parser: XmlTokenizer<ParserSink>,
186    /// For handling incoming truncated utf8
187    // TODO: optimize using  tendrils?
188    buf: Vec<u8>,
189    /// Shared with ParserSink
190    queue: Rc<RefCell<VecDeque<QueueItem>>>,
191}
192
193impl XMPPCodec {
194    /// Constructor
195    pub fn new() -> Self {
196        let queue = Rc::new(RefCell::new(VecDeque::new()));
197        let sink = ParserSink::new(queue.clone());
198        // TODO: configure parser?
199        let parser = XmlTokenizer::new(sink, Default::default());
200        XMPPCodec {
201            ns: None,
202            parser,
203            queue,
204            buf: vec![],
205        }
206    }
207}
208
209impl Default for XMPPCodec {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215impl Decoder for XMPPCodec {
216    type Item = Packet;
217    type Error = ParserError;
218
219    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
220        let buf1: Box<AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
221            let mut prefix = std::mem::replace(&mut self.buf, vec![]);
222            prefix.extend_from_slice(buf.take().as_ref());
223            Box::new(prefix)
224        } else {
225            Box::new(buf.take())
226        };
227        let buf1 = buf1.as_ref().as_ref();
228        match from_utf8(buf1) {
229            Ok(s) => {
230                if !s.is_empty() {
231                    // println!("<< {}", s);
232                    let tendril = FromIterator::from_iter(s.chars());
233                    self.parser.feed(tendril);
234                }
235            }
236            // Remedies for truncated utf8
237            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
238                // Prepare all the valid data
239                let mut b = BytesMut::with_capacity(e.valid_up_to());
240                b.put(&buf1[0..e.valid_up_to()]);
241
242                // Retry
243                let result = self.decode(&mut b);
244
245                // Keep the tail back in
246                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
247
248                return result;
249            }
250            Err(e) => {
251                // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
252                return Err(ParserError::Utf8(e));
253            }
254        }
255
256        match self.queue.borrow_mut().pop_front() {
257            None => Ok(None),
258            Some(result) => result.map(|pkt| Some(pkt)),
259        }
260    }
261
262    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
263        self.decode(buf)
264    }
265}
266
267impl Encoder for XMPPCodec {
268    type Item = Packet;
269    type Error = io::Error;
270
271    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
272        let remaining = dst.capacity() - dst.len();
273        let max_stanza_size: usize = 2usize.pow(16);
274        if remaining < max_stanza_size {
275            dst.reserve(max_stanza_size - remaining);
276        }
277
278        match item {
279            Packet::StreamStart(start_attrs) => {
280                let mut buf = String::new();
281                write!(buf, "<stream:stream").unwrap();
282                for (name, value) in start_attrs {
283                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).unwrap();
284                    if name == "xmlns" {
285                        self.ns = Some(value);
286                    }
287                }
288                write!(buf, ">\n").unwrap();
289
290                print!(">> {}", buf);
291                write!(dst, "{}", buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
292            }
293            Packet::Stanza(stanza) => {
294                stanza
295                    .write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
296                    .and_then(|_| {
297                        // println!(">> {:?}", dst);
298                        Ok(())
299                    })
300                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
301            }
302            Packet::Text(text) => {
303                write_text(&text, dst)
304                    .and_then(|_| {
305                        // println!(">> {:?}", dst);
306                        Ok(())
307                    })
308                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
309            }
310            // TODO: Implement all
311            _ => Ok(()),
312        }
313    }
314}
315
316/// Write XML-escaped text string
317pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
318    write!(writer, "{}", escape(text))
319}
320
321/// Copied from `RustyXML` for now
322pub fn escape(input: &str) -> String {
323    let mut result = String::with_capacity(input.len());
324
325    for c in input.chars() {
326        match c {
327            '&' => result.push_str("&amp;"),
328            '<' => result.push_str("&lt;"),
329            '>' => result.push_str("&gt;"),
330            '\'' => result.push_str("&apos;"),
331            '"' => result.push_str("&quot;"),
332            o => result.push(o),
333        }
334    }
335    result
336}
337
338/// BytesMut impl only std::fmt::Write but not std::io::Write. The
339/// latter trait is required for minidom's
340/// `Element::write_to_inner()`.
341struct WriteBytes<'a> {
342    dst: &'a mut BytesMut,
343}
344
345impl<'a> WriteBytes<'a> {
346    fn new(dst: &'a mut BytesMut) -> Self {
347        WriteBytes { dst }
348    }
349}
350
351impl<'a> std::io::Write for WriteBytes<'a> {
352    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
353        self.dst.put_slice(buf);
354        Ok(buf.len())
355    }
356
357    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
358        Ok(())
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use bytes::BytesMut;
366
367    #[test]
368    fn test_stream_start() {
369        let mut c = XMPPCodec::new();
370        let mut b = BytesMut::with_capacity(1024);
371        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
372        let r = c.decode(&mut b);
373        assert!(match r {
374            Ok(Some(Packet::StreamStart(_))) => true,
375            _ => false,
376        });
377    }
378
379    #[test]
380    fn test_truncated_stanza() {
381        let mut c = XMPPCodec::new();
382        let mut b = BytesMut::with_capacity(1024);
383        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
384        let r = c.decode(&mut b);
385        assert!(match r {
386            Ok(Some(Packet::StreamStart(_))) => true,
387            _ => false,
388        });
389
390        b.clear();
391        b.put(r"<test>ß</test");
392        let r = c.decode(&mut b);
393        assert!(match r {
394            Ok(None) => true,
395            _ => false,
396        });
397
398        b.clear();
399        b.put(r">");
400        let r = c.decode(&mut b);
401        assert!(match r {
402            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
403            _ => false,
404        });
405    }
406
407    #[test]
408    fn test_truncated_utf8() {
409        let mut c = XMPPCodec::new();
410        let mut b = BytesMut::with_capacity(1024);
411        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
412        let r = c.decode(&mut b);
413        assert!(match r {
414            Ok(Some(Packet::StreamStart(_))) => true,
415            _ => false,
416        });
417
418        b.clear();
419        b.put(&b"<test>\xc3"[..]);
420        let r = c.decode(&mut b);
421        assert!(match r {
422            Ok(None) => true,
423            _ => false,
424        });
425
426        b.clear();
427        b.put(&b"\x9f</test>"[..]);
428        let r = c.decode(&mut b);
429        assert!(match r {
430            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
431            _ => false,
432        });
433    }
434
435    /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
436    #[test]
437    fn test_atrribute_prefix() {
438        let mut c = XMPPCodec::new();
439        let mut b = BytesMut::with_capacity(1024);
440        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
441        let r = c.decode(&mut b);
442        assert!(match r {
443            Ok(Some(Packet::StreamStart(_))) => true,
444            _ => false,
445        });
446
447        b.clear();
448        b.put(r"<status xml:lang='en'>Test status</status>");
449        let r = c.decode(&mut b);
450        assert!(match r {
451            Ok(Some(Packet::Stanza(ref el))) if el.name() == "status" && el.text() == "Test status" && el.attr("xml:lang").map_or(false, |a| a == "en") => true,
452            _ => false,
453        });
454
455    }
456
457    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
458    #[test]
459    fn test_large_stanza() {
460        use futures::{Future, Sink};
461        use std::io::Cursor;
462        use tokio_codec::FramedWrite;
463        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
464        let mut text = "".to_owned();
465        for _ in 0..2usize.pow(15) {
466            text = text + "A";
467        }
468        let stanza = Element::builder("message")
469            .append(Element::builder("body").append(&text).build())
470            .build();
471        let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send");
472        assert_eq!(
473            framed.get_ref().get_ref(),
474            &("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
475        );
476    }
477}