xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use std;
  4use std::default::Default;
  5use std::iter::FromIterator;
  6use std::cell::RefCell;
  7use std::rc::Rc;
  8use std::fmt::Write;
  9use std::str::from_utf8;
 10use std::io::{Error, ErrorKind};
 11use std::collections::HashMap;
 12use std::collections::vec_deque::VecDeque;
 13use tokio_codec::{Encoder, Decoder};
 14use minidom::Element;
 15use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
 16use xml5ever::interface::Attribute;
 17use bytes::{BytesMut, BufMut};
 18use quick_xml::Writer as EventWriter;
 19
 20/// Anything that can be sent or received on an XMPP/XML stream
 21#[derive(Debug)]
 22pub enum Packet {
 23    /// General error (`InvalidInput`)
 24    Error(Box<std::error::Error>),
 25    /// `<stream:stream>` start tag
 26    StreamStart(HashMap<String, String>),
 27    /// A complete stanza or nonza
 28    Stanza(Element),
 29    /// Plain text (think whitespace keep-alive)
 30    Text(String),
 31    /// `</stream:stream>` closing tag
 32    StreamEnd,
 33}
 34
 35/// Parser state
 36struct ParserSink {
 37    // Ready stanzas, shared with XMPPCodec
 38    queue: Rc<RefCell<VecDeque<Packet>>>,
 39    // Parsing stack
 40    stack: Vec<Element>,
 41    ns_stack: Vec<HashMap<Option<String>, String>>,
 42}
 43
 44impl ParserSink {
 45    pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
 46        ParserSink {
 47            queue,
 48            stack: vec![],
 49            ns_stack: vec![],
 50        }
 51    }
 52
 53    fn push_queue(&self, pkt: Packet) {
 54        self.queue.borrow_mut().push_back(pkt);
 55    }
 56
 57    /// Lookup XML namespace declaration for given prefix (or no prefix)
 58    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 59        for nss in self.ns_stack.iter().rev() {
 60            if let Some(ns) = nss.get(prefix) {
 61                return Some(ns);
 62            }
 63        }
 64
 65        None
 66    }
 67
 68    fn handle_start_tag(&mut self, tag: Tag) {
 69        let mut nss = HashMap::new();
 70        let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
 71            .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 72            .unwrap_or(false);
 73        for attr in &tag.attrs {
 74            match attr.name.local.as_ref() {
 75                "xmlns" => {
 76                    nss.insert(None, attr.value.as_ref().to_owned());
 77                },
 78                prefix if is_prefix_xmlns(attr) => {
 79                        nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 80                    },
 81                _ => (),
 82            }
 83        }
 84        self.ns_stack.push(nss);
 85
 86        let el = {
 87            let mut el_builder = Element::builder(tag.name.local.as_ref());
 88            if let Some(el_ns) = self.lookup_ns(
 89                &tag.name.prefix.map(|prefix|
 90                                     prefix.as_ref().to_owned())
 91            ) {
 92                el_builder = el_builder.ns(el_ns);
 93            }
 94            for attr in &tag.attrs {
 95                match attr.name.local.as_ref() {
 96                    "xmlns" => (),
 97                    _ if is_prefix_xmlns(attr) => (),
 98                    _ => {
 99                        el_builder = el_builder.attr(
100                            attr.name.local.as_ref(),
101                            attr.value.as_ref()
102                        );
103                    },
104                }
105            }
106            el_builder.build()
107        };
108
109        if self.stack.is_empty() {
110            let attrs = HashMap::from_iter(
111                tag.attrs.iter()
112                    .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
113            );
114            self.push_queue(Packet::StreamStart(attrs));
115        }
116
117        self.stack.push(el);
118    }
119
120    fn handle_end_tag(&mut self) {
121        let el = self.stack.pop().unwrap();
122        self.ns_stack.pop();
123
124        match self.stack.len() {
125            // </stream:stream>
126            0 =>
127                self.push_queue(Packet::StreamEnd),
128            // </stanza>
129            1 =>
130                self.push_queue(Packet::Stanza(el)),
131            len => {
132                let parent = &mut self.stack[len - 1];
133                parent.append_child(el);
134            },
135        }
136    }
137}
138
139impl TokenSink for ParserSink {
140    fn process_token(&mut self, token: Token) {
141        match token {
142            Token::TagToken(tag) => match tag.kind {
143                TagKind::StartTag =>
144                    self.handle_start_tag(tag),
145                TagKind::EndTag =>
146                    self.handle_end_tag(),
147                TagKind::EmptyTag => {
148                    self.handle_start_tag(tag);
149                    self.handle_end_tag();
150                },
151                TagKind::ShortTag =>
152                    self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
153            },
154            Token::CharacterTokens(tendril) =>
155                match self.stack.len() {
156                    0 | 1 =>
157                        self.push_queue(Packet::Text(tendril.into())),
158                    len => {
159                        let el = &mut self.stack[len - 1];
160                        el.append_text_node(tendril);
161                    },
162                },
163            Token::EOFToken =>
164                self.push_queue(Packet::StreamEnd),
165            Token::ParseError(s) => {
166                // println!("ParseError: {:?}", s);
167                self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
168            },
169            _ => (),
170        }
171    }
172
173    // fn end(&mut self) {
174    // }
175}
176
177/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
178pub struct XMPPCodec {
179    /// Outgoing
180    ns: Option<String>,
181    /// Incoming
182    parser: XmlTokenizer<ParserSink>,
183    /// For handling incoming truncated utf8
184    // TODO: optimize using  tendrils?
185    buf: Vec<u8>,
186    /// Shared with ParserSink
187    queue: Rc<RefCell<VecDeque<Packet>>>,
188}
189
190impl XMPPCodec {
191    /// Constructor
192    pub fn new() -> Self {
193        let queue = Rc::new(RefCell::new(VecDeque::new()));
194        let sink = ParserSink::new(queue.clone());
195        // TODO: configure parser?
196        let parser = XmlTokenizer::new(sink, Default::default());
197        XMPPCodec {
198            ns: None,
199            parser,
200            queue,
201            buf: vec![],
202        }
203    }
204}
205
206impl Default for XMPPCodec {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212impl Decoder for XMPPCodec {
213    type Item = Packet;
214    type Error = Error;
215
216    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
217        let buf1: Box<AsRef<[u8]>> =
218            if ! self.buf.is_empty() && ! buf.is_empty() {
219                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
220                prefix.extend_from_slice(buf.take().as_ref());
221                Box::new(prefix)
222            } else {
223                Box::new(buf.take())
224            };
225        let buf1 = buf1.as_ref().as_ref();
226        match from_utf8(buf1) {
227            Ok(s) => {
228                if ! s.is_empty() {
229                    // println!("<< {}", s);
230                    let tendril = FromIterator::from_iter(s.chars());
231                    self.parser.feed(tendril);
232                }
233            },
234            // Remedies for truncated utf8
235            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
236                // Prepare all the valid data
237                let mut b = BytesMut::with_capacity(e.valid_up_to());
238                b.put(&buf1[0..e.valid_up_to()]);
239
240                // Retry
241                let result = self.decode(&mut b);
242
243                // Keep the tail back in
244                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
245
246                return result;
247            },
248            Err(e) => {
249                // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
250                return Err(Error::new(ErrorKind::InvalidInput, e));
251            },
252        }
253
254        let result = self.queue.borrow_mut().pop_front();
255        Ok(result)
256    }
257
258    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
259        self.decode(buf)
260    }
261}
262
263impl Encoder for XMPPCodec {
264    type Item = Packet;
265    type Error = Error;
266
267    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
268        let remaining = dst.capacity() - dst.len();
269        let max_stanza_size: usize = 2usize.pow(16);
270        if remaining < max_stanza_size {
271            dst.reserve(max_stanza_size - remaining);
272        }
273
274        match item {
275            Packet::StreamStart(start_attrs) => {
276                let mut buf = String::new();
277                write!(buf, "<stream:stream").unwrap();
278                for (name, value) in start_attrs {
279                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
280                        .unwrap();
281                    if name == "xmlns" {
282                        self.ns = Some(value);
283                    }
284                }
285                write!(buf, ">\n").unwrap();
286
287                print!(">> {}", buf);
288                write!(dst, "{}", buf)
289                    .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
290            },
291            Packet::Stanza(stanza) => {
292                stanza.write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
293                    .and_then(|_| {
294                        // println!(">> {:?}", dst);
295                        Ok(())
296                    })
297                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
298            },
299            Packet::Text(text) => {
300                write_text(&text, dst)
301                    .and_then(|_| {
302                        // println!(">> {:?}", dst);
303                        Ok(())
304                    })
305                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
306            },
307            // TODO: Implement all
308            _ => Ok(())
309        }
310    }
311}
312
313/// Write XML-escaped text string
314pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
315    write!(writer, "{}", escape(text))
316}
317
318/// Copied from `RustyXML` for now
319pub fn escape(input: &str) -> String {
320    let mut result = String::with_capacity(input.len());
321
322    for c in input.chars() {
323        match c {
324            '&' => result.push_str("&amp;"),
325            '<' => result.push_str("&lt;"),
326            '>' => result.push_str("&gt;"),
327            '\'' => result.push_str("&apos;"),
328            '"' => result.push_str("&quot;"),
329            o => result.push(o)
330        }
331    }
332    result
333}
334
335/// BytesMut impl only std::fmt::Write but not std::io::Write. The
336/// latter trait is required for minidom's
337/// `Element::write_to_inner()`.
338struct WriteBytes<'a> {
339    dst: &'a mut BytesMut,
340}
341
342impl<'a> WriteBytes<'a> {
343    fn new(dst: &'a mut BytesMut) -> Self {
344        WriteBytes { dst }
345    }
346}
347
348impl<'a> std::io::Write for WriteBytes<'a> {
349    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
350        self.dst.put_slice(buf);
351        Ok(buf.len())
352    }
353
354    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
355        Ok(())
356    }
357}
358
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use bytes::BytesMut;
364
365    #[test]
366    fn test_stream_start() {
367        let mut c = XMPPCodec::new();
368        let mut b = BytesMut::with_capacity(1024);
369        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
370        let r = c.decode(&mut b);
371        assert!(match r {
372            Ok(Some(Packet::StreamStart(_))) => true,
373            _ => false,
374        });
375    }
376
377    #[test]
378    fn test_truncated_stanza() {
379        let mut c = XMPPCodec::new();
380        let mut b = BytesMut::with_capacity(1024);
381        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
382        let r = c.decode(&mut b);
383        assert!(match r {
384            Ok(Some(Packet::StreamStart(_))) => true,
385            _ => false,
386        });
387
388        b.clear();
389        b.put(r"<test>ß</test");
390        let r = c.decode(&mut b);
391        assert!(match r {
392            Ok(None) => true,
393            _ => false,
394        });
395
396        b.clear();
397        b.put(r">");
398        let r = c.decode(&mut b);
399        assert!(match r {
400            Ok(Some(Packet::Stanza(ref el)))
401                if el.name() == "test"
402                && el.text() == "ß"
403                => true,
404            _ => false,
405        });
406    }
407
408    #[test]
409    fn test_truncated_utf8() {
410        let mut c = XMPPCodec::new();
411        let mut b = BytesMut::with_capacity(1024);
412        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
413        let r = c.decode(&mut b);
414        assert!(match r {
415            Ok(Some(Packet::StreamStart(_))) => true,
416            _ => false,
417        });
418
419        b.clear();
420        b.put(&b"<test>\xc3"[..]);
421        let r = c.decode(&mut b);
422        assert!(match r {
423            Ok(None) => true,
424            _ => false,
425        });
426
427        b.clear();
428        b.put(&b"\x9f</test>"[..]);
429        let r = c.decode(&mut b);
430        assert!(match r {
431            Ok(Some(Packet::Stanza(ref el)))
432                if el.name() == "test"
433                && el.text() == "ß"
434                => true,
435            _ => false,
436        });
437    }
438
439    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
440    #[test]
441    fn test_large_stanza() {
442        use std::io::Cursor;
443        use futures::{Future, Sink};
444        use tokio_codec::FramedWrite;
445        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
446        let mut text = "".to_owned();
447        for _ in 0..2usize.pow(15) {
448            text = text + "A";
449        }
450        let stanza = Element::builder("message")
451            .append(
452                Element::builder("body")
453                    .append(&text)
454                    .build()
455            )
456            .build();
457        let framed = framed.send(Packet::Stanza(stanza))
458            .wait()
459            .expect("send");
460        assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
461    }
462}