xmpp_codec.rs

  1use std;
  2use std::default::Default;
  3use std::iter::FromIterator;
  4use std::cell::RefCell;
  5use std::rc::Rc;
  6use std::fmt::Write;
  7use std::str::from_utf8;
  8use std::io::{Error, ErrorKind};
  9use std::collections::HashMap;
 10use std::collections::vec_deque::VecDeque;
 11use tokio_io::codec::{Encoder, Decoder};
 12use minidom::Element;
 13use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
 14use xml5ever::interface::Attribute;
 15use bytes::{BytesMut, BufMut};
 16use quick_xml::Writer as EventWriter;
 17
 18// const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
 19
 20#[derive(Debug)]
 21pub enum Packet {
 22    Error(Box<std::error::Error>),
 23    StreamStart(HashMap<String, String>),
 24    Stanza(Element),
 25    Text(String),
 26    StreamEnd,
 27}
 28
 29struct ParserSink {
 30    // Ready stanzas, shared with XMPPCodec
 31    queue: Rc<RefCell<VecDeque<Packet>>>,
 32    // Parsing stack
 33    stack: Vec<Element>,
 34    ns_stack: Vec<HashMap<Option<String>, String>>,
 35}
 36
 37impl ParserSink {
 38    pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
 39        ParserSink {
 40            queue,
 41            stack: vec![],
 42            ns_stack: vec![],
 43        }
 44    }
 45
 46    fn push_queue(&self, pkt: Packet) {
 47        self.queue.borrow_mut().push_back(pkt);
 48    }
 49
 50    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 51        for nss in self.ns_stack.iter().rev() {
 52            if let Some(ns) = nss.get(prefix) {
 53                return Some(ns);
 54            }
 55        }
 56
 57        None
 58    }
 59
 60    fn handle_start_tag(&mut self, tag: Tag) {
 61        let mut nss = HashMap::new();
 62        let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
 63            .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 64            .unwrap_or(false);
 65        for attr in &tag.attrs {
 66            match attr.name.local.as_ref() {
 67                "xmlns" => {
 68                    nss.insert(None, attr.value.as_ref().to_owned());
 69                },
 70                prefix if is_prefix_xmlns(attr) => {
 71                        nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 72                    },
 73                _ => (),
 74            }
 75        }
 76        self.ns_stack.push(nss);
 77
 78        let el = {
 79            let mut el_builder = Element::builder(tag.name.local.as_ref());
 80            if let Some(el_ns) = self.lookup_ns(
 81                &tag.name.prefix.map(|prefix|
 82                                     prefix.as_ref().to_owned())
 83            ) {
 84                el_builder = el_builder.ns(el_ns);
 85            }
 86            for attr in &tag.attrs {
 87                match attr.name.local.as_ref() {
 88                    "xmlns" => (),
 89                    _ if is_prefix_xmlns(attr) => (),
 90                    _ => {
 91                        el_builder = el_builder.attr(
 92                            attr.name.local.as_ref(),
 93                            attr.value.as_ref()
 94                        );
 95                    },
 96                }
 97            }
 98            el_builder.build()
 99        };
100
101        if self.stack.is_empty() {
102            let attrs = HashMap::from_iter(
103                tag.attrs.iter()
104                    .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
105            );
106            self.push_queue(Packet::StreamStart(attrs));
107        }
108
109        self.stack.push(el);
110    }
111
112    fn handle_end_tag(&mut self) {
113        let el = self.stack.pop().unwrap();
114        self.ns_stack.pop();
115
116        match self.stack.len() {
117            // </stream:stream>
118            0 =>
119                self.push_queue(Packet::StreamEnd),
120            // </stanza>
121            1 =>
122                self.push_queue(Packet::Stanza(el)),
123            len => {
124                let parent = &mut self.stack[len - 1];
125                parent.append_child(el);
126            },
127        }
128    }
129}
130
131impl TokenSink for ParserSink {
132    fn process_token(&mut self, token: Token) {
133        match token {
134            Token::TagToken(tag) => match tag.kind {
135                TagKind::StartTag =>
136                    self.handle_start_tag(tag),
137                TagKind::EndTag =>
138                    self.handle_end_tag(),
139                TagKind::EmptyTag => {
140                    self.handle_start_tag(tag);
141                    self.handle_end_tag();
142                },
143                TagKind::ShortTag =>
144                    self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
145            },
146            Token::CharacterTokens(tendril) =>
147                match self.stack.len() {
148                    0 | 1 =>
149                        self.push_queue(Packet::Text(tendril.into())),
150                    len => {
151                        let el = &mut self.stack[len - 1];
152                        el.append_text_node(tendril);
153                    },
154                },
155            Token::EOFToken =>
156                self.push_queue(Packet::StreamEnd),
157            Token::ParseError(s) => {
158                println!("ParseError: {:?}", s);
159                self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
160            },
161            _ => (),
162        }
163    }
164
165    // fn end(&mut self) {
166    // }
167}
168
169pub struct XMPPCodec {
170    /// Outgoing
171    ns: Option<String>,
172    /// Incoming
173    parser: XmlTokenizer<ParserSink>,
174    /// For handling incoming truncated utf8
175    // TODO: optimize using  tendrils?
176    buf: Vec<u8>,
177    /// Shared with ParserSink
178    queue: Rc<RefCell<VecDeque<Packet>>>,
179}
180
181impl XMPPCodec {
182    pub fn new() -> Self {
183        let queue = Rc::new(RefCell::new(VecDeque::new()));
184        let sink = ParserSink::new(queue.clone());
185        // TODO: configure parser?
186        let parser = XmlTokenizer::new(sink, Default::default());
187        XMPPCodec {
188            ns: None,
189            parser,
190            queue,
191            buf: vec![],
192        }
193    }
194}
195
196impl Default for XMPPCodec {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202impl Decoder for XMPPCodec {
203    type Item = Packet;
204    type Error = Error;
205
206    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
207        let buf1: Box<AsRef<[u8]>> =
208            if ! self.buf.is_empty() && ! buf.is_empty() {
209                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
210                prefix.extend_from_slice(buf.take().as_ref());
211                Box::new(prefix)
212            } else {
213                Box::new(buf.take())
214            };
215        let buf1 = buf1.as_ref().as_ref();
216        match from_utf8(buf1) {
217            Ok(s) => {
218                if ! s.is_empty() {
219                    println!("<< {}", s);
220                    let tendril = FromIterator::from_iter(s.chars());
221                    self.parser.feed(tendril);
222                }
223            },
224            // Remedies for truncated utf8
225            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
226                // Prepare all the valid data
227                let mut b = BytesMut::with_capacity(e.valid_up_to());
228                b.put(&buf1[0..e.valid_up_to()]);
229
230                // Retry
231                let result = self.decode(&mut b);
232
233                // Keep the tail back in
234                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
235
236                return result;
237            },
238            Err(e) => {
239                println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
240                return Err(Error::new(ErrorKind::InvalidInput, e));
241            },
242        }
243
244        let result = self.queue.borrow_mut().pop_front();
245        Ok(result)
246    }
247
248    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
249        self.decode(buf)
250    }
251}
252
253impl Encoder for XMPPCodec {
254    type Item = Packet;
255    type Error = Error;
256
257    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
258        let remaining = dst.capacity() - dst.len();
259        let max_stanza_size: usize = 2usize.pow(16);
260        if remaining < max_stanza_size {
261            dst.reserve(max_stanza_size - remaining);
262        }
263
264        match item {
265            Packet::StreamStart(start_attrs) => {
266                let mut buf = String::new();
267                write!(buf, "<stream:stream").unwrap();
268                for (name, value) in start_attrs {
269                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
270                        .unwrap();
271                    if name == "xmlns" {
272                        self.ns = Some(value);
273                    }
274                }
275                write!(buf, ">\n").unwrap();
276
277                print!(">> {}", buf);
278                write!(dst, "{}", buf)
279                    .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
280            },
281            Packet::Stanza(stanza) => {
282                stanza.write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
283                    .and_then(|_| {
284                        println!(">> {:?}", dst);
285                        Ok(())
286                    })
287                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
288            },
289            Packet::Text(text) => {
290                write_text(&text, dst)
291                    .and_then(|_| {
292                        println!(">> {:?}", dst);
293                        Ok(())
294                    })
295                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
296            },
297            // TODO: Implement all
298            _ => Ok(())
299        }
300    }
301}
302
303pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
304    write!(writer, "{}", escape(text))
305}
306
307/// Copied from `RustyXML` for now
308pub fn escape(input: &str) -> String {
309    let mut result = String::with_capacity(input.len());
310
311    for c in input.chars() {
312        match c {
313            '&' => result.push_str("&amp;"),
314            '<' => result.push_str("&lt;"),
315            '>' => result.push_str("&gt;"),
316            '\'' => result.push_str("&apos;"),
317            '"' => result.push_str("&quot;"),
318            o => result.push(o)
319        }
320    }
321    result
322}
323
324/// BytesMut impl only std::fmt::Write but not std::io::Write. The
325/// latter trait is required for minidom's
326/// `Element::write_to_inner()`.
327struct WriteBytes<'a> {
328    dst: &'a mut BytesMut,
329}
330
331impl<'a> WriteBytes<'a> {
332    fn new(dst: &'a mut BytesMut) -> Self {
333        WriteBytes { dst }
334    }
335}
336
337impl<'a> std::io::Write for WriteBytes<'a> {
338    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
339        self.dst.put_slice(buf);
340        Ok(buf.len())
341    }
342
343    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
344        Ok(())
345    }
346}
347
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use bytes::BytesMut;
353
354    #[test]
355    fn test_stream_start() {
356        let mut c = XMPPCodec::new();
357        let mut b = BytesMut::with_capacity(1024);
358        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
359        let r = c.decode(&mut b);
360        assert!(match r {
361            Ok(Some(Packet::StreamStart(_))) => true,
362            _ => false,
363        });
364    }
365
366    #[test]
367    fn test_truncated_stanza() {
368        let mut c = XMPPCodec::new();
369        let mut b = BytesMut::with_capacity(1024);
370        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
371        let r = c.decode(&mut b);
372        assert!(match r {
373            Ok(Some(Packet::StreamStart(_))) => true,
374            _ => false,
375        });
376
377        b.clear();
378        b.put(r"<test>ß</test");
379        let r = c.decode(&mut b);
380        assert!(match r {
381            Ok(None) => true,
382            _ => false,
383        });
384
385        b.clear();
386        b.put(r">");
387        let r = c.decode(&mut b);
388        assert!(match r {
389            Ok(Some(Packet::Stanza(ref el)))
390                if el.name() == "test"
391                && el.text() == "ß"
392                => true,
393            _ => false,
394        });
395    }
396
397    #[test]
398    fn test_truncated_utf8() {
399        let mut c = XMPPCodec::new();
400        let mut b = BytesMut::with_capacity(1024);
401        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
402        let r = c.decode(&mut b);
403        assert!(match r {
404            Ok(Some(Packet::StreamStart(_))) => true,
405            _ => false,
406        });
407
408        b.clear();
409        b.put(&b"<test>\xc3"[..]);
410        let r = c.decode(&mut b);
411        assert!(match r {
412            Ok(None) => true,
413            _ => false,
414        });
415
416        b.clear();
417        b.put(&b"\x9f</test>"[..]);
418        let r = c.decode(&mut b);
419        assert!(match r {
420            Ok(Some(Packet::Stanza(ref el)))
421                if el.name() == "test"
422                && el.text() == "ß"
423                => true,
424            _ => false,
425        });
426    }
427
428    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
429    #[test]
430    fn test_large_stanza() {
431        use std::io::Cursor;
432        use futures::{Future, Sink};
433        use tokio_io::codec::FramedWrite;
434        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
435        let mut text = "".to_owned();
436        for _ in 0..2usize.pow(15) {
437            text = text + "A";
438        }
439        let stanza = Element::builder("message")
440            .append(
441                Element::builder("body")
442                    .append(&text)
443                    .build()
444            )
445            .build();
446        let framed = framed.send(Packet::Stanza(stanza))
447            .wait()
448            .expect("send");
449        assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
450    }
451}