xmpp_codec: add remedies for truncated utf8

Astro created

Change summary

src/xmpp_codec.rs | 113 +++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 110 insertions(+), 3 deletions(-)

Detailed changes

src/xmpp_codec.rs 🔗

@@ -56,6 +56,7 @@ pub enum Packet {
 pub struct XMPPCodec {
     parser: xml::Parser,
     root: Option<XMPPRoot>,
+    buf: Vec<u8>,
 }
 
 impl XMPPCodec {
@@ -63,6 +64,7 @@ impl XMPPCodec {
         XMPPCodec {
             parser: xml::Parser::new(),
             root: None,
+            buf: vec![],
         }
     }
 }
@@ -72,15 +74,40 @@ impl Decoder for XMPPCodec {
     type Error = Error;
 
     fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
-        match from_utf8(buf.take().as_ref()) {
+        let buf1: Box<AsRef<[u8]>> =
+            if self.buf.len() > 0 && buf.len() > 0 {
+                let mut prefix = std::mem::replace(&mut self.buf, vec![]);
+                prefix.extend_from_slice(buf.take().as_ref());
+                Box::new(prefix)
+            } else {
+                Box::new(buf.take())
+            };
+        let buf1 = buf1.as_ref().as_ref();
+        match from_utf8(buf1) {
             Ok(s) => {
                 if s.len() > 0 {
                     println!("<< {}", s);
                     self.parser.feed_str(s);
                 }
             },
-            Err(e) =>
-                return Err(Error::new(ErrorKind::InvalidInput, e)),
+            // Remedies for truncated utf8
+            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
+                // Prepare all the valid data
+                let mut b = BytesMut::with_capacity(e.valid_up_to());
+                b.put(&buf1[0..e.valid_up_to()]);
+
+                // Retry
+                let result = self.decode(&mut b);
+
+                // Keep the tail back in
+                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
+
+                return result;
+            },
+            Err(e) => {
+                println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
+                return Err(Error::new(ErrorKind::InvalidInput, e));
+            },
         }
 
         let mut new_root: Option<XMPPRoot> = None;
@@ -171,3 +198,83 @@ impl Encoder for XMPPCodec {
         .map_err(|_| Error::from(ErrorKind::InvalidInput))
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use bytes::BytesMut;
+
+    #[test]
+    fn test_stream_start() {
+        let mut c = XMPPCodec::new();
+        let mut b = BytesMut::with_capacity(1024);
+        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        let r = c.decode(&mut b);
+        assert!(match r {
+            Ok(Some(Packet::StreamStart(_))) => true,
+            _ => false,
+        });
+    }
+
+    #[test]
+    fn test_truncated_stanza() {
+        let mut c = XMPPCodec::new();
+        let mut b = BytesMut::with_capacity(1024);
+        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        let r = c.decode(&mut b);
+        assert!(match r {
+            Ok(Some(Packet::StreamStart(_))) => true,
+            _ => false,
+        });
+
+        b.clear();
+        b.put(r"<test>ß</test");
+        let r = c.decode(&mut b);
+        assert!(match r {
+            Ok(None) => true,
+            _ => false,
+        });
+
+        b.clear();
+        b.put(r">");
+        let r = c.decode(&mut b);
+        assert!(match r {
+            Ok(Some(Packet::Stanza(ref el)))
+                if el.name == "test"
+                && el.content_str() == "ß"
+                => true,
+            _ => false,
+        });
+    }
+
+    #[test]
+    fn test_truncated_utf8() {
+        let mut c = XMPPCodec::new();
+        let mut b = BytesMut::with_capacity(1024);
+        b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
+        let r = c.decode(&mut b);
+        assert!(match r {
+            Ok(Some(Packet::StreamStart(_))) => true,
+            _ => false,
+        });
+
+        b.clear();
+        b.put(&b"<test>\xc3"[..]);
+        let r = c.decode(&mut b);
+        assert!(match r {
+            Ok(None) => true,
+            _ => false,
+        });
+
+        b.clear();
+        b.put(&b"\x9f</test>"[..]);
+        let r = c.decode(&mut b);
+        assert!(match r {
+            Ok(Some(Packet::Stanza(ref el)))
+                if el.name == "test"
+                && el.content_str() == "ß"
+                => true,
+            _ => false,
+        });
+    }
+}