xmpp_codec.rs

  1use std;
  2use std::default::Default;
  3use std::iter::FromIterator;
  4use std::cell::RefCell;
  5use std::rc::Rc;
  6use std::fmt::Write;
  7use std::str::from_utf8;
  8use std::io::{Error, ErrorKind};
  9use std::collections::HashMap;
 10use std::collections::vec_deque::VecDeque;
 11use tokio_io::codec::{Encoder, Decoder};
 12use minidom::{Element, Node};
 13use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
 14use xml5ever::interface::Attribute;
 15use bytes::{BytesMut, BufMut};
 16
 17// const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
 18
 19#[derive(Debug)]
 20pub enum Packet {
 21    Error(Box<std::error::Error>),
 22    StreamStart(HashMap<String, String>),
 23    Stanza(Element),
 24    Text(String),
 25    StreamEnd,
 26}
 27
 28struct ParserSink {
 29    // Ready stanzas, shared with XMPPCodec
 30    queue: Rc<RefCell<VecDeque<Packet>>>,
 31    // Parsing stack
 32    stack: Vec<Element>,
 33    ns_stack: Vec<HashMap<Option<String>, String>>,
 34}
 35
 36impl ParserSink {
 37    pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
 38        ParserSink {
 39            queue,
 40            stack: vec![],
 41            ns_stack: vec![],
 42        }
 43    }
 44
 45    fn push_queue(&self, pkt: Packet) {
 46        self.queue.borrow_mut().push_back(pkt);
 47    }
 48
 49    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 50        for nss in self.ns_stack.iter().rev() {
 51            match nss.get(prefix) {
 52                Some(ns) => return Some(ns),
 53                None => (),
 54            }
 55        }
 56
 57        None
 58    }
 59
 60    fn handle_start_tag(&mut self, tag: Tag) {
 61        let mut nss = HashMap::new();
 62        let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
 63            .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 64            .unwrap_or(false);
 65        for attr in &tag.attrs {
 66            match attr.name.local.as_ref() {
 67                "xmlns" => {
 68                    nss.insert(None, attr.value.as_ref().to_owned());
 69                },
 70                prefix if is_prefix_xmlns(attr) => {
 71                        nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 72                    },
 73                _ => (),
 74            }
 75        }
 76        self.ns_stack.push(nss);
 77
 78        let el = {
 79            let mut el_builder = Element::builder(tag.name.local.as_ref());
 80            match self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned())) {
 81                Some(el_ns) => el_builder = el_builder.ns(el_ns),
 82                None => (),
 83            }
 84            for attr in &tag.attrs {
 85                match attr.name.local.as_ref() {
 86                    "xmlns" => (),
 87                    _ if is_prefix_xmlns(attr) => (),
 88                    _ => {
 89                        el_builder = el_builder.attr(
 90                            attr.name.local.as_ref(),
 91                            attr.value.as_ref()
 92                        );
 93                    },
 94                }
 95            }
 96            el_builder.build()
 97        };
 98
 99        if self.stack.is_empty() {
100            let attrs = HashMap::from_iter(
101                tag.attrs.iter()
102                    .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
103            );
104            self.push_queue(Packet::StreamStart(attrs));
105        }
106
107        self.stack.push(el);
108    }
109
110    fn handle_end_tag(&mut self) {
111        let el = self.stack.pop().unwrap();
112        self.ns_stack.pop();
113
114        match self.stack.len() {
115            // </stream:stream>
116            0 =>
117                self.push_queue(Packet::StreamEnd),
118            // </stanza>
119            1 =>
120                self.push_queue(Packet::Stanza(el)),
121            len => {
122                let parent = &mut self.stack[len - 1];
123                parent.append_child(el);
124            },
125        }
126    }
127}
128
129impl TokenSink for ParserSink {
130    fn process_token(&mut self, token: Token) {
131        match token {
132            Token::TagToken(tag) => match tag.kind {
133                TagKind::StartTag =>
134                    self.handle_start_tag(tag),
135                TagKind::EndTag =>
136                    self.handle_end_tag(),
137                TagKind::EmptyTag => {
138                    self.handle_start_tag(tag);
139                    self.handle_end_tag();
140                },
141                TagKind::ShortTag =>
142                    self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
143            },
144            Token::CharacterTokens(tendril) =>
145                match self.stack.len() {
146                    0 | 1 =>
147                        self.push_queue(Packet::Text(tendril.into())),
148                    len => {
149                        let el = &mut self.stack[len - 1];
150                        el.append_text_node(tendril);
151                    },
152                },
153            Token::EOFToken =>
154                self.push_queue(Packet::StreamEnd),
155            Token::ParseError(s) => {
156                println!("ParseError: {:?}", s);
157                self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
158            },
159            _ => (),
160        }
161    }
162
163    // fn end(&mut self) {
164    // }
165}
166
167pub struct XMPPCodec {
168    /// Outgoing
169    ns: Option<String>,
170    /// Incoming
171    parser: XmlTokenizer<ParserSink>,
172    /// For handling incoming truncated utf8
173    // TODO: optimize using  tendrils?
174    buf: Vec<u8>,
175    /// Shared with ParserSink
176    queue: Rc<RefCell<VecDeque<Packet>>>,
177}
178
179impl XMPPCodec {
180    pub fn new() -> Self {
181        let queue = Rc::new(RefCell::new((VecDeque::new())));
182        let sink = ParserSink::new(queue.clone());
183        // TODO: configure parser?
184        let parser = XmlTokenizer::new(sink, Default::default());
185        XMPPCodec {
186            ns: None,
187            parser,
188            queue,
189            buf: vec![],
190        }
191    }
192}
193
194impl Decoder for XMPPCodec {
195    type Item = Packet;
196    type Error = Error;
197
198    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
199        let buf1: Box<AsRef<[u8]>> =
200            if self.buf.len() > 0 && buf.len() > 0 {
201                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
202                prefix.extend_from_slice(buf.take().as_ref());
203                Box::new(prefix)
204            } else {
205                Box::new(buf.take())
206            };
207        let buf1 = buf1.as_ref().as_ref();
208        match from_utf8(buf1) {
209            Ok(s) => {
210                if s.len() > 0 {
211                    println!("<< {}", s);
212                    let tendril = FromIterator::from_iter(s.chars());
213                    self.parser.feed(tendril);
214                }
215            },
216            // Remedies for truncated utf8
217            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
218                // Prepare all the valid data
219                let mut b = BytesMut::with_capacity(e.valid_up_to());
220                b.put(&buf1[0..e.valid_up_to()]);
221
222                // Retry
223                let result = self.decode(&mut b);
224
225                // Keep the tail back in
226                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
227
228                return result;
229            },
230            Err(e) => {
231                println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
232                return Err(Error::new(ErrorKind::InvalidInput, e));
233            },
234        }
235
236        let result = self.queue.borrow_mut().pop_front();
237        Ok(result)
238    }
239
240    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
241        self.decode(buf)
242    }
243}
244
245impl Encoder for XMPPCodec {
246    type Item = Packet;
247    type Error = Error;
248
249    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
250        let remaining = dst.capacity() - dst.len();
251        let max_stanza_size: usize = 2usize.pow(16);
252        if remaining < max_stanza_size {
253            dst.reserve(max_stanza_size - remaining);
254        }
255
256        match item {
257            Packet::StreamStart(start_attrs) => {
258                let mut buf = String::new();
259                write!(buf, "<stream:stream").unwrap();
260                for (name, value) in start_attrs.into_iter() {
261                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
262                        .unwrap();
263                    if name == "xmlns" {
264                        self.ns = Some(value);
265                    }
266                }
267                write!(buf, ">\n").unwrap();
268
269                print!(">> {}", buf);
270                write!(dst, "{}", buf)
271                    .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
272            },
273            Packet::Stanza(stanza) => {
274                let root_ns = self.ns.as_ref().map(|s| s.as_ref());
275                write_element(&stanza, dst, root_ns)
276                    .and_then(|_| {
277                        println!(">> {:?}", dst);
278                        Ok(())
279                    })
280                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
281            },
282            Packet::Text(text) => {
283                write_text(&text, dst)
284                    .and_then(|_| {
285                        println!(">> {:?}", dst);
286                        Ok(())
287                    })
288                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
289            },
290            // TODO: Implement all
291            _ => Ok(())
292        }
293    }
294}
295
296pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
297    write!(writer, "{}", text)
298}
299
300// TODO: escape everything?
301pub fn write_element<W: Write>(el: &Element, writer: &mut W, parent_ns: Option<&str>) -> Result<(), std::fmt::Error> {
302    write!(writer, "<")?;
303    write!(writer, "{}", el.name())?;
304
305    if let Some(ref ns) = el.ns() {
306        if parent_ns.map(|s| s.as_ref()) != el.ns() {
307            write!(writer, " xmlns=\"{}\"", ns)?;
308        }
309    }
310
311    for (key, value) in el.attrs() {
312        write!(writer, " {}=\"{}\"", key, value)?;
313    }
314
315    if ! el.nodes().any(|_| true) {
316        write!(writer, " />")?;
317        return Ok(())
318    }
319
320    write!(writer, ">")?;
321
322    for node in el.nodes() {
323        match node {
324            &Node::Element(ref child) =>
325                write_element(child, writer, el.ns())?,
326            &Node::Text(ref text) =>
327                write_text(text, writer)?,
328        }
329    }
330
331    write!(writer, "</{}>", el.name())?;
332    Ok(())
333}
334
335/// Copied from RustyXML for now
336pub fn escape(input: &str) -> String {
337    let mut result = String::with_capacity(input.len());
338
339    for c in input.chars() {
340        match c {
341            '&' => result.push_str("&amp;"),
342            '<' => result.push_str("&lt;"),
343            '>' => result.push_str("&gt;"),
344            '\'' => result.push_str("&apos;"),
345            '"' => result.push_str("&quot;"),
346            o => result.push(o)
347        }
348    }
349    result
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use bytes::BytesMut;
356
357    #[test]
358    fn test_stream_start() {
359        let mut c = XMPPCodec::new();
360        let mut b = BytesMut::with_capacity(1024);
361        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
362        let r = c.decode(&mut b);
363        assert!(match r {
364            Ok(Some(Packet::StreamStart(_))) => true,
365            _ => false,
366        });
367    }
368
369    #[test]
370    fn test_truncated_stanza() {
371        let mut c = XMPPCodec::new();
372        let mut b = BytesMut::with_capacity(1024);
373        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
374        let r = c.decode(&mut b);
375        assert!(match r {
376            Ok(Some(Packet::StreamStart(_))) => true,
377            _ => false,
378        });
379
380        b.clear();
381        b.put(r"<test>ß</test");
382        let r = c.decode(&mut b);
383        assert!(match r {
384            Ok(None) => true,
385            _ => false,
386        });
387
388        b.clear();
389        b.put(r">");
390        let r = c.decode(&mut b);
391        assert!(match r {
392            Ok(Some(Packet::Stanza(ref el)))
393                if el.name() == "test"
394                && el.text() == "ß"
395                => true,
396            _ => false,
397        });
398    }
399
400    #[test]
401    fn test_truncated_utf8() {
402        let mut c = XMPPCodec::new();
403        let mut b = BytesMut::with_capacity(1024);
404        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
405        let r = c.decode(&mut b);
406        assert!(match r {
407            Ok(Some(Packet::StreamStart(_))) => true,
408            _ => false,
409        });
410
411        b.clear();
412        b.put(&b"<test>\xc3"[..]);
413        let r = c.decode(&mut b);
414        assert!(match r {
415            Ok(None) => true,
416            _ => false,
417        });
418
419        b.clear();
420        b.put(&b"\x9f</test>"[..]);
421        let r = c.decode(&mut b);
422        assert!(match r {
423            Ok(Some(Packet::Stanza(ref el)))
424                if el.name() == "test"
425                && el.text() == "ß"
426                => true,
427            _ => false,
428        });
429    }
430
431    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
432    #[test]
433    fn test_large_stanza() {
434        use std::io::Cursor;
435        use futures::{Future, Sink};
436        use tokio_io::codec::FramedWrite;
437        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
438        let mut text = "".to_owned();
439        for _ in 0..2usize.pow(15) {
440            text = text + "A";
441        }
442        let stanza = Element::builder("message")
443            .append(
444                Element::builder("body")
445                    .append(&text)
446                    .build()
447            )
448            .build();
449        let framed = framed.send(Packet::Stanza(stanza))
450            .wait()
451            .expect("send");
452        assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
453    }
454}