xmpp_codec.rs

  1use std;
  2use std::fmt::Write;
  3use std::str::from_utf8;
  4use std::io::{Error, ErrorKind};
  5use std::collections::HashMap;
  6use tokio_io::codec::{Encoder, Decoder};
  7use xml;
  8use bytes::*;
  9
 10const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
 11
 12pub type Attributes = HashMap<(String, Option<String>), String>;
 13
 14struct XMPPRoot {
 15    builder: xml::ElementBuilder,
 16    pub attributes: Attributes,
 17}
 18
 19impl XMPPRoot {
 20    fn new(root: xml::StartTag) -> Self {
 21        let mut builder = xml::ElementBuilder::new();
 22        let mut attributes = HashMap::new();
 23        for (name_ns, value) in root.attributes {
 24            match name_ns {
 25                (ref name, None) if name == "xmlns" =>
 26                    builder.set_default_ns(value),
 27                (ref prefix, Some(ref ns)) if ns == NS_XMLNS =>
 28                    builder.define_prefix(prefix.to_owned(), value),
 29                _ => {
 30                    attributes.insert(name_ns, value);
 31                },
 32            }
 33        }
 34
 35        XMPPRoot {
 36            builder: builder,
 37            attributes: attributes,
 38        }
 39    }
 40
 41    fn handle_event(&mut self, event: Result<xml::Event, xml::ParserError>)
 42                    -> Option<Result<xml::Element, xml::BuilderError>> {
 43        self.builder.handle_event(event)
 44    }
 45}
 46
 47#[derive(Debug)]
 48pub enum Packet {
 49    Error(Box<std::error::Error>),
 50    StreamStart(HashMap<String, String>),
 51    Stanza(xml::Element),
 52    Text(String),
 53    StreamEnd,
 54}
 55
 56pub struct XMPPCodec {
 57    parser: xml::Parser,
 58    root: Option<XMPPRoot>,
 59    buf: Vec<u8>,
 60}
 61
 62impl XMPPCodec {
 63    pub fn new() -> Self {
 64        XMPPCodec {
 65            parser: xml::Parser::new(),
 66            root: None,
 67            buf: vec![],
 68        }
 69    }
 70}
 71
 72impl Decoder for XMPPCodec {
 73    type Item = Packet;
 74    type Error = Error;
 75
 76    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
 77        let buf1: Box<AsRef<[u8]>> =
 78            if self.buf.len() > 0 && buf.len() > 0 {
 79                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
 80                prefix.extend_from_slice(buf.take().as_ref());
 81                Box::new(prefix)
 82            } else {
 83                Box::new(buf.take())
 84            };
 85        let buf1 = buf1.as_ref().as_ref();
 86        match from_utf8(buf1) {
 87            Ok(s) => {
 88                if s.len() > 0 {
 89                    println!("<< {}", s);
 90                    self.parser.feed_str(s);
 91                }
 92            },
 93            // Remedies for truncated utf8
 94            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
 95                // Prepare all the valid data
 96                let mut b = BytesMut::with_capacity(e.valid_up_to());
 97                b.put(&buf1[0..e.valid_up_to()]);
 98
 99                // Retry
100                let result = self.decode(&mut b);
101
102                // Keep the tail back in
103                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
104
105                return result;
106            },
107            Err(e) => {
108                println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
109                return Err(Error::new(ErrorKind::InvalidInput, e));
110            },
111        }
112
113        let mut new_root: Option<XMPPRoot> = None;
114        let mut result = None;
115        for event in &mut self.parser {
116            match self.root {
117                None => {
118                    // Expecting <stream:stream>
119                    match event {
120                        Ok(xml::Event::ElementStart(start_tag)) => {
121                            let mut attrs: HashMap<String, String> = HashMap::new();
122                            for (&(ref name, _), value) in &start_tag.attributes {
123                                attrs.insert(name.to_owned(), value.to_owned());
124                            }
125                            result = Some(Packet::StreamStart(attrs));
126                            self.root = Some(XMPPRoot::new(start_tag));
127                            break
128                        },
129                        Err(e) => {
130                            result = Some(Packet::Error(Box::new(e)));
131                            break
132                        },
133                        _ =>
134                            (),
135                    }
136                }
137
138                Some(ref mut root) => {
139                    match root.handle_event(event) {
140                        None => (),
141                        Some(Ok(stanza)) => {
142                            // Emit the stanza
143                            result = Some(Packet::Stanza(stanza));
144                            break
145                        },
146                        Some(Err(e)) => {
147                            result = Some(Packet::Error(Box::new(e)));
148                            break
149                        }
150                    };
151                },
152            }
153
154            match new_root.take() {
155                None => (),
156                Some(root) => self.root = Some(root),
157            }
158        }
159
160        Ok(result)
161    }
162
163    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
164        self.decode(buf)
165    }
166}
167
168impl Encoder for XMPPCodec {
169    type Item = Packet;
170    type Error = Error;
171
172    fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
173        match item {
174            Packet::StreamStart(start_attrs) => {
175                let mut buf = String::new();
176                write!(buf, "<stream:stream").unwrap();
177                for (ref name, ref value) in &start_attrs {
178                    write!(buf, " {}=\"{}\"", xml::escape(&name), xml::escape(&value))
179                        .unwrap();
180                }
181                write!(buf, ">\n").unwrap();
182
183                print!(">> {}", buf);
184                write!(dst, "{}", buf)
185            },
186            Packet::Stanza(stanza) => {
187                println!(">> {}", stanza);
188                write!(dst, "{}", stanza)
189            },
190            Packet::Text(text) => {
191                let escaped = xml::escape(&text);
192                println!(">> {}", escaped);
193                write!(dst, "{}", escaped)
194            },
195            // TODO: Implement all
196            _ => Ok(())
197        }
198        .map_err(|_| Error::from(ErrorKind::InvalidInput))
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use bytes::BytesMut;
206
207    #[test]
208    fn test_stream_start() {
209        let mut c = XMPPCodec::new();
210        let mut b = BytesMut::with_capacity(1024);
211        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
212        let r = c.decode(&mut b);
213        assert!(match r {
214            Ok(Some(Packet::StreamStart(_))) => true,
215            _ => false,
216        });
217    }
218
219    #[test]
220    fn test_truncated_stanza() {
221        let mut c = XMPPCodec::new();
222        let mut b = BytesMut::with_capacity(1024);
223        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
224        let r = c.decode(&mut b);
225        assert!(match r {
226            Ok(Some(Packet::StreamStart(_))) => true,
227            _ => false,
228        });
229
230        b.clear();
231        b.put(r"<test>ß</test");
232        let r = c.decode(&mut b);
233        assert!(match r {
234            Ok(None) => true,
235            _ => false,
236        });
237
238        b.clear();
239        b.put(r">");
240        let r = c.decode(&mut b);
241        assert!(match r {
242            Ok(Some(Packet::Stanza(ref el)))
243                if el.name == "test"
244                && el.content_str() == "ß"
245                => true,
246            _ => false,
247        });
248    }
249
250    #[test]
251    fn test_truncated_utf8() {
252        let mut c = XMPPCodec::new();
253        let mut b = BytesMut::with_capacity(1024);
254        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
255        let r = c.decode(&mut b);
256        assert!(match r {
257            Ok(Some(Packet::StreamStart(_))) => true,
258            _ => false,
259        });
260
261        b.clear();
262        b.put(&b"<test>\xc3"[..]);
263        let r = c.decode(&mut b);
264        assert!(match r {
265            Ok(None) => true,
266            _ => false,
267        });
268
269        b.clear();
270        b.put(&b"\x9f</test>"[..]);
271        let r = c.decode(&mut b);
272        assert!(match r {
273            Ok(Some(Packet::Stanza(ref el)))
274                if el.name == "test"
275                && el.content_str() == "ß"
276                => true,
277            _ => false,
278        });
279    }
280}