1//! XML stream parser for XMPP
2
3use std;
4use std::default::Default;
5use std::iter::FromIterator;
6use std::cell::RefCell;
7use std::rc::Rc;
8use std::fmt::Write;
9use std::str::{from_utf8, Utf8Error};
10use std::io;
11use std::collections::HashMap;
12use std::collections::vec_deque::VecDeque;
13use std::error::Error as StdError;
14use std::fmt;
15use std::borrow::Cow;
16use tokio_codec::{Encoder, Decoder};
17use minidom::Element;
18use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
19use xml5ever::interface::Attribute;
20use bytes::{BytesMut, BufMut};
21use quick_xml::Writer as EventWriter;
22
23/// Anything that can be sent or received on an XMPP/XML stream
24#[derive(Debug)]
25pub enum Packet {
26 /// `<stream:stream>` start tag
27 StreamStart(HashMap<String, String>),
28 /// A complete stanza or nonza
29 Stanza(Element),
30 /// Plain text (think whitespace keep-alive)
31 Text(String),
32 /// `</stream:stream>` closing tag
33 StreamEnd,
34}
35
36/// Causes for stream parsing errors
37#[derive(Debug)]
38pub enum ParserError {
39 /// Encoding error
40 Utf8(Utf8Error),
41 /// XML parse error
42 Parse(Cow<'static, str>),
43 /// Illegal `</>`
44 ShortTag,
45 /// Required by `impl Decoder`
46 IO(io::Error),
47}
48
49impl From<io::Error> for ParserError {
50 fn from(e: io::Error) -> Self {
51 ParserError::IO(e)
52 }
53}
54
55impl StdError for ParserError {
56 fn description(&self) -> &str {
57 match *self {
58 ParserError::Utf8(ref ue) => ue.description(),
59 ParserError::Parse(ref pe) => pe,
60 ParserError::ShortTag => "short tag",
61 ParserError::IO(ref ie) => ie.description(),
62 }
63 }
64
65 fn cause(&self) -> Option<&StdError> {
66 match *self {
67 ParserError::Utf8(ref ue) => ue.cause(),
68 ParserError::IO(ref ie) => ie.cause(),
69 _ => None,
70 }
71 }
72}
73
74impl fmt::Display for ParserError {
75 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76 match *self {
77 ParserError::Utf8(ref ue) => write!(f, "{}", ue),
78 ParserError::Parse(ref pe) => write!(f, "{}", pe),
79 ParserError::ShortTag => write!(f, "Short tag"),
80 ParserError::IO(ref ie) => write!(f, "{}", ie),
81 }
82 }
83}
84
85type QueueItem = Result<Packet, ParserError>;
86
87/// Parser state
88struct ParserSink {
89 // Ready stanzas, shared with XMPPCodec
90 queue: Rc<RefCell<VecDeque<QueueItem>>>,
91 // Parsing stack
92 stack: Vec<Element>,
93 ns_stack: Vec<HashMap<Option<String>, String>>,
94}
95
96impl ParserSink {
97 pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
98 ParserSink {
99 queue,
100 stack: vec![],
101 ns_stack: vec![],
102 }
103 }
104
105 fn push_queue(&self, pkt: Packet) {
106 self.queue.borrow_mut().push_back(Ok(pkt));
107 }
108
109 fn push_queue_error(&self, e: ParserError) {
110 self.queue.borrow_mut().push_back(Err(e));
111 }
112
113 /// Lookup XML namespace declaration for given prefix (or no prefix)
114 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
115 for nss in self.ns_stack.iter().rev() {
116 if let Some(ns) = nss.get(prefix) {
117 return Some(ns);
118 }
119 }
120
121 None
122 }
123
124 fn handle_start_tag(&mut self, tag: Tag) {
125 let mut nss = HashMap::new();
126 let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
127 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
128 .unwrap_or(false);
129 for attr in &tag.attrs {
130 match attr.name.local.as_ref() {
131 "xmlns" => {
132 nss.insert(None, attr.value.as_ref().to_owned());
133 },
134 prefix if is_prefix_xmlns(attr) => {
135 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
136 },
137 _ => (),
138 }
139 }
140 self.ns_stack.push(nss);
141
142 let el = {
143 let mut el_builder = Element::builder(tag.name.local.as_ref());
144 if let Some(el_ns) = self.lookup_ns(
145 &tag.name.prefix.map(|prefix|
146 prefix.as_ref().to_owned())
147 ) {
148 el_builder = el_builder.ns(el_ns);
149 }
150 for attr in &tag.attrs {
151 match attr.name.local.as_ref() {
152 "xmlns" => (),
153 _ if is_prefix_xmlns(attr) => (),
154 _ => {
155 el_builder = el_builder.attr(
156 attr.name.local.as_ref(),
157 attr.value.as_ref()
158 );
159 },
160 }
161 }
162 el_builder.build()
163 };
164
165 if self.stack.is_empty() {
166 let attrs = HashMap::from_iter(
167 tag.attrs.iter()
168 .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
169 );
170 self.push_queue(Packet::StreamStart(attrs));
171 }
172
173 self.stack.push(el);
174 }
175
176 fn handle_end_tag(&mut self) {
177 let el = self.stack.pop().unwrap();
178 self.ns_stack.pop();
179
180 match self.stack.len() {
181 // </stream:stream>
182 0 =>
183 self.push_queue(Packet::StreamEnd),
184 // </stanza>
185 1 =>
186 self.push_queue(Packet::Stanza(el)),
187 len => {
188 let parent = &mut self.stack[len - 1];
189 parent.append_child(el);
190 },
191 }
192 }
193}
194
195impl TokenSink for ParserSink {
196 fn process_token(&mut self, token: Token) {
197 match token {
198 Token::TagToken(tag) => match tag.kind {
199 TagKind::StartTag =>
200 self.handle_start_tag(tag),
201 TagKind::EndTag =>
202 self.handle_end_tag(),
203 TagKind::EmptyTag => {
204 self.handle_start_tag(tag);
205 self.handle_end_tag();
206 },
207 TagKind::ShortTag =>
208 self.push_queue_error(ParserError::ShortTag),
209 },
210 Token::CharacterTokens(tendril) =>
211 match self.stack.len() {
212 0 | 1 =>
213 self.push_queue(Packet::Text(tendril.into())),
214 len => {
215 let el = &mut self.stack[len - 1];
216 el.append_text_node(tendril);
217 },
218 },
219 Token::EOFToken =>
220 self.push_queue(Packet::StreamEnd),
221 Token::ParseError(s) => {
222 // println!("ParseError: {:?}", s);
223 self.push_queue_error(ParserError::Parse(s));
224 },
225 _ => (),
226 }
227 }
228
229 // fn end(&mut self) {
230 // }
231}
232
233/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
234pub struct XMPPCodec {
235 /// Outgoing
236 ns: Option<String>,
237 /// Incoming
238 parser: XmlTokenizer<ParserSink>,
239 /// For handling incoming truncated utf8
240 // TODO: optimize using tendrils?
241 buf: Vec<u8>,
242 /// Shared with ParserSink
243 queue: Rc<RefCell<VecDeque<QueueItem>>>,
244}
245
246impl XMPPCodec {
247 /// Constructor
248 pub fn new() -> Self {
249 let queue = Rc::new(RefCell::new(VecDeque::new()));
250 let sink = ParserSink::new(queue.clone());
251 // TODO: configure parser?
252 let parser = XmlTokenizer::new(sink, Default::default());
253 XMPPCodec {
254 ns: None,
255 parser,
256 queue,
257 buf: vec![],
258 }
259 }
260}
261
262impl Default for XMPPCodec {
263 fn default() -> Self {
264 Self::new()
265 }
266}
267
268impl Decoder for XMPPCodec {
269 type Item = Packet;
270 type Error = ParserError;
271
272 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
273 let buf1: Box<AsRef<[u8]>> =
274 if ! self.buf.is_empty() && ! buf.is_empty() {
275 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
276 prefix.extend_from_slice(buf.take().as_ref());
277 Box::new(prefix)
278 } else {
279 Box::new(buf.take())
280 };
281 let buf1 = buf1.as_ref().as_ref();
282 match from_utf8(buf1) {
283 Ok(s) => {
284 if ! s.is_empty() {
285 // println!("<< {}", s);
286 let tendril = FromIterator::from_iter(s.chars());
287 self.parser.feed(tendril);
288 }
289 },
290 // Remedies for truncated utf8
291 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
292 // Prepare all the valid data
293 let mut b = BytesMut::with_capacity(e.valid_up_to());
294 b.put(&buf1[0..e.valid_up_to()]);
295
296 // Retry
297 let result = self.decode(&mut b);
298
299 // Keep the tail back in
300 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
301
302 return result;
303 },
304 Err(e) => {
305 // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
306 return Err(ParserError::Utf8(e));
307 },
308 }
309
310 match self.queue.borrow_mut().pop_front() {
311 None => Ok(None),
312 Some(result) =>
313 result.map(|pkt| Some(pkt)),
314 }
315 }
316
317 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
318 self.decode(buf)
319 }
320}
321
322impl Encoder for XMPPCodec {
323 type Item = Packet;
324 type Error = io::Error;
325
326 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
327 let remaining = dst.capacity() - dst.len();
328 let max_stanza_size: usize = 2usize.pow(16);
329 if remaining < max_stanza_size {
330 dst.reserve(max_stanza_size - remaining);
331 }
332
333 match item {
334 Packet::StreamStart(start_attrs) => {
335 let mut buf = String::new();
336 write!(buf, "<stream:stream").unwrap();
337 for (name, value) in start_attrs {
338 write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
339 .unwrap();
340 if name == "xmlns" {
341 self.ns = Some(value);
342 }
343 }
344 write!(buf, ">\n").unwrap();
345
346 print!(">> {}", buf);
347 write!(dst, "{}", buf)
348 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
349 },
350 Packet::Stanza(stanza) => {
351 stanza.write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
352 .and_then(|_| {
353 // println!(">> {:?}", dst);
354 Ok(())
355 })
356 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
357 },
358 Packet::Text(text) => {
359 write_text(&text, dst)
360 .and_then(|_| {
361 // println!(">> {:?}", dst);
362 Ok(())
363 })
364 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
365 },
366 // TODO: Implement all
367 _ => Ok(())
368 }
369 }
370}
371
372/// Write XML-escaped text string
373pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
374 write!(writer, "{}", escape(text))
375}
376
377/// Copied from `RustyXML` for now
378pub fn escape(input: &str) -> String {
379 let mut result = String::with_capacity(input.len());
380
381 for c in input.chars() {
382 match c {
383 '&' => result.push_str("&"),
384 '<' => result.push_str("<"),
385 '>' => result.push_str(">"),
386 '\'' => result.push_str("'"),
387 '"' => result.push_str("""),
388 o => result.push(o)
389 }
390 }
391 result
392}
393
394/// BytesMut impl only std::fmt::Write but not std::io::Write. The
395/// latter trait is required for minidom's
396/// `Element::write_to_inner()`.
397struct WriteBytes<'a> {
398 dst: &'a mut BytesMut,
399}
400
401impl<'a> WriteBytes<'a> {
402 fn new(dst: &'a mut BytesMut) -> Self {
403 WriteBytes { dst }
404 }
405}
406
407impl<'a> std::io::Write for WriteBytes<'a> {
408 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
409 self.dst.put_slice(buf);
410 Ok(buf.len())
411 }
412
413 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
414 Ok(())
415 }
416}
417
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use bytes::BytesMut;
423
424 #[test]
425 fn test_stream_start() {
426 let mut c = XMPPCodec::new();
427 let mut b = BytesMut::with_capacity(1024);
428 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
429 let r = c.decode(&mut b);
430 assert!(match r {
431 Ok(Some(Packet::StreamStart(_))) => true,
432 _ => false,
433 });
434 }
435
436 #[test]
437 fn test_truncated_stanza() {
438 let mut c = XMPPCodec::new();
439 let mut b = BytesMut::with_capacity(1024);
440 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
441 let r = c.decode(&mut b);
442 assert!(match r {
443 Ok(Some(Packet::StreamStart(_))) => true,
444 _ => false,
445 });
446
447 b.clear();
448 b.put(r"<test>ß</test");
449 let r = c.decode(&mut b);
450 assert!(match r {
451 Ok(None) => true,
452 _ => false,
453 });
454
455 b.clear();
456 b.put(r">");
457 let r = c.decode(&mut b);
458 assert!(match r {
459 Ok(Some(Packet::Stanza(ref el)))
460 if el.name() == "test"
461 && el.text() == "ß"
462 => true,
463 _ => false,
464 });
465 }
466
467 #[test]
468 fn test_truncated_utf8() {
469 let mut c = XMPPCodec::new();
470 let mut b = BytesMut::with_capacity(1024);
471 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
472 let r = c.decode(&mut b);
473 assert!(match r {
474 Ok(Some(Packet::StreamStart(_))) => true,
475 _ => false,
476 });
477
478 b.clear();
479 b.put(&b"<test>\xc3"[..]);
480 let r = c.decode(&mut b);
481 assert!(match r {
482 Ok(None) => true,
483 _ => false,
484 });
485
486 b.clear();
487 b.put(&b"\x9f</test>"[..]);
488 let r = c.decode(&mut b);
489 assert!(match r {
490 Ok(Some(Packet::Stanza(ref el)))
491 if el.name() == "test"
492 && el.text() == "ß"
493 => true,
494 _ => false,
495 });
496 }
497
498 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
499 #[test]
500 fn test_large_stanza() {
501 use std::io::Cursor;
502 use futures::{Future, Sink};
503 use tokio_codec::FramedWrite;
504 let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
505 let mut text = "".to_owned();
506 for _ in 0..2usize.pow(15) {
507 text = text + "A";
508 }
509 let stanza = Element::builder("message")
510 .append(
511 Element::builder("body")
512 .append(&text)
513 .build()
514 )
515 .build();
516 let framed = framed.send(Packet::Stanza(stanza))
517 .wait()
518 .expect("send");
519 assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
520 }
521}