xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use std;
  4use std::default::Default;
  5use std::iter::FromIterator;
  6use std::cell::RefCell;
  7use std::rc::Rc;
  8use std::fmt::Write;
  9use std::str::from_utf8;
 10use std::io;
 11use std::collections::HashMap;
 12use std::collections::vec_deque::VecDeque;
 13use tokio_codec::{Encoder, Decoder};
 14use minidom::Element;
 15use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
 16use xml5ever::interface::Attribute;
 17use bytes::{BytesMut, BufMut};
 18use quick_xml::Writer as EventWriter;
 19use crate::{ParserError, ParseError};
 20
 21/// Anything that can be sent or received on an XMPP/XML stream
 22#[derive(Debug)]
 23pub enum Packet {
 24    /// `<stream:stream>` start tag
 25    StreamStart(HashMap<String, String>),
 26    /// A complete stanza or nonza
 27    Stanza(Element),
 28    /// Plain text (think whitespace keep-alive)
 29    Text(String),
 30    /// `</stream:stream>` closing tag
 31    StreamEnd,
 32}
 33
 34type QueueItem = Result<Packet, ParserError>;
 35
 36/// Parser state
 37struct ParserSink {
 38    // Ready stanzas, shared with XMPPCodec
 39    queue: Rc<RefCell<VecDeque<QueueItem>>>,
 40    // Parsing stack
 41    stack: Vec<Element>,
 42    ns_stack: Vec<HashMap<Option<String>, String>>,
 43}
 44
 45impl ParserSink {
 46    pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
 47        ParserSink {
 48            queue,
 49            stack: vec![],
 50            ns_stack: vec![],
 51        }
 52    }
 53
 54    fn push_queue(&self, pkt: Packet) {
 55        self.queue.borrow_mut().push_back(Ok(pkt));
 56    }
 57
 58    fn push_queue_error(&self, e: ParserError) {
 59        self.queue.borrow_mut().push_back(Err(e));
 60    }
 61
 62    /// Lookup XML namespace declaration for given prefix (or no prefix)
 63    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 64        for nss in self.ns_stack.iter().rev() {
 65            if let Some(ns) = nss.get(prefix) {
 66                return Some(ns);
 67            }
 68        }
 69
 70        None
 71    }
 72
 73    fn handle_start_tag(&mut self, tag: Tag) {
 74        let mut nss = HashMap::new();
 75        let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
 76            .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 77            .unwrap_or(false);
 78        for attr in &tag.attrs {
 79            match attr.name.local.as_ref() {
 80                "xmlns" => {
 81                    nss.insert(None, attr.value.as_ref().to_owned());
 82                },
 83                prefix if is_prefix_xmlns(attr) => {
 84                        nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 85                    },
 86                _ => (),
 87            }
 88        }
 89        self.ns_stack.push(nss);
 90
 91        let el = {
 92            let mut el_builder = Element::builder(tag.name.local.as_ref());
 93            if let Some(el_ns) = self.lookup_ns(
 94                &tag.name.prefix.map(|prefix|
 95                                     prefix.as_ref().to_owned())
 96            ) {
 97                el_builder = el_builder.ns(el_ns);
 98            }
 99            for attr in &tag.attrs {
100                match attr.name.local.as_ref() {
101                    "xmlns" => (),
102                    _ if is_prefix_xmlns(attr) => (),
103                    _ => {
104                        el_builder = el_builder.attr(
105                            attr.name.local.as_ref(),
106                            attr.value.as_ref()
107                        );
108                    },
109                }
110            }
111            el_builder.build()
112        };
113
114        if self.stack.is_empty() {
115            let attrs = HashMap::from_iter(
116                tag.attrs.iter()
117                    .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
118            );
119            self.push_queue(Packet::StreamStart(attrs));
120        }
121
122        self.stack.push(el);
123    }
124
125    fn handle_end_tag(&mut self) {
126        let el = self.stack.pop().unwrap();
127        self.ns_stack.pop();
128
129        match self.stack.len() {
130            // </stream:stream>
131            0 =>
132                self.push_queue(Packet::StreamEnd),
133            // </stanza>
134            1 =>
135                self.push_queue(Packet::Stanza(el)),
136            len => {
137                let parent = &mut self.stack[len - 1];
138                parent.append_child(el);
139            },
140        }
141    }
142}
143
144impl TokenSink for ParserSink {
145    fn process_token(&mut self, token: Token) {
146        match token {
147            Token::TagToken(tag) => match tag.kind {
148                TagKind::StartTag =>
149                    self.handle_start_tag(tag),
150                TagKind::EndTag =>
151                    self.handle_end_tag(),
152                TagKind::EmptyTag => {
153                    self.handle_start_tag(tag);
154                    self.handle_end_tag();
155                },
156                TagKind::ShortTag =>
157                    self.push_queue_error(ParserError::ShortTag),
158            },
159            Token::CharacterTokens(tendril) =>
160                match self.stack.len() {
161                    0 | 1 =>
162                        self.push_queue(Packet::Text(tendril.into())),
163                    len => {
164                        let el = &mut self.stack[len - 1];
165                        el.append_text_node(tendril);
166                    },
167                },
168            Token::EOFToken =>
169                self.push_queue(Packet::StreamEnd),
170            Token::ParseError(s) => {
171                // println!("ParseError: {:?}", s);
172                self.push_queue_error(ParserError::Parse(ParseError(s)));
173            },
174            _ => (),
175        }
176    }
177
178    // fn end(&mut self) {
179    // }
180}
181
182/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
183pub struct XMPPCodec {
184    /// Outgoing
185    ns: Option<String>,
186    /// Incoming
187    parser: XmlTokenizer<ParserSink>,
188    /// For handling incoming truncated utf8
189    // TODO: optimize using  tendrils?
190    buf: Vec<u8>,
191    /// Shared with ParserSink
192    queue: Rc<RefCell<VecDeque<QueueItem>>>,
193}
194
195impl XMPPCodec {
196    /// Constructor
197    pub fn new() -> Self {
198        let queue = Rc::new(RefCell::new(VecDeque::new()));
199        let sink = ParserSink::new(queue.clone());
200        // TODO: configure parser?
201        let parser = XmlTokenizer::new(sink, Default::default());
202        XMPPCodec {
203            ns: None,
204            parser,
205            queue,
206            buf: vec![],
207        }
208    }
209}
210
211impl Default for XMPPCodec {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl Decoder for XMPPCodec {
218    type Item = Packet;
219    type Error = ParserError;
220
221    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
222        let buf1: Box<AsRef<[u8]>> =
223            if ! self.buf.is_empty() && ! buf.is_empty() {
224                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
225                prefix.extend_from_slice(buf.take().as_ref());
226                Box::new(prefix)
227            } else {
228                Box::new(buf.take())
229            };
230        let buf1 = buf1.as_ref().as_ref();
231        match from_utf8(buf1) {
232            Ok(s) => {
233                if ! s.is_empty() {
234                    // println!("<< {}", s);
235                    let tendril = FromIterator::from_iter(s.chars());
236                    self.parser.feed(tendril);
237                }
238            },
239            // Remedies for truncated utf8
240            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
241                // Prepare all the valid data
242                let mut b = BytesMut::with_capacity(e.valid_up_to());
243                b.put(&buf1[0..e.valid_up_to()]);
244
245                // Retry
246                let result = self.decode(&mut b);
247
248                // Keep the tail back in
249                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
250
251                return result;
252            },
253            Err(e) => {
254                // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
255                return Err(ParserError::Utf8(e));
256            },
257        }
258
259        match self.queue.borrow_mut().pop_front() {
260            None => Ok(None),
261            Some(result) =>
262                result.map(|pkt| Some(pkt)),
263        }
264    }
265
266    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
267        self.decode(buf)
268    }
269}
270
271impl Encoder for XMPPCodec {
272    type Item = Packet;
273    type Error = io::Error;
274
275    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
276        let remaining = dst.capacity() - dst.len();
277        let max_stanza_size: usize = 2usize.pow(16);
278        if remaining < max_stanza_size {
279            dst.reserve(max_stanza_size - remaining);
280        }
281
282        match item {
283            Packet::StreamStart(start_attrs) => {
284                let mut buf = String::new();
285                write!(buf, "<stream:stream").unwrap();
286                for (name, value) in start_attrs {
287                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
288                        .unwrap();
289                    if name == "xmlns" {
290                        self.ns = Some(value);
291                    }
292                }
293                write!(buf, ">\n").unwrap();
294
295                print!(">> {}", buf);
296                write!(dst, "{}", buf)
297                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
298            },
299            Packet::Stanza(stanza) => {
300                stanza.write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
301                    .and_then(|_| {
302                        // println!(">> {:?}", dst);
303                        Ok(())
304                    })
305                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
306            },
307            Packet::Text(text) => {
308                write_text(&text, dst)
309                    .and_then(|_| {
310                        // println!(">> {:?}", dst);
311                        Ok(())
312                    })
313                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
314            },
315            // TODO: Implement all
316            _ => Ok(())
317        }
318    }
319}
320
321/// Write XML-escaped text string
322pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
323    write!(writer, "{}", escape(text))
324}
325
326/// Copied from `RustyXML` for now
327pub fn escape(input: &str) -> String {
328    let mut result = String::with_capacity(input.len());
329
330    for c in input.chars() {
331        match c {
332            '&' => result.push_str("&amp;"),
333            '<' => result.push_str("&lt;"),
334            '>' => result.push_str("&gt;"),
335            '\'' => result.push_str("&apos;"),
336            '"' => result.push_str("&quot;"),
337            o => result.push(o)
338        }
339    }
340    result
341}
342
343/// BytesMut impl only std::fmt::Write but not std::io::Write. The
344/// latter trait is required for minidom's
345/// `Element::write_to_inner()`.
346struct WriteBytes<'a> {
347    dst: &'a mut BytesMut,
348}
349
350impl<'a> WriteBytes<'a> {
351    fn new(dst: &'a mut BytesMut) -> Self {
352        WriteBytes { dst }
353    }
354}
355
356impl<'a> std::io::Write for WriteBytes<'a> {
357    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
358        self.dst.put_slice(buf);
359        Ok(buf.len())
360    }
361
362    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
363        Ok(())
364    }
365}
366
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use bytes::BytesMut;
372
373    #[test]
374    fn test_stream_start() {
375        let mut c = XMPPCodec::new();
376        let mut b = BytesMut::with_capacity(1024);
377        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
378        let r = c.decode(&mut b);
379        assert!(match r {
380            Ok(Some(Packet::StreamStart(_))) => true,
381            _ => false,
382        });
383    }
384
385    #[test]
386    fn test_truncated_stanza() {
387        let mut c = XMPPCodec::new();
388        let mut b = BytesMut::with_capacity(1024);
389        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
390        let r = c.decode(&mut b);
391        assert!(match r {
392            Ok(Some(Packet::StreamStart(_))) => true,
393            _ => false,
394        });
395
396        b.clear();
397        b.put(r"<test>ß</test");
398        let r = c.decode(&mut b);
399        assert!(match r {
400            Ok(None) => true,
401            _ => false,
402        });
403
404        b.clear();
405        b.put(r">");
406        let r = c.decode(&mut b);
407        assert!(match r {
408            Ok(Some(Packet::Stanza(ref el)))
409                if el.name() == "test"
410                && el.text() == "ß"
411                => true,
412            _ => false,
413        });
414    }
415
416    #[test]
417    fn test_truncated_utf8() {
418        let mut c = XMPPCodec::new();
419        let mut b = BytesMut::with_capacity(1024);
420        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
421        let r = c.decode(&mut b);
422        assert!(match r {
423            Ok(Some(Packet::StreamStart(_))) => true,
424            _ => false,
425        });
426
427        b.clear();
428        b.put(&b"<test>\xc3"[..]);
429        let r = c.decode(&mut b);
430        assert!(match r {
431            Ok(None) => true,
432            _ => false,
433        });
434
435        b.clear();
436        b.put(&b"\x9f</test>"[..]);
437        let r = c.decode(&mut b);
438        assert!(match r {
439            Ok(Some(Packet::Stanza(ref el)))
440                if el.name() == "test"
441                && el.text() == "ß"
442                => true,
443            _ => false,
444        });
445    }
446
447    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
448    #[test]
449    fn test_large_stanza() {
450        use std::io::Cursor;
451        use futures::{Future, Sink};
452        use tokio_codec::FramedWrite;
453        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
454        let mut text = "".to_owned();
455        for _ in 0..2usize.pow(15) {
456            text = text + "A";
457        }
458        let stanza = Element::builder("message")
459            .append(
460                Element::builder("body")
461                    .append(&text)
462                    .build()
463            )
464            .build();
465        let framed = framed.send(Packet::Stanza(stanza))
466            .wait()
467            .expect("send");
468        assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
469    }
470}