xmpp_codec.rs

  1use std;
  2use std::default::Default;
  3use std::iter::FromIterator;
  4use std::cell::RefCell;
  5use std::rc::Rc;
  6use std::fmt::Write;
  7use std::str::from_utf8;
  8use std::io::{Error, ErrorKind};
  9use std::collections::HashMap;
 10use std::collections::vec_deque::VecDeque;
 11use tokio_io::codec::{Encoder, Decoder};
 12use minidom::{Element, Node};
 13use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
 14use xml5ever::interface::Attribute;
 15use bytes::*;
 16
 17// const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
 18
 19#[derive(Debug)]
 20pub enum Packet {
 21    Error(Box<std::error::Error>),
 22    StreamStart(HashMap<String, String>),
 23    Stanza(Element),
 24    Text(String),
 25    StreamEnd,
 26}
 27
 28struct ParserSink {
 29    // Ready stanzas, shared with XMPPCodec
 30    queue: Rc<RefCell<VecDeque<Packet>>>,
 31    // Parsing stack
 32    stack: Vec<Element>,
 33    ns_stack: Vec<HashMap<Option<String>, String>>,
 34}
 35
 36impl ParserSink {
 37    pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
 38        ParserSink {
 39            queue,
 40            stack: vec![],
 41            ns_stack: vec![],
 42        }
 43    }
 44
 45    fn push_queue(&self, pkt: Packet) {
 46        self.queue.borrow_mut().push_back(pkt);
 47    }
 48
 49    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 50        for nss in self.ns_stack.iter().rev() {
 51            match nss.get(prefix) {
 52                Some(ns) => return Some(ns),
 53                None => (),
 54            }
 55        }
 56
 57        None
 58    }
 59
 60    fn handle_start_tag(&mut self, tag: Tag) {
 61        let mut nss = HashMap::new();
 62        let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
 63            .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 64            .unwrap_or(false);
 65        for attr in &tag.attrs {
 66            match attr.name.local.as_ref() {
 67                "xmlns" => {
 68                    nss.insert(None, attr.value.as_ref().to_owned());
 69                },
 70                prefix if is_prefix_xmlns(attr) => {
 71                        nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 72                    },
 73                _ => (),
 74            }
 75        }
 76        self.ns_stack.push(nss);
 77
 78        let el = {
 79            let mut el_builder = Element::builder(tag.name.local.as_ref());
 80            match self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned())) {
 81                Some(el_ns) => el_builder = el_builder.ns(el_ns),
 82                None => (),
 83            }
 84            for attr in &tag.attrs {
 85                match attr.name.local.as_ref() {
 86                    "xmlns" => (),
 87                    _ if is_prefix_xmlns(attr) => (),
 88                    _ => {
 89                        el_builder = el_builder.attr(
 90                            attr.name.local.as_ref(),
 91                            attr.value.as_ref()
 92                        );
 93                    },
 94                }
 95            }
 96            el_builder.build()
 97        };
 98
 99        if self.stack.is_empty() {
100            let attrs = HashMap::from_iter(
101                el.attrs()
102                    .map(|(name, value)| (name.to_owned(), value.to_owned()))
103            );
104            self.push_queue(Packet::StreamStart(attrs));
105        }
106
107        self.stack.push(el);
108    }
109
110    fn handle_end_tag(&mut self) {
111        let el = self.stack.pop().unwrap();
112        self.ns_stack.pop();
113
114        match self.stack.len() {
115            // </stream:stream>
116            0 =>
117                self.push_queue(Packet::StreamEnd),
118            // </stanza>
119            1 =>
120                self.push_queue(Packet::Stanza(el)),
121            len => {
122                let parent = &mut self.stack[len - 1];
123                parent.append_child(el);
124            },
125        }
126    }
127}
128
129impl TokenSink for ParserSink {
130    fn process_token(&mut self, token: Token) {
131        match token {
132            Token::TagToken(tag) => match tag.kind {
133                TagKind::StartTag =>
134                    self.handle_start_tag(tag),
135                TagKind::EndTag =>
136                    self.handle_end_tag(),
137                TagKind::EmptyTag => {
138                    self.handle_start_tag(tag);
139                    self.handle_end_tag();
140                },
141                TagKind::ShortTag =>
142                    self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
143            },
144            Token::CharacterTokens(tendril) =>
145                match self.stack.len() {
146                    0 | 1 =>
147                        self.push_queue(Packet::Text(tendril.into())),
148                    len => {
149                        let el = &mut self.stack[len - 1];
150                        el.append_text_node(tendril);
151                    },
152                },
153            Token::EOFToken =>
154                self.push_queue(Packet::StreamEnd),
155            Token::ParseError(s) => {
156                println!("ParseError: {:?}", s);
157                self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
158            },
159            _ => (),
160        }
161    }
162
163    // fn end(&mut self) {
164    // }
165}
166
167pub struct XMPPCodec {
168    /// Outgoing
169    ns: Option<String>,
170    /// Incoming
171    parser: XmlTokenizer<ParserSink>,
172    /// For handling incoming truncated utf8
173    // TODO: optimize using  tendrils?
174    buf: Vec<u8>,
175    /// Shared with ParserSink
176    queue: Rc<RefCell<VecDeque<Packet>>>,
177}
178
179impl XMPPCodec {
180    pub fn new() -> Self {
181        let queue = Rc::new(RefCell::new((VecDeque::new())));
182        let sink = ParserSink::new(queue.clone());
183        // TODO: configure parser?
184        let parser = XmlTokenizer::new(sink, Default::default());
185        XMPPCodec {
186            ns: None,
187            parser,
188            queue,
189            buf: vec![],
190        }
191    }
192}
193
194impl Decoder for XMPPCodec {
195    type Item = Packet;
196    type Error = Error;
197
198    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
199        let buf1: Box<AsRef<[u8]>> =
200            if self.buf.len() > 0 && buf.len() > 0 {
201                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
202                prefix.extend_from_slice(buf.take().as_ref());
203                Box::new(prefix)
204            } else {
205                Box::new(buf.take())
206            };
207        let buf1 = buf1.as_ref().as_ref();
208        match from_utf8(buf1) {
209            Ok(s) => {
210                if s.len() > 0 {
211                    println!("<< {}", s);
212                    let tendril = FromIterator::from_iter(s.chars());
213                    self.parser.feed(tendril);
214                }
215            },
216            // Remedies for truncated utf8
217            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
218                // Prepare all the valid data
219                let mut b = BytesMut::with_capacity(e.valid_up_to());
220                b.put(&buf1[0..e.valid_up_to()]);
221
222                // Retry
223                let result = self.decode(&mut b);
224
225                // Keep the tail back in
226                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
227
228                return result;
229            },
230            Err(e) => {
231                println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
232                return Err(Error::new(ErrorKind::InvalidInput, e));
233            },
234        }
235
236        let result = self.queue.borrow_mut().pop_front();
237        Ok(result)
238    }
239
240    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
241        self.decode(buf)
242    }
243}
244
245impl Encoder for XMPPCodec {
246    type Item = Packet;
247    type Error = Error;
248
249    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
250        match item {
251            Packet::StreamStart(start_attrs) => {
252                let mut buf = String::new();
253                write!(buf, "<stream:stream").unwrap();
254                for (name, value) in start_attrs.into_iter() {
255                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
256                        .unwrap();
257                    if name == "xmlns" {
258                        self.ns = Some(value);
259                    }
260                }
261                write!(buf, ">\n").unwrap();
262
263                print!(">> {}", buf);
264                write!(dst, "{}", buf)
265                    .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
266            },
267            Packet::Stanza(stanza) => {
268                let root_ns = self.ns.as_ref().map(|s| s.as_ref());
269                write_element(&stanza, dst, root_ns)
270                    .and_then(|_| {
271                        println!(">> {:?}", dst);
272                        Ok(())
273                    })
274                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
275            },
276            Packet::Text(text) => {
277                write_text(&text, dst)
278                    .and_then(|_| {
279                        println!(">> {:?}", dst);
280                        Ok(())
281                    })
282                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
283            },
284            // TODO: Implement all
285            _ => Ok(())
286        }
287    }
288}
289
290pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
291    write!(writer, "{}", text)
292}
293
294// TODO: escape everything?
295pub fn write_element<W: Write>(el: &Element, writer: &mut W, parent_ns: Option<&str>) -> Result<(), std::fmt::Error> {
296    write!(writer, "<")?;
297    write!(writer, "{}", el.name())?;
298
299    if let Some(ref ns) = el.ns() {
300        if parent_ns.map(|s| s.as_ref()) != el.ns() {
301            write!(writer, " xmlns=\"{}\"", ns)?;
302        }
303    }
304
305    for (key, value) in el.attrs() {
306        write!(writer, " {}=\"{}\"", key, value)?;
307    }
308
309    if ! el.nodes().any(|_| true) {
310        write!(writer, " />")?;
311        return Ok(())
312    }
313
314    write!(writer, ">")?;
315
316    for node in el.nodes() {
317        match node {
318            &Node::Element(ref child) =>
319                write_element(child, writer, el.ns())?,
320            &Node::Text(ref text) =>
321                write_text(text, writer)?,
322        }
323    }
324
325    write!(writer, "</{}>", el.name())?;
326    Ok(())
327}
328
329/// Copied from RustyXML for now
330pub fn escape(input: &str) -> String {
331    let mut result = String::with_capacity(input.len());
332
333    for c in input.chars() {
334        match c {
335            '&' => result.push_str("&amp;"),
336            '<' => result.push_str("&lt;"),
337            '>' => result.push_str("&gt;"),
338            '\'' => result.push_str("&apos;"),
339            '"' => result.push_str("&quot;"),
340            o => result.push(o)
341        }
342    }
343    result
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use bytes::BytesMut;
350
351    #[test]
352    fn test_stream_start() {
353        let mut c = XMPPCodec::new();
354        let mut b = BytesMut::with_capacity(1024);
355        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
356        let r = c.decode(&mut b);
357        assert!(match r {
358            Ok(Some(Packet::StreamStart(_))) => true,
359            _ => false,
360        });
361    }
362
363    #[test]
364    fn test_truncated_stanza() {
365        let mut c = XMPPCodec::new();
366        let mut b = BytesMut::with_capacity(1024);
367        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
368        let r = c.decode(&mut b);
369        assert!(match r {
370            Ok(Some(Packet::StreamStart(_))) => true,
371            _ => false,
372        });
373
374        b.clear();
375        b.put(r"<test>ß</test");
376        let r = c.decode(&mut b);
377        assert!(match r {
378            Ok(None) => true,
379            _ => false,
380        });
381
382        b.clear();
383        b.put(r">");
384        let r = c.decode(&mut b);
385        assert!(match r {
386            Ok(Some(Packet::Stanza(ref el)))
387                if el.name() == "test"
388                && el.text() == "ß"
389                => true,
390            _ => false,
391        });
392    }
393
394    #[test]
395    fn test_truncated_utf8() {
396        let mut c = XMPPCodec::new();
397        let mut b = BytesMut::with_capacity(1024);
398        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
399        let r = c.decode(&mut b);
400        assert!(match r {
401            Ok(Some(Packet::StreamStart(_))) => true,
402            _ => false,
403        });
404
405        b.clear();
406        b.put(&b"<test>\xc3"[..]);
407        let r = c.decode(&mut b);
408        assert!(match r {
409            Ok(None) => true,
410            _ => false,
411        });
412
413        b.clear();
414        b.put(&b"\x9f</test>"[..]);
415        let r = c.decode(&mut b);
416        assert!(match r {
417            Ok(Some(Packet::Stanza(ref el)))
418                if el.name() == "test"
419                && el.text() == "ß"
420                => true,
421            _ => false,
422        });
423    }
424}