xmpp_codec.rs

  1use std;
  2use std::default::Default;
  3use std::iter::FromIterator;
  4use std::cell::RefCell;
  5use std::rc::Rc;
  6use std::fmt::Write;
  7use std::str::from_utf8;
  8use std::io::{Error, ErrorKind};
  9use std::collections::HashMap;
 10use std::collections::vec_deque::VecDeque;
 11use tokio_io::codec::{Encoder, Decoder};
 12use minidom::{Element, Node};
 13use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
 14use xml5ever::interface::Attribute;
 15use bytes::{BytesMut, BufMut};
 16
 17// const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
 18
 19#[derive(Debug)]
 20pub enum Packet {
 21    Error(Box<std::error::Error>),
 22    StreamStart(HashMap<String, String>),
 23    Stanza(Element),
 24    Text(String),
 25    StreamEnd,
 26}
 27
 28struct ParserSink {
 29    // Ready stanzas, shared with XMPPCodec
 30    queue: Rc<RefCell<VecDeque<Packet>>>,
 31    // Parsing stack
 32    stack: Vec<Element>,
 33    ns_stack: Vec<HashMap<Option<String>, String>>,
 34}
 35
 36impl ParserSink {
 37    pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
 38        ParserSink {
 39            queue,
 40            stack: vec![],
 41            ns_stack: vec![],
 42        }
 43    }
 44
 45    fn push_queue(&self, pkt: Packet) {
 46        self.queue.borrow_mut().push_back(pkt);
 47    }
 48
 49    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 50        for nss in self.ns_stack.iter().rev() {
 51            if let Some(ns) = nss.get(prefix) {
 52                return Some(ns);
 53            }
 54        }
 55
 56        None
 57    }
 58
 59    fn handle_start_tag(&mut self, tag: Tag) {
 60        let mut nss = HashMap::new();
 61        let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
 62            .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 63            .unwrap_or(false);
 64        for attr in &tag.attrs {
 65            match attr.name.local.as_ref() {
 66                "xmlns" => {
 67                    nss.insert(None, attr.value.as_ref().to_owned());
 68                },
 69                prefix if is_prefix_xmlns(attr) => {
 70                        nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 71                    },
 72                _ => (),
 73            }
 74        }
 75        self.ns_stack.push(nss);
 76
 77        let el = {
 78            let mut el_builder = Element::builder(tag.name.local.as_ref());
 79            if let Some(el_ns) = self.lookup_ns(
 80                &tag.name.prefix.map(|prefix|
 81                                     prefix.as_ref().to_owned())
 82            ) {
 83                el_builder = el_builder.ns(el_ns);
 84            }
 85            for attr in &tag.attrs {
 86                match attr.name.local.as_ref() {
 87                    "xmlns" => (),
 88                    _ if is_prefix_xmlns(attr) => (),
 89                    _ => {
 90                        el_builder = el_builder.attr(
 91                            attr.name.local.as_ref(),
 92                            attr.value.as_ref()
 93                        );
 94                    },
 95                }
 96            }
 97            el_builder.build()
 98        };
 99
100        if self.stack.is_empty() {
101            let attrs = HashMap::from_iter(
102                tag.attrs.iter()
103                    .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
104            );
105            self.push_queue(Packet::StreamStart(attrs));
106        }
107
108        self.stack.push(el);
109    }
110
111    fn handle_end_tag(&mut self) {
112        let el = self.stack.pop().unwrap();
113        self.ns_stack.pop();
114
115        match self.stack.len() {
116            // </stream:stream>
117            0 =>
118                self.push_queue(Packet::StreamEnd),
119            // </stanza>
120            1 =>
121                self.push_queue(Packet::Stanza(el)),
122            len => {
123                let parent = &mut self.stack[len - 1];
124                parent.append_child(el);
125            },
126        }
127    }
128}
129
130impl TokenSink for ParserSink {
131    fn process_token(&mut self, token: Token) {
132        match token {
133            Token::TagToken(tag) => match tag.kind {
134                TagKind::StartTag =>
135                    self.handle_start_tag(tag),
136                TagKind::EndTag =>
137                    self.handle_end_tag(),
138                TagKind::EmptyTag => {
139                    self.handle_start_tag(tag);
140                    self.handle_end_tag();
141                },
142                TagKind::ShortTag =>
143                    self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
144            },
145            Token::CharacterTokens(tendril) =>
146                match self.stack.len() {
147                    0 | 1 =>
148                        self.push_queue(Packet::Text(tendril.into())),
149                    len => {
150                        let el = &mut self.stack[len - 1];
151                        el.append_text_node(tendril);
152                    },
153                },
154            Token::EOFToken =>
155                self.push_queue(Packet::StreamEnd),
156            Token::ParseError(s) => {
157                println!("ParseError: {:?}", s);
158                self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
159            },
160            _ => (),
161        }
162    }
163
164    // fn end(&mut self) {
165    // }
166}
167
168pub struct XMPPCodec {
169    /// Outgoing
170    ns: Option<String>,
171    /// Incoming
172    parser: XmlTokenizer<ParserSink>,
173    /// For handling incoming truncated utf8
174    // TODO: optimize using  tendrils?
175    buf: Vec<u8>,
176    /// Shared with ParserSink
177    queue: Rc<RefCell<VecDeque<Packet>>>,
178}
179
180impl XMPPCodec {
181    pub fn new() -> Self {
182        let queue = Rc::new(RefCell::new(VecDeque::new()));
183        let sink = ParserSink::new(queue.clone());
184        // TODO: configure parser?
185        let parser = XmlTokenizer::new(sink, Default::default());
186        XMPPCodec {
187            ns: None,
188            parser,
189            queue,
190            buf: vec![],
191        }
192    }
193}
194
195impl Default for XMPPCodec {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201impl Decoder for XMPPCodec {
202    type Item = Packet;
203    type Error = Error;
204
205    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
206        let buf1: Box<AsRef<[u8]>> =
207            if ! self.buf.is_empty() && ! buf.is_empty() {
208                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
209                prefix.extend_from_slice(buf.take().as_ref());
210                Box::new(prefix)
211            } else {
212                Box::new(buf.take())
213            };
214        let buf1 = buf1.as_ref().as_ref();
215        match from_utf8(buf1) {
216            Ok(s) => {
217                if ! s.is_empty() {
218                    println!("<< {}", s);
219                    let tendril = FromIterator::from_iter(s.chars());
220                    self.parser.feed(tendril);
221                }
222            },
223            // Remedies for truncated utf8
224            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
225                // Prepare all the valid data
226                let mut b = BytesMut::with_capacity(e.valid_up_to());
227                b.put(&buf1[0..e.valid_up_to()]);
228
229                // Retry
230                let result = self.decode(&mut b);
231
232                // Keep the tail back in
233                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
234
235                return result;
236            },
237            Err(e) => {
238                println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
239                return Err(Error::new(ErrorKind::InvalidInput, e));
240            },
241        }
242
243        let result = self.queue.borrow_mut().pop_front();
244        Ok(result)
245    }
246
247    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
248        self.decode(buf)
249    }
250}
251
252impl Encoder for XMPPCodec {
253    type Item = Packet;
254    type Error = Error;
255
256    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
257        let remaining = dst.capacity() - dst.len();
258        let max_stanza_size: usize = 2usize.pow(16);
259        if remaining < max_stanza_size {
260            dst.reserve(max_stanza_size - remaining);
261        }
262
263        match item {
264            Packet::StreamStart(start_attrs) => {
265                let mut buf = String::new();
266                write!(buf, "<stream:stream").unwrap();
267                for (name, value) in start_attrs {
268                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
269                        .unwrap();
270                    if name == "xmlns" {
271                        self.ns = Some(value);
272                    }
273                }
274                write!(buf, ">\n").unwrap();
275
276                print!(">> {}", buf);
277                write!(dst, "{}", buf)
278                    .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
279            },
280            Packet::Stanza(stanza) => {
281                let root_ns = self.ns.as_ref().map(|s| s.as_ref());
282                write_element(&stanza, dst, root_ns)
283                    .and_then(|_| {
284                        println!(">> {:?}", dst);
285                        Ok(())
286                    })
287                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
288            },
289            Packet::Text(text) => {
290                write_text(&text, dst)
291                    .and_then(|_| {
292                        println!(">> {:?}", dst);
293                        Ok(())
294                    })
295                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
296            },
297            // TODO: Implement all
298            _ => Ok(())
299        }
300    }
301}
302
303pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
304    write!(writer, "{}", text)
305}
306
307// TODO: escape everything?
308pub fn write_element<W: Write>(el: &Element, writer: &mut W, parent_ns: Option<&str>) -> Result<(), std::fmt::Error> {
309    write!(writer, "<")?;
310    write!(writer, "{}", el.name())?;
311
312    if let Some(ns) = el.ns() {
313        if parent_ns.map(|s| s.as_ref()) != el.ns() {
314            write!(writer, " xmlns=\"{}\"", ns)?;
315        }
316    }
317
318    for (key, value) in el.attrs() {
319        write!(writer, " {}=\"{}\"", key, value)?;
320    }
321
322    if ! el.nodes().any(|_| true) {
323        write!(writer, " />")?;
324        return Ok(())
325    }
326
327    write!(writer, ">")?;
328
329    for node in el.nodes() {
330        match *node {
331            Node::Element(ref child) =>
332                write_element(child, writer, el.ns())?,
333            Node::Text(ref text) =>
334                write_text(text, writer)?,
335        }
336    }
337
338    write!(writer, "</{}>", el.name())?;
339    Ok(())
340}
341
342/// Copied from `RustyXML` for now
343pub fn escape(input: &str) -> String {
344    let mut result = String::with_capacity(input.len());
345
346    for c in input.chars() {
347        match c {
348            '&' => result.push_str("&amp;"),
349            '<' => result.push_str("&lt;"),
350            '>' => result.push_str("&gt;"),
351            '\'' => result.push_str("&apos;"),
352            '"' => result.push_str("&quot;"),
353            o => result.push(o)
354        }
355    }
356    result
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use bytes::BytesMut;
363
364    #[test]
365    fn test_stream_start() {
366        let mut c = XMPPCodec::new();
367        let mut b = BytesMut::with_capacity(1024);
368        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
369        let r = c.decode(&mut b);
370        assert!(match r {
371            Ok(Some(Packet::StreamStart(_))) => true,
372            _ => false,
373        });
374    }
375
376    #[test]
377    fn test_truncated_stanza() {
378        let mut c = XMPPCodec::new();
379        let mut b = BytesMut::with_capacity(1024);
380        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
381        let r = c.decode(&mut b);
382        assert!(match r {
383            Ok(Some(Packet::StreamStart(_))) => true,
384            _ => false,
385        });
386
387        b.clear();
388        b.put(r"<test>ß</test");
389        let r = c.decode(&mut b);
390        assert!(match r {
391            Ok(None) => true,
392            _ => false,
393        });
394
395        b.clear();
396        b.put(r">");
397        let r = c.decode(&mut b);
398        assert!(match r {
399            Ok(Some(Packet::Stanza(ref el)))
400                if el.name() == "test"
401                && el.text() == "ß"
402                => true,
403            _ => false,
404        });
405    }
406
407    #[test]
408    fn test_truncated_utf8() {
409        let mut c = XMPPCodec::new();
410        let mut b = BytesMut::with_capacity(1024);
411        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
412        let r = c.decode(&mut b);
413        assert!(match r {
414            Ok(Some(Packet::StreamStart(_))) => true,
415            _ => false,
416        });
417
418        b.clear();
419        b.put(&b"<test>\xc3"[..]);
420        let r = c.decode(&mut b);
421        assert!(match r {
422            Ok(None) => true,
423            _ => false,
424        });
425
426        b.clear();
427        b.put(&b"\x9f</test>"[..]);
428        let r = c.decode(&mut b);
429        assert!(match r {
430            Ok(Some(Packet::Stanza(ref el)))
431                if el.name() == "test"
432                && el.text() == "ß"
433                => true,
434            _ => false,
435        });
436    }
437
438    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
439    #[test]
440    fn test_large_stanza() {
441        use std::io::Cursor;
442        use futures::{Future, Sink};
443        use tokio_io::codec::FramedWrite;
444        let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
445        let mut text = "".to_owned();
446        for _ in 0..2usize.pow(15) {
447            text = text + "A";
448        }
449        let stanza = Element::builder("message")
450            .append(
451                Element::builder("body")
452                    .append(&text)
453                    .build()
454            )
455            .build();
456        let framed = framed.send(Packet::Stanza(stanza))
457            .wait()
458            .expect("send");
459        assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
460    }
461}