xmpp_codec.rs

  1//! XML stream parser for XMPP
  2
  3use crate::{ParseError, ParserError};
  4use bytes::{BufMut, BytesMut};
  5use log::{debug, error};
  6use std;
  7use std::borrow::Cow;
  8use std::collections::vec_deque::VecDeque;
  9use std::collections::HashMap;
 10use std::default::Default;
 11use std::fmt::Write;
 12use std::io;
 13use std::iter::FromIterator;
 14use std::str::from_utf8;
 15use std::sync::Arc;
 16use std::sync::Mutex;
 17use tokio_util::codec::{Decoder, Encoder};
 18use xml5ever::buffer_queue::BufferQueue;
 19use xml5ever::interface::Attribute;
 20use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
 21use xmpp_parsers::Element;
 22
 23/// Anything that can be sent or received on an XMPP/XML stream
 24#[derive(Debug, Clone, PartialEq, Eq)]
 25pub enum Packet {
 26    /// `<stream:stream>` start tag
 27    StreamStart(HashMap<String, String>),
 28    /// A complete stanza or nonza
 29    Stanza(Element),
 30    /// Plain text (think whitespace keep-alive)
 31    Text(String),
 32    /// `</stream:stream>` closing tag
 33    StreamEnd,
 34}
 35
 36type QueueItem = Result<Packet, ParserError>;
 37
 38/// Parser state
 39struct ParserSink {
 40    // Ready stanzas, shared with XMPPCodec
 41    queue: Arc<Mutex<VecDeque<QueueItem>>>,
 42    // Parsing stack
 43    stack: Vec<Element>,
 44    ns_stack: Vec<HashMap<Option<String>, String>>,
 45}
 46
 47impl ParserSink {
 48    pub fn new(queue: Arc<Mutex<VecDeque<QueueItem>>>) -> Self {
 49        ParserSink {
 50            queue,
 51            stack: vec![],
 52            ns_stack: vec![],
 53        }
 54    }
 55
 56    fn push_queue(&self, pkt: Packet) {
 57        self.queue.lock().unwrap().push_back(Ok(pkt));
 58    }
 59
 60    fn push_queue_error(&self, e: ParserError) {
 61        self.queue.lock().unwrap().push_back(Err(e));
 62    }
 63
 64    /// Lookup XML namespace declaration for given prefix (or no prefix)
 65    fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
 66        for nss in self.ns_stack.iter().rev() {
 67            if let Some(ns) = nss.get(prefix) {
 68                return Some(ns);
 69            }
 70        }
 71
 72        None
 73    }
 74
 75    fn handle_start_tag(&mut self, tag: Tag) {
 76        let mut nss = HashMap::new();
 77        let is_prefix_xmlns = |attr: &Attribute| {
 78            attr.name
 79                .prefix
 80                .as_ref()
 81                .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
 82                .unwrap_or(false)
 83        };
 84        for attr in &tag.attrs {
 85            match attr.name.local.as_ref() {
 86                "xmlns" => {
 87                    nss.insert(None, attr.value.as_ref().to_owned());
 88                }
 89                prefix if is_prefix_xmlns(attr) => {
 90                    nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
 91                }
 92                _ => (),
 93            }
 94        }
 95        self.ns_stack.push(nss);
 96
 97        let el = {
 98            let el_ns = self
 99                .lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
100                .unwrap();
101            let mut el_builder = Element::builder(tag.name.local.as_ref(), el_ns);
102            for attr in &tag.attrs {
103                match attr.name.local.as_ref() {
104                    "xmlns" => (),
105                    _ if is_prefix_xmlns(attr) => (),
106                    _ => {
107                        let attr_name = if let Some(ref prefix) = attr.name.prefix {
108                            Cow::Owned(format!("{}:{}", prefix, attr.name.local))
109                        } else {
110                            Cow::Borrowed(attr.name.local.as_ref())
111                        };
112                        el_builder = el_builder.attr(attr_name, attr.value.as_ref());
113                    }
114                }
115            }
116            el_builder.build()
117        };
118
119        if self.stack.is_empty() {
120            let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
121                (
122                    attr.name.local.as_ref().to_owned(),
123                    attr.value.as_ref().to_owned(),
124                )
125            }));
126            self.push_queue(Packet::StreamStart(attrs));
127        }
128
129        self.stack.push(el);
130    }
131
132    fn handle_end_tag(&mut self) {
133        let el = self.stack.pop().unwrap();
134        self.ns_stack.pop();
135
136        match self.stack.len() {
137            // </stream:stream>
138            0 => self.push_queue(Packet::StreamEnd),
139            // </stanza>
140            1 => self.push_queue(Packet::Stanza(el)),
141            len => {
142                let parent = &mut self.stack[len - 1];
143                parent.append_child(el);
144            }
145        }
146    }
147}
148
149impl TokenSink for ParserSink {
150    fn process_token(&mut self, token: Token) {
151        match token {
152            Token::TagToken(tag) => match tag.kind {
153                TagKind::StartTag => self.handle_start_tag(tag),
154                TagKind::EndTag => self.handle_end_tag(),
155                TagKind::EmptyTag => {
156                    self.handle_start_tag(tag);
157                    self.handle_end_tag();
158                }
159                TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
160            },
161            Token::CharacterTokens(tendril) => match self.stack.len() {
162                0 | 1 => self.push_queue(Packet::Text(tendril.into())),
163                len => {
164                    let el = &mut self.stack[len - 1];
165                    el.append_text_node(tendril);
166                }
167            },
168            Token::EOFToken => self.push_queue(Packet::StreamEnd),
169            Token::ParseError(s) => {
170                self.push_queue_error(ParserError::Parse(ParseError(s)));
171            }
172            _ => (),
173        }
174    }
175
176    // fn end(&mut self) {
177    // }
178}
179
180/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
181pub struct XMPPCodec {
182    /// Outgoing
183    ns: Option<String>,
184    /// Incoming
185    parser: XmlTokenizer<ParserSink>,
186    /// For handling incoming truncated utf8
187    // TODO: optimize using  tendrils?
188    buf: Vec<u8>,
189    /// Shared with ParserSink
190    queue: Arc<Mutex<VecDeque<QueueItem>>>,
191}
192
193impl XMPPCodec {
194    /// Constructor
195    pub fn new() -> Self {
196        let queue = Arc::new(Mutex::new(VecDeque::new()));
197        let sink = ParserSink::new(queue.clone());
198        // TODO: configure parser?
199        let parser = XmlTokenizer::new(sink, Default::default());
200        XMPPCodec {
201            ns: None,
202            parser,
203            queue,
204            buf: vec![],
205        }
206    }
207}
208
209impl Default for XMPPCodec {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215impl Decoder for XMPPCodec {
216    type Item = Packet;
217    type Error = ParserError;
218
219    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
220        let buf1: Box<dyn AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
221            let mut prefix = std::mem::replace(&mut self.buf, vec![]);
222            prefix.extend_from_slice(&buf.split_to(buf.len()));
223            Box::new(prefix)
224        } else {
225            Box::new(buf.split_to(buf.len()))
226        };
227        let buf1 = buf1.as_ref().as_ref();
228        match from_utf8(buf1) {
229            Ok(s) => {
230                debug!("<< {:?}", s);
231                if !s.is_empty() {
232                    let mut buffer_queue = BufferQueue::new();
233                    let tendril = FromIterator::from_iter(s.chars());
234                    buffer_queue.push_back(tendril);
235                    self.parser.feed(&mut buffer_queue);
236                }
237            }
238            // Remedies for truncated utf8
239            Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
240                // Prepare all the valid data
241                let mut b = BytesMut::with_capacity(e.valid_up_to());
242                b.put(&buf1[0..e.valid_up_to()]);
243
244                // Retry
245                let result = self.decode(&mut b);
246
247                // Keep the tail back in
248                self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
249
250                return result;
251            }
252            Err(e) => {
253                error!(
254                    "error {} at {}/{} in {:?}",
255                    e,
256                    e.valid_up_to(),
257                    buf1.len(),
258                    buf1
259                );
260                return Err(ParserError::Utf8(e));
261            }
262        }
263
264        match self.queue.lock().unwrap().pop_front() {
265            None => Ok(None),
266            Some(result) => result.map(|pkt| Some(pkt)),
267        }
268    }
269
270    fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
271        self.decode(buf)
272    }
273}
274
275impl Encoder<Packet> for XMPPCodec {
276    type Error = io::Error;
277
278    fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
279        let remaining = dst.capacity() - dst.len();
280        let max_stanza_size: usize = 2usize.pow(16);
281        if remaining < max_stanza_size {
282            dst.reserve(max_stanza_size - remaining);
283        }
284
285        fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
286            io::Error::new(io::ErrorKind::InvalidInput, e)
287        }
288
289        match item {
290            Packet::StreamStart(start_attrs) => {
291                let mut buf = String::new();
292                write!(buf, "<stream:stream").map_err(to_io_err)?;
293                for (name, value) in start_attrs {
294                    write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
295                    if name == "xmlns" {
296                        self.ns = Some(value);
297                    }
298                }
299                write!(buf, ">\n").map_err(to_io_err)?;
300
301                debug!(">> {:?}", buf);
302                write!(dst, "{}", buf).map_err(to_io_err)
303            }
304            Packet::Stanza(stanza) => stanza
305                .write_to(&mut WriteBytes::new(dst))
306                .and_then(|_| {
307                    debug!(">> {:?}", dst);
308                    Ok(())
309                })
310                .map_err(|e| to_io_err(format!("{}", e))),
311            Packet::Text(text) => write_text(&text, dst)
312                .and_then(|_| {
313                    debug!(">> {:?}", dst);
314                    Ok(())
315                })
316                .map_err(to_io_err),
317            Packet::StreamEnd => write!(dst, "</stream:stream>\n").map_err(to_io_err),
318        }
319    }
320}
321
322/// Write XML-escaped text string
323pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
324    write!(writer, "{}", escape(text))
325}
326
327/// Copied from `RustyXML` for now
328pub fn escape(input: &str) -> String {
329    let mut result = String::with_capacity(input.len());
330
331    for c in input.chars() {
332        match c {
333            '&' => result.push_str("&amp;"),
334            '<' => result.push_str("&lt;"),
335            '>' => result.push_str("&gt;"),
336            '\'' => result.push_str("&apos;"),
337            '"' => result.push_str("&quot;"),
338            o => result.push(o),
339        }
340    }
341    result
342}
343
344/// BytesMut impl only std::fmt::Write but not std::io::Write. The
345/// latter trait is required for minidom's
346/// `Element::write_to_inner()`.
347struct WriteBytes<'a> {
348    dst: &'a mut BytesMut,
349}
350
351impl<'a> WriteBytes<'a> {
352    fn new(dst: &'a mut BytesMut) -> Self {
353        WriteBytes { dst }
354    }
355}
356
357impl<'a> std::io::Write for WriteBytes<'a> {
358    fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
359        self.dst.put_slice(buf);
360        Ok(buf.len())
361    }
362
363    fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
364        Ok(())
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use bytes::BytesMut;
372
373    #[test]
374    fn test_stream_start() {
375        let mut c = XMPPCodec::new();
376        let mut b = BytesMut::with_capacity(1024);
377        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
378        let r = c.decode(&mut b);
379        assert!(match r {
380            Ok(Some(Packet::StreamStart(_))) => true,
381            _ => false,
382        });
383    }
384
385    #[test]
386    fn test_stream_end() {
387        let mut c = XMPPCodec::new();
388        let mut b = BytesMut::with_capacity(1024);
389        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
390        let r = c.decode(&mut b);
391        assert!(match r {
392            Ok(Some(Packet::StreamStart(_))) => true,
393            _ => false,
394        });
395        b.clear();
396        b.put_slice(b"</stream:stream>");
397        let r = c.decode(&mut b);
398        assert!(match r {
399            Ok(Some(Packet::StreamEnd)) => true,
400            _ => false,
401        });
402    }
403
404    #[test]
405    fn test_truncated_stanza() {
406        let mut c = XMPPCodec::new();
407        let mut b = BytesMut::with_capacity(1024);
408        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
409        let r = c.decode(&mut b);
410        assert!(match r {
411            Ok(Some(Packet::StreamStart(_))) => true,
412            _ => false,
413        });
414
415        b.clear();
416        b.put_slice("<test>ß</test".as_bytes());
417        let r = c.decode(&mut b);
418        assert!(match r {
419            Ok(None) => true,
420            _ => false,
421        });
422
423        b.clear();
424        b.put_slice(b">");
425        let r = c.decode(&mut b);
426        assert!(match r {
427            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
428            _ => false,
429        });
430    }
431
432    #[test]
433    fn test_truncated_utf8() {
434        let mut c = XMPPCodec::new();
435        let mut b = BytesMut::with_capacity(1024);
436        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
437        let r = c.decode(&mut b);
438        assert!(match r {
439            Ok(Some(Packet::StreamStart(_))) => true,
440            _ => false,
441        });
442
443        b.clear();
444        b.put(&b"<test>\xc3"[..]);
445        let r = c.decode(&mut b);
446        assert!(match r {
447            Ok(None) => true,
448            _ => false,
449        });
450
451        b.clear();
452        b.put(&b"\x9f</test>"[..]);
453        let r = c.decode(&mut b);
454        assert!(match r {
455            Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
456            _ => false,
457        });
458    }
459
460    /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
461    #[test]
462    fn test_atrribute_prefix() {
463        let mut c = XMPPCodec::new();
464        let mut b = BytesMut::with_capacity(1024);
465        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
466        let r = c.decode(&mut b);
467        assert!(match r {
468            Ok(Some(Packet::StreamStart(_))) => true,
469            _ => false,
470        });
471
472        b.clear();
473        b.put_slice(b"<status xml:lang='en'>Test status</status>");
474        let r = c.decode(&mut b);
475        assert!(match r {
476            Ok(Some(Packet::Stanza(ref el)))
477                if el.name() == "status"
478                    && el.text() == "Test status"
479                    && el.attr("xml:lang").map_or(false, |a| a == "en") =>
480                true,
481            _ => false,
482        });
483    }
484
485    /// By default, encode() only get's a BytesMut that has 8kb space reserved.
486    #[test]
487    fn test_large_stanza() {
488        use futures::{executor::block_on, sink::SinkExt};
489        use std::io::Cursor;
490        use tokio_util::codec::FramedWrite;
491        let mut framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
492        let mut text = "".to_owned();
493        for _ in 0..2usize.pow(15) {
494            text = text + "A";
495        }
496        let stanza = Element::builder("message", "jabber:client")
497            .append(
498                Element::builder("body", "jabber:client")
499                    .append(text.as_ref())
500                    .build(),
501            )
502            .build();
503        block_on(framed.send(Packet::Stanza(stanza))).expect("send");
504        assert_eq!(
505            framed.get_ref().get_ref(),
506            &("<message xmlns=\"jabber:client\"><body>".to_owned() + &text + "</body></message>")
507                .as_bytes()
508        );
509    }
510
511    #[test]
512    fn test_cut_out_stanza() {
513        let mut c = XMPPCodec::new();
514        let mut b = BytesMut::with_capacity(1024);
515        b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
516        let r = c.decode(&mut b);
517        assert!(match r {
518            Ok(Some(Packet::StreamStart(_))) => true,
519            _ => false,
520        });
521
522        b.clear();
523        b.put_slice(b"<message ");
524        b.put_slice(b"type='chat'><body>Foo</body></message>");
525        let r = c.decode(&mut b);
526        assert!(match r {
527            Ok(Some(Packet::Stanza(_))) => true,
528            _ => false,
529        });
530    }
531}