1//! XML stream parser for XMPP
2
3use crate::{ParseError, ParserError};
4use bytes::{BufMut, BytesMut};
5use minidom::Element;
6use quick_xml::Writer as EventWriter;
7use std;
8use std::cell::RefCell;
9use std::collections::vec_deque::VecDeque;
10use std::collections::HashMap;
11use std::default::Default;
12use std::fmt::Write;
13use std::io;
14use std::iter::FromIterator;
15use std::rc::Rc;
16use std::str::from_utf8;
17use tokio_codec::{Decoder, Encoder};
18use xml5ever::interface::Attribute;
19use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
20
21/// Anything that can be sent or received on an XMPP/XML stream
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum Packet {
24 /// `<stream:stream>` start tag
25 StreamStart(HashMap<String, String>),
26 /// A complete stanza or nonza
27 Stanza(Element),
28 /// Plain text (think whitespace keep-alive)
29 Text(String),
30 /// `</stream:stream>` closing tag
31 StreamEnd,
32}
33
34type QueueItem = Result<Packet, ParserError>;
35
36/// Parser state
37struct ParserSink {
38 // Ready stanzas, shared with XMPPCodec
39 queue: Rc<RefCell<VecDeque<QueueItem>>>,
40 // Parsing stack
41 stack: Vec<Element>,
42 ns_stack: Vec<HashMap<Option<String>, String>>,
43}
44
45impl ParserSink {
46 pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
47 ParserSink {
48 queue,
49 stack: vec![],
50 ns_stack: vec![],
51 }
52 }
53
54 fn push_queue(&self, pkt: Packet) {
55 self.queue.borrow_mut().push_back(Ok(pkt));
56 }
57
58 fn push_queue_error(&self, e: ParserError) {
59 self.queue.borrow_mut().push_back(Err(e));
60 }
61
62 /// Lookup XML namespace declaration for given prefix (or no prefix)
63 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
64 for nss in self.ns_stack.iter().rev() {
65 if let Some(ns) = nss.get(prefix) {
66 return Some(ns);
67 }
68 }
69
70 None
71 }
72
73 fn handle_start_tag(&mut self, tag: Tag) {
74 let mut nss = HashMap::new();
75 let is_prefix_xmlns = |attr: &Attribute| {
76 attr.name
77 .prefix
78 .as_ref()
79 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
80 .unwrap_or(false)
81 };
82 for attr in &tag.attrs {
83 match attr.name.local.as_ref() {
84 "xmlns" => {
85 nss.insert(None, attr.value.as_ref().to_owned());
86 }
87 prefix if is_prefix_xmlns(attr) => {
88 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
89 }
90 _ => (),
91 }
92 }
93 self.ns_stack.push(nss);
94
95 let el = {
96 let mut el_builder = Element::builder(tag.name.local.as_ref());
97 if let Some(el_ns) =
98 self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
99 {
100 el_builder = el_builder.ns(el_ns);
101 }
102 for attr in &tag.attrs {
103 match attr.name.local.as_ref() {
104 "xmlns" => (),
105 _ if is_prefix_xmlns(attr) => (),
106 _ => {
107 if let Some(ref prefix) = attr.name.prefix {
108 el_builder = el_builder.attr(format!("{}:{}", prefix, attr.name.local), attr.value.as_ref());
109 } else {
110 el_builder = el_builder.attr(attr.name.local.as_ref(), attr.value.as_ref());
111 }
112 }
113 }
114 }
115 el_builder.build()
116 };
117
118 if self.stack.is_empty() {
119 let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
120 (
121 attr.name.local.as_ref().to_owned(),
122 attr.value.as_ref().to_owned(),
123 )
124 }));
125 self.push_queue(Packet::StreamStart(attrs));
126 }
127
128 self.stack.push(el);
129 }
130
131 fn handle_end_tag(&mut self) {
132 let el = self.stack.pop().unwrap();
133 self.ns_stack.pop();
134
135 match self.stack.len() {
136 // </stream:stream>
137 0 => self.push_queue(Packet::StreamEnd),
138 // </stanza>
139 1 => self.push_queue(Packet::Stanza(el)),
140 len => {
141 let parent = &mut self.stack[len - 1];
142 parent.append_child(el);
143 }
144 }
145 }
146}
147
148impl TokenSink for ParserSink {
149 fn process_token(&mut self, token: Token) {
150 match token {
151 Token::TagToken(tag) => match tag.kind {
152 TagKind::StartTag => self.handle_start_tag(tag),
153 TagKind::EndTag => self.handle_end_tag(),
154 TagKind::EmptyTag => {
155 self.handle_start_tag(tag);
156 self.handle_end_tag();
157 }
158 TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
159 },
160 Token::CharacterTokens(tendril) => match self.stack.len() {
161 0 | 1 => self.push_queue(Packet::Text(tendril.into())),
162 len => {
163 let el = &mut self.stack[len - 1];
164 el.append_text_node(tendril);
165 }
166 },
167 Token::EOFToken => self.push_queue(Packet::StreamEnd),
168 Token::ParseError(s) => {
169 // println!("ParseError: {:?}", s);
170 self.push_queue_error(ParserError::Parse(ParseError(s)));
171 }
172 _ => (),
173 }
174 }
175
176 // fn end(&mut self) {
177 // }
178}
179
180/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
181pub struct XMPPCodec {
182 /// Outgoing
183 ns: Option<String>,
184 /// Incoming
185 parser: XmlTokenizer<ParserSink>,
186 /// For handling incoming truncated utf8
187 // TODO: optimize using tendrils?
188 buf: Vec<u8>,
189 /// Shared with ParserSink
190 queue: Rc<RefCell<VecDeque<QueueItem>>>,
191}
192
193impl XMPPCodec {
194 /// Constructor
195 pub fn new() -> Self {
196 let queue = Rc::new(RefCell::new(VecDeque::new()));
197 let sink = ParserSink::new(queue.clone());
198 // TODO: configure parser?
199 let parser = XmlTokenizer::new(sink, Default::default());
200 XMPPCodec {
201 ns: None,
202 parser,
203 queue,
204 buf: vec![],
205 }
206 }
207}
208
209impl Default for XMPPCodec {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215impl Decoder for XMPPCodec {
216 type Item = Packet;
217 type Error = ParserError;
218
219 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
220 let buf1: Box<AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
221 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
222 prefix.extend_from_slice(buf.take().as_ref());
223 Box::new(prefix)
224 } else {
225 Box::new(buf.take())
226 };
227 let buf1 = buf1.as_ref().as_ref();
228 match from_utf8(buf1) {
229 Ok(s) => {
230 if !s.is_empty() {
231 // println!("<< {}", s);
232 let tendril = FromIterator::from_iter(s.chars());
233 self.parser.feed(tendril);
234 }
235 }
236 // Remedies for truncated utf8
237 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
238 // Prepare all the valid data
239 let mut b = BytesMut::with_capacity(e.valid_up_to());
240 b.put(&buf1[0..e.valid_up_to()]);
241
242 // Retry
243 let result = self.decode(&mut b);
244
245 // Keep the tail back in
246 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
247
248 return result;
249 }
250 Err(e) => {
251 // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
252 return Err(ParserError::Utf8(e));
253 }
254 }
255
256 match self.queue.borrow_mut().pop_front() {
257 None => Ok(None),
258 Some(result) => result.map(|pkt| Some(pkt)),
259 }
260 }
261
262 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
263 self.decode(buf)
264 }
265}
266
267impl Encoder for XMPPCodec {
268 type Item = Packet;
269 type Error = io::Error;
270
271 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
272 let remaining = dst.capacity() - dst.len();
273 let max_stanza_size: usize = 2usize.pow(16);
274 if remaining < max_stanza_size {
275 dst.reserve(max_stanza_size - remaining);
276 }
277
278 match item {
279 Packet::StreamStart(start_attrs) => {
280 let mut buf = String::new();
281 write!(buf, "<stream:stream").unwrap();
282 for (name, value) in start_attrs {
283 write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).unwrap();
284 if name == "xmlns" {
285 self.ns = Some(value);
286 }
287 }
288 write!(buf, ">\n").unwrap();
289
290 print!(">> {}", buf);
291 write!(dst, "{}", buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
292 }
293 Packet::Stanza(stanza) => {
294 stanza
295 .write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
296 .and_then(|_| {
297 // println!(">> {:?}", dst);
298 Ok(())
299 })
300 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
301 }
302 Packet::Text(text) => {
303 write_text(&text, dst)
304 .and_then(|_| {
305 // println!(">> {:?}", dst);
306 Ok(())
307 })
308 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
309 }
310 // TODO: Implement all
311 _ => Ok(()),
312 }
313 }
314}
315
316/// Write XML-escaped text string
317pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
318 write!(writer, "{}", escape(text))
319}
320
321/// Copied from `RustyXML` for now
322pub fn escape(input: &str) -> String {
323 let mut result = String::with_capacity(input.len());
324
325 for c in input.chars() {
326 match c {
327 '&' => result.push_str("&"),
328 '<' => result.push_str("<"),
329 '>' => result.push_str(">"),
330 '\'' => result.push_str("'"),
331 '"' => result.push_str("""),
332 o => result.push(o),
333 }
334 }
335 result
336}
337
338/// BytesMut impl only std::fmt::Write but not std::io::Write. The
339/// latter trait is required for minidom's
340/// `Element::write_to_inner()`.
341struct WriteBytes<'a> {
342 dst: &'a mut BytesMut,
343}
344
345impl<'a> WriteBytes<'a> {
346 fn new(dst: &'a mut BytesMut) -> Self {
347 WriteBytes { dst }
348 }
349}
350
351impl<'a> std::io::Write for WriteBytes<'a> {
352 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
353 self.dst.put_slice(buf);
354 Ok(buf.len())
355 }
356
357 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
358 Ok(())
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use bytes::BytesMut;
366
367 #[test]
368 fn test_stream_start() {
369 let mut c = XMPPCodec::new();
370 let mut b = BytesMut::with_capacity(1024);
371 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
372 let r = c.decode(&mut b);
373 assert!(match r {
374 Ok(Some(Packet::StreamStart(_))) => true,
375 _ => false,
376 });
377 }
378
379 #[test]
380 fn test_truncated_stanza() {
381 let mut c = XMPPCodec::new();
382 let mut b = BytesMut::with_capacity(1024);
383 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
384 let r = c.decode(&mut b);
385 assert!(match r {
386 Ok(Some(Packet::StreamStart(_))) => true,
387 _ => false,
388 });
389
390 b.clear();
391 b.put(r"<test>ß</test");
392 let r = c.decode(&mut b);
393 assert!(match r {
394 Ok(None) => true,
395 _ => false,
396 });
397
398 b.clear();
399 b.put(r">");
400 let r = c.decode(&mut b);
401 assert!(match r {
402 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
403 _ => false,
404 });
405 }
406
407 #[test]
408 fn test_truncated_utf8() {
409 let mut c = XMPPCodec::new();
410 let mut b = BytesMut::with_capacity(1024);
411 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
412 let r = c.decode(&mut b);
413 assert!(match r {
414 Ok(Some(Packet::StreamStart(_))) => true,
415 _ => false,
416 });
417
418 b.clear();
419 b.put(&b"<test>\xc3"[..]);
420 let r = c.decode(&mut b);
421 assert!(match r {
422 Ok(None) => true,
423 _ => false,
424 });
425
426 b.clear();
427 b.put(&b"\x9f</test>"[..]);
428 let r = c.decode(&mut b);
429 assert!(match r {
430 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
431 _ => false,
432 });
433 }
434
435 /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
436 #[test]
437 fn test_atrribute_prefix() {
438 let mut c = XMPPCodec::new();
439 let mut b = BytesMut::with_capacity(1024);
440 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
441 let r = c.decode(&mut b);
442 assert!(match r {
443 Ok(Some(Packet::StreamStart(_))) => true,
444 _ => false,
445 });
446
447 b.clear();
448 b.put(r"<status xml:lang='en'>Test status</status>");
449 let r = c.decode(&mut b);
450 assert!(match r {
451 Ok(Some(Packet::Stanza(ref el))) if el.name() == "status" && el.text() == "Test status" && el.attr("xml:lang").map_or(false, |a| a == "en") => true,
452 _ => false,
453 });
454
455 }
456
457 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
458 #[test]
459 fn test_large_stanza() {
460 use futures::{Future, Sink};
461 use std::io::Cursor;
462 use tokio_codec::FramedWrite;
463 let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
464 let mut text = "".to_owned();
465 for _ in 0..2usize.pow(15) {
466 text = text + "A";
467 }
468 let stanza = Element::builder("message")
469 .append(Element::builder("body").append(&text).build())
470 .build();
471 let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send");
472 assert_eq!(
473 framed.get_ref().get_ref(),
474 &("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
475 );
476 }
477}