@@ -56,6 +56,7 @@ pub enum Packet {
pub struct XMPPCodec {
parser: xml::Parser,
root: Option<XMPPRoot>,
+ buf: Vec<u8>,
}
impl XMPPCodec {
@@ -63,6 +64,7 @@ impl XMPPCodec {
XMPPCodec {
parser: xml::Parser::new(),
root: None,
+ buf: vec![],
}
}
}
@@ -72,15 +74,40 @@ impl Decoder for XMPPCodec {
type Error = Error;
fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
- match from_utf8(buf.take().as_ref()) {
+ let buf1: Box<AsRef<[u8]>> =
+ if self.buf.len() > 0 && buf.len() > 0 {
+ let mut prefix = std::mem::replace(&mut self.buf, vec![]);
+ prefix.extend_from_slice(buf.take().as_ref());
+ Box::new(prefix)
+ } else {
+ Box::new(buf.take())
+ };
+ let buf1 = buf1.as_ref().as_ref();
+ match from_utf8(buf1) {
Ok(s) => {
if s.len() > 0 {
println!("<< {}", s);
self.parser.feed_str(s);
}
},
- Err(e) =>
- return Err(Error::new(ErrorKind::InvalidInput, e)),
+ // Remedies for truncated utf8
+ Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
+ // Prepare all the valid data
+ let mut b = BytesMut::with_capacity(e.valid_up_to());
+ b.put(&buf1[0..e.valid_up_to()]);
+
+ // Retry
+ let result = self.decode(&mut b);
+
+ // Keep the tail back in
+ self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
+
+ return result;
+ },
+ Err(e) => {
+ println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
+ return Err(Error::new(ErrorKind::InvalidInput, e));
+ },
}
let mut new_root: Option<XMPPRoot> = None;
@@ -171,3 +198,83 @@ impl Encoder for XMPPCodec {
.map_err(|_| Error::from(ErrorKind::InvalidInput))
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use bytes::BytesMut;
+
+ #[test]
+ fn test_stream_start() {
+ let mut c = XMPPCodec::new();
+ let mut b = BytesMut::with_capacity(1024);
+ b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ let r = c.decode(&mut b);
+ assert!(match r {
+ Ok(Some(Packet::StreamStart(_))) => true,
+ _ => false,
+ });
+ }
+
+ #[test]
+ fn test_truncated_stanza() {
+ let mut c = XMPPCodec::new();
+ let mut b = BytesMut::with_capacity(1024);
+ b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ let r = c.decode(&mut b);
+ assert!(match r {
+ Ok(Some(Packet::StreamStart(_))) => true,
+ _ => false,
+ });
+
+ b.clear();
+ b.put(r"<test>ß</test");
+ let r = c.decode(&mut b);
+ assert!(match r {
+ Ok(None) => true,
+ _ => false,
+ });
+
+ b.clear();
+ b.put(r">");
+ let r = c.decode(&mut b);
+ assert!(match r {
+ Ok(Some(Packet::Stanza(ref el)))
+ if el.name == "test"
+ && el.content_str() == "ß"
+ => true,
+ _ => false,
+ });
+ }
+
+ #[test]
+ fn test_truncated_utf8() {
+ let mut c = XMPPCodec::new();
+ let mut b = BytesMut::with_capacity(1024);
+ b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+ let r = c.decode(&mut b);
+ assert!(match r {
+ Ok(Some(Packet::StreamStart(_))) => true,
+ _ => false,
+ });
+
+ b.clear();
+ b.put(&b"<test>\xc3"[..]);
+ let r = c.decode(&mut b);
+ assert!(match r {
+ Ok(None) => true,
+ _ => false,
+ });
+
+ b.clear();
+ b.put(&b"\x9f</test>"[..]);
+ let r = c.decode(&mut b);
+ assert!(match r {
+ Ok(Some(Packet::Stanza(ref el)))
+ if el.name == "test"
+ && el.content_str() == "ß"
+ => true,
+ _ => false,
+ });
+ }
+}