a proper ParserError type for XMPPCodec

Astro created

Change summary

src/stream_start.rs | 14 +++---
src/xmpp_codec.rs   | 95 ++++++++++++++++++++++++++++++++++++++--------
2 files changed, 84 insertions(+), 25 deletions(-)

Detailed changes

src/stream_start.rs 🔗

@@ -1,12 +1,12 @@
 use std::mem::replace;
-use std::io::{Error, ErrorKind};
+use std::borrow::Cow;
 use futures::{Future, Async, Poll, Stream, sink, Sink};
 use tokio_io::{AsyncRead, AsyncWrite};
 use tokio_codec::Framed;
 use jid::Jid;
 use minidom::Element;
 
-use xmpp_codec::{XMPPCodec, Packet};
+use xmpp_codec::{XMPPCodec, Packet, ParserError};
 use xmpp_stream::XMPPStream;
 
 const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
@@ -43,7 +43,7 @@ impl<S: AsyncWrite> StreamStart<S> {
 
 impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
     type Item = XMPPStream<S>;
-    type Error = Error;
+    type Error = ParserError;
 
     fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
         let old_state = replace(&mut self.state, StreamStartState::Invalid);
@@ -59,7 +59,7 @@ impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
                     Ok(Async::NotReady) =>
                         (StreamStartState::SendStart(send), Ok(Async::NotReady)),
                     Err(e) =>
-                        (StreamStartState::Invalid, Err(e)),
+                        (StreamStartState::Invalid, Err(e.into())),
                 },
             StreamStartState::RecvStart(mut stream) =>
                 match stream.poll() {
@@ -67,7 +67,7 @@ impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
                         let stream_ns = match stream_attrs.get("xmlns") {
                             Some(ns) => ns.clone(),
                             None =>
-                                return Err(Error::from(ErrorKind::InvalidData)),
+                                return Err(ParserError::Parse(Cow::from("Missing stream namespace"))),
                         };
                         if self.ns == "jabber:client" {
                             retry = true;
@@ -77,7 +77,7 @@ impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
                             let id = match stream_attrs.get("id") {
                                 Some(id) => id.clone(),
                                 None =>
-                                    return Err(Error::from(ErrorKind::InvalidData)),
+                                    return Err(ParserError::Parse(Cow::from("No stream id"))),
                             };
                                                                                                     // FIXME: huge hack, shouldn’t be an element!
                             let stream = XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), Element::builder(id).build());
@@ -85,7 +85,7 @@ impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
                         }
                     },
                     Ok(Async::Ready(_)) =>
-                        return Err(Error::from(ErrorKind::InvalidData)),
+                        return Err(ParserError::Parse(Cow::from("Invalid XML event received"))),
                     Ok(Async::NotReady) =>
                         (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
                     Err(e) =>

src/xmpp_codec.rs 🔗

@@ -6,10 +6,13 @@ use std::iter::FromIterator;
 use std::cell::RefCell;
 use std::rc::Rc;
 use std::fmt::Write;
-use std::str::from_utf8;
-use std::io::{Error, ErrorKind};
+use std::str::{from_utf8, Utf8Error};
+use std::io;
 use std::collections::HashMap;
 use std::collections::vec_deque::VecDeque;
+use std::error::Error as StdError;
+use std::fmt;
+use std::borrow::Cow;
 use tokio_codec::{Encoder, Decoder};
 use minidom::Element;
 use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
@@ -20,8 +23,6 @@ use quick_xml::Writer as EventWriter;
 /// Anything that can be sent or received on an XMPP/XML stream
 #[derive(Debug)]
 pub enum Packet {
-    /// General error (`InvalidInput`)
-    Error(Box<std::error::Error>),
     /// `<stream:stream>` start tag
     StreamStart(HashMap<String, String>),
     /// A complete stanza or nonza
@@ -32,17 +33,68 @@ pub enum Packet {
     StreamEnd,
 }
 
+/// Causes for stream parsing errors
+#[derive(Debug)]
+pub enum ParserError {
+    /// Encoding error
+    Utf8(Utf8Error),
+    /// XML parse error
+    Parse(Cow<'static, str>),
+    /// Illegal `</>`
+    ShortTag,
+    /// Required by `impl Decoder`
+    IO(io::Error),
+}
+
+impl From<io::Error> for ParserError {
+    fn from(e: io::Error) -> Self {
+        ParserError::IO(e)
+    }
+}
+
+impl StdError for ParserError {
+    fn description(&self) -> &str {
+        match *self {
+            ParserError::Utf8(ref ue) => ue.description(),
+            ParserError::Parse(ref pe) => pe,
+            ParserError::ShortTag => "short tag",
+            ParserError::IO(ref ie) => ie.description(),
+        }
+    }
+
+    fn cause(&self) -> Option<&StdError> {
+        match *self {
+            ParserError::Utf8(ref ue) => ue.cause(),
+            ParserError::IO(ref ie) => ie.cause(),
+            _ => None,
+        }
+    }
+}
+
+impl fmt::Display for ParserError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match *self {
+            ParserError::Utf8(ref ue) => write!(f, "{}", ue),
+            ParserError::Parse(ref pe) => write!(f, "{}", pe),
+            ParserError::ShortTag => write!(f, "Short tag"),
+            ParserError::IO(ref ie) => write!(f, "{}", ie),
+        }
+    }
+}
+
+type QueueItem = Result<Packet, ParserError>;
+
 /// Parser state
 struct ParserSink {
     // Ready stanzas, shared with XMPPCodec
-    queue: Rc<RefCell<VecDeque<Packet>>>,
+    queue: Rc<RefCell<VecDeque<QueueItem>>>,
     // Parsing stack
     stack: Vec<Element>,
     ns_stack: Vec<HashMap<Option<String>, String>>,
 }
 
 impl ParserSink {
-    pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
+    pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
         ParserSink {
             queue,
             stack: vec![],
@@ -51,7 +103,11 @@ impl ParserSink {
     }
 
     fn push_queue(&self, pkt: Packet) {
-        self.queue.borrow_mut().push_back(pkt);
+        self.queue.borrow_mut().push_back(Ok(pkt));
+    }
+
+    fn push_queue_error(&self, e: ParserError) {
+        self.queue.borrow_mut().push_back(Err(e));
     }
 
     /// Lookup XML namespace declaration for given prefix (or no prefix)
@@ -149,7 +205,7 @@ impl TokenSink for ParserSink {
                     self.handle_end_tag();
                 },
                 TagKind::ShortTag =>
-                    self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
+                    self.push_queue_error(ParserError::ShortTag),
             },
             Token::CharacterTokens(tendril) =>
                 match self.stack.len() {
@@ -164,7 +220,7 @@ impl TokenSink for ParserSink {
                 self.push_queue(Packet::StreamEnd),
             Token::ParseError(s) => {
                 // println!("ParseError: {:?}", s);
-                self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
+                self.push_queue_error(ParserError::Parse(s));
             },
             _ => (),
         }
@@ -184,7 +240,7 @@ pub struct XMPPCodec {
     // TODO: optimize using  tendrils?
     buf: Vec<u8>,
     /// Shared with ParserSink
-    queue: Rc<RefCell<VecDeque<Packet>>>,
+    queue: Rc<RefCell<VecDeque<QueueItem>>>,
 }
 
 impl XMPPCodec {
@@ -211,7 +267,7 @@ impl Default for XMPPCodec {
 
 impl Decoder for XMPPCodec {
     type Item = Packet;
-    type Error = Error;
+    type Error = ParserError;
 
     fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
         let buf1: Box<AsRef<[u8]>> =
@@ -247,12 +303,15 @@ impl Decoder for XMPPCodec {
             },
             Err(e) => {
                 // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
-                return Err(Error::new(ErrorKind::InvalidInput, e));
+                return Err(ParserError::Utf8(e));
             },
         }
 
-        let result = self.queue.borrow_mut().pop_front();
-        Ok(result)
+        match self.queue.borrow_mut().pop_front() {
+            None => Ok(None),
+            Some(result) =>
+                result.map(|pkt| Some(pkt)),
+        }
     }
 
     fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
@@ -262,7 +321,7 @@ impl Decoder for XMPPCodec {
 
 impl Encoder for XMPPCodec {
     type Item = Packet;
-    type Error = Error;
+    type Error = io::Error;
 
     fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
         let remaining = dst.capacity() - dst.len();
@@ -286,7 +345,7 @@ impl Encoder for XMPPCodec {
 
                 print!(">> {}", buf);
                 write!(dst, "{}", buf)
-                    .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
+                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
             },
             Packet::Stanza(stanza) => {
                 stanza.write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
@@ -294,7 +353,7 @@ impl Encoder for XMPPCodec {
                         // println!(">> {:?}", dst);
                         Ok(())
                     })
-                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
+                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
             },
             Packet::Text(text) => {
                 write_text(&text, dst)
@@ -302,7 +361,7 @@ impl Encoder for XMPPCodec {
                         // println!(">> {:?}", dst);
                         Ok(())
                     })
-                    .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
+                    .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
             },
             // TODO: Implement all
             _ => Ok(())