1//! XML stream parser for XMPP
2
3use crate::{ParseError, ParserError};
4use bytes::{BufMut, BytesMut};
5use log::debug;
6use std;
7use std::borrow::Cow;
8use std::collections::vec_deque::VecDeque;
9use std::collections::HashMap;
10use std::default::Default;
11use std::fmt::Write;
12use std::io;
13use std::iter::FromIterator;
14use std::str::from_utf8;
15use std::sync::Arc;
16use std::sync::Mutex;
17use tokio_util::codec::{Decoder, Encoder};
18use xml5ever::buffer_queue::BufferQueue;
19use xml5ever::interface::Attribute;
20use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
21use xmpp_parsers::Element;
22
23/// Anything that can be sent or received on an XMPP/XML stream
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum Packet {
26 /// `<stream:stream>` start tag
27 StreamStart(HashMap<String, String>),
28 /// A complete stanza or nonza
29 Stanza(Element),
30 /// Plain text (think whitespace keep-alive)
31 Text(String),
32 /// `</stream:stream>` closing tag
33 StreamEnd,
34}
35
36type QueueItem = Result<Packet, ParserError>;
37
38/// Parser state
39struct ParserSink {
40 // Ready stanzas, shared with XMPPCodec
41 queue: Arc<Mutex<VecDeque<QueueItem>>>,
42 // Parsing stack
43 stack: Vec<Element>,
44 ns_stack: Vec<HashMap<Option<String>, String>>,
45}
46
47impl ParserSink {
48 pub fn new(queue: Arc<Mutex<VecDeque<QueueItem>>>) -> Self {
49 ParserSink {
50 queue,
51 stack: vec![],
52 ns_stack: vec![],
53 }
54 }
55
56 fn push_queue(&self, pkt: Packet) {
57 self.queue.lock().unwrap().push_back(Ok(pkt));
58 }
59
60 fn push_queue_error(&self, e: ParserError) {
61 self.queue.lock().unwrap().push_back(Err(e));
62 }
63
64 /// Lookup XML namespace declaration for given prefix (or no prefix)
65 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
66 for nss in self.ns_stack.iter().rev() {
67 if let Some(ns) = nss.get(prefix) {
68 return Some(ns);
69 }
70 }
71
72 None
73 }
74
75 fn handle_start_tag(&mut self, tag: Tag) {
76 let mut nss = HashMap::new();
77 let is_prefix_xmlns = |attr: &Attribute| {
78 attr.name
79 .prefix
80 .as_ref()
81 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
82 .unwrap_or(false)
83 };
84 for attr in &tag.attrs {
85 match attr.name.local.as_ref() {
86 "xmlns" => {
87 nss.insert(None, attr.value.as_ref().to_owned());
88 }
89 prefix if is_prefix_xmlns(attr) => {
90 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
91 }
92 _ => (),
93 }
94 }
95 self.ns_stack.push(nss);
96
97 let el = {
98 let mut el_builder = Element::builder(tag.name.local.as_ref());
99 if let Some(el_ns) =
100 self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
101 {
102 el_builder = el_builder.ns(el_ns);
103 }
104 for attr in &tag.attrs {
105 match attr.name.local.as_ref() {
106 "xmlns" => (),
107 _ if is_prefix_xmlns(attr) => (),
108 _ => {
109 let attr_name = if let Some(ref prefix) = attr.name.prefix {
110 Cow::Owned(format!("{}:{}", prefix, attr.name.local))
111 } else {
112 Cow::Borrowed(attr.name.local.as_ref())
113 };
114 el_builder = el_builder.attr(attr_name, attr.value.as_ref());
115 }
116 }
117 }
118 el_builder.build()
119 };
120
121 if self.stack.is_empty() {
122 let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
123 (
124 attr.name.local.as_ref().to_owned(),
125 attr.value.as_ref().to_owned(),
126 )
127 }));
128 self.push_queue(Packet::StreamStart(attrs));
129 }
130
131 self.stack.push(el);
132 }
133
134 fn handle_end_tag(&mut self) {
135 let el = self.stack.pop().unwrap();
136 self.ns_stack.pop();
137
138 match self.stack.len() {
139 // </stream:stream>
140 0 => self.push_queue(Packet::StreamEnd),
141 // </stanza>
142 1 => self.push_queue(Packet::Stanza(el)),
143 len => {
144 let parent = &mut self.stack[len - 1];
145 parent.append_child(el);
146 }
147 }
148 }
149}
150
151impl TokenSink for ParserSink {
152 fn process_token(&mut self, token: Token) {
153 match token {
154 Token::TagToken(tag) => match tag.kind {
155 TagKind::StartTag => self.handle_start_tag(tag),
156 TagKind::EndTag => self.handle_end_tag(),
157 TagKind::EmptyTag => {
158 self.handle_start_tag(tag);
159 self.handle_end_tag();
160 }
161 TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
162 },
163 Token::CharacterTokens(tendril) => match self.stack.len() {
164 0 | 1 => self.push_queue(Packet::Text(tendril.into())),
165 len => {
166 let el = &mut self.stack[len - 1];
167 el.append_text_node(tendril);
168 }
169 },
170 Token::EOFToken => self.push_queue(Packet::StreamEnd),
171 Token::ParseError(s) => {
172 self.push_queue_error(ParserError::Parse(ParseError(s)));
173 }
174 _ => (),
175 }
176 }
177
178 // fn end(&mut self) {
179 // }
180}
181
182/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
183pub struct XMPPCodec {
184 /// Outgoing
185 ns: Option<String>,
186 /// Incoming
187 parser: XmlTokenizer<ParserSink>,
188 /// For handling incoming truncated utf8
189 // TODO: optimize using tendrils?
190 buf: Vec<u8>,
191 /// Shared with ParserSink
192 queue: Arc<Mutex<VecDeque<QueueItem>>>,
193}
194
195impl XMPPCodec {
196 /// Constructor
197 pub fn new() -> Self {
198 let queue = Arc::new(Mutex::new(VecDeque::new()));
199 let sink = ParserSink::new(queue.clone());
200 // TODO: configure parser?
201 let parser = XmlTokenizer::new(sink, Default::default());
202 XMPPCodec {
203 ns: None,
204 parser,
205 queue,
206 buf: vec![],
207 }
208 }
209}
210
211impl Default for XMPPCodec {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217impl Decoder for XMPPCodec {
218 type Item = Packet;
219 type Error = ParserError;
220
221 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
222 let buf1: Box<dyn AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
223 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
224 prefix.extend_from_slice(&buf.split_to(buf.len()));
225 Box::new(prefix)
226 } else {
227 Box::new(buf.split_to(buf.len()))
228 };
229 let buf1 = buf1.as_ref().as_ref();
230 match from_utf8(buf1) {
231 Ok(s) => {
232 debug!("<< {:?}", s);
233 if !s.is_empty() {
234 let mut buffer_queue = BufferQueue::new();
235 let tendril = FromIterator::from_iter(s.chars());
236 buffer_queue.push_back(tendril);
237 self.parser.feed(&mut buffer_queue);
238 }
239 }
240 // Remedies for truncated utf8
241 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
242 // Prepare all the valid data
243 let mut b = BytesMut::with_capacity(e.valid_up_to());
244 b.put(&buf1[0..e.valid_up_to()]);
245
246 // Retry
247 let result = self.decode(&mut b);
248
249 // Keep the tail back in
250 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
251
252 return result;
253 }
254 Err(e) => {
255 // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
256 return Err(ParserError::Utf8(e));
257 }
258 }
259
260 match self.queue.lock().unwrap().pop_front() {
261 None => Ok(None),
262 Some(result) => result.map(|pkt| Some(pkt)),
263 }
264 }
265
266 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
267 self.decode(buf)
268 }
269}
270
271impl Encoder for XMPPCodec {
272 type Item = Packet;
273 type Error = io::Error;
274
275 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
276 let remaining = dst.capacity() - dst.len();
277 let max_stanza_size: usize = 2usize.pow(16);
278 if remaining < max_stanza_size {
279 dst.reserve(max_stanza_size - remaining);
280 }
281
282 fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
283 io::Error::new(io::ErrorKind::InvalidInput, e)
284 }
285
286 match item {
287 Packet::StreamStart(start_attrs) => {
288 let mut buf = String::new();
289 write!(buf, "<stream:stream").map_err(to_io_err)?;
290 for (name, value) in start_attrs {
291 write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
292 if name == "xmlns" {
293 self.ns = Some(value);
294 }
295 }
296 write!(buf, ">\n").map_err(to_io_err)?;
297
298 debug!(">> {:?}", buf);
299 write!(dst, "{}", buf).map_err(to_io_err)
300 }
301 Packet::Stanza(stanza) => stanza
302 .write_to(&mut WriteBytes::new(dst))
303 .and_then(|_| {
304 debug!(">> {:?}", dst);
305 Ok(())
306 })
307 .map_err(|e| to_io_err(format!("{}", e))),
308 Packet::Text(text) => write_text(&text, dst)
309 .and_then(|_| {
310 debug!(">> {:?}", dst);
311 Ok(())
312 })
313 .map_err(to_io_err),
314 Packet::StreamEnd => write!(dst, "</stream:stream>\n").map_err(to_io_err),
315 }
316 }
317}
318
319/// Write XML-escaped text string
320pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
321 write!(writer, "{}", escape(text))
322}
323
324/// Copied from `RustyXML` for now
325pub fn escape(input: &str) -> String {
326 let mut result = String::with_capacity(input.len());
327
328 for c in input.chars() {
329 match c {
330 '&' => result.push_str("&"),
331 '<' => result.push_str("<"),
332 '>' => result.push_str(">"),
333 '\'' => result.push_str("'"),
334 '"' => result.push_str("""),
335 o => result.push(o),
336 }
337 }
338 result
339}
340
341/// BytesMut impl only std::fmt::Write but not std::io::Write. The
342/// latter trait is required for minidom's
343/// `Element::write_to_inner()`.
344struct WriteBytes<'a> {
345 dst: &'a mut BytesMut,
346}
347
348impl<'a> WriteBytes<'a> {
349 fn new(dst: &'a mut BytesMut) -> Self {
350 WriteBytes { dst }
351 }
352}
353
354impl<'a> std::io::Write for WriteBytes<'a> {
355 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
356 self.dst.put_slice(buf);
357 Ok(buf.len())
358 }
359
360 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
361 Ok(())
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368 use bytes::BytesMut;
369
370 #[test]
371 fn test_stream_start() {
372 let mut c = XMPPCodec::new();
373 let mut b = BytesMut::with_capacity(1024);
374 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
375 let r = c.decode(&mut b);
376 assert!(match r {
377 Ok(Some(Packet::StreamStart(_))) => true,
378 _ => false,
379 });
380 }
381
382 #[test]
383 fn test_stream_end() {
384 let mut c = XMPPCodec::new();
385 let mut b = BytesMut::with_capacity(1024);
386 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
387 let r = c.decode(&mut b);
388 assert!(match r {
389 Ok(Some(Packet::StreamStart(_))) => true,
390 _ => false,
391 });
392 b.clear();
393 b.put_slice(b"</stream:stream>");
394 let r = c.decode(&mut b);
395 assert!(match r {
396 Ok(Some(Packet::StreamEnd)) => true,
397 _ => false,
398 });
399 }
400
401 #[test]
402 fn test_truncated_stanza() {
403 let mut c = XMPPCodec::new();
404 let mut b = BytesMut::with_capacity(1024);
405 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
406 let r = c.decode(&mut b);
407 assert!(match r {
408 Ok(Some(Packet::StreamStart(_))) => true,
409 _ => false,
410 });
411
412 b.clear();
413 b.put_slice("<test>ß</test".as_bytes());
414 let r = c.decode(&mut b);
415 assert!(match r {
416 Ok(None) => true,
417 _ => false,
418 });
419
420 b.clear();
421 b.put_slice(b">");
422 let r = c.decode(&mut b);
423 assert!(match r {
424 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
425 _ => false,
426 });
427 }
428
429 #[test]
430 fn test_truncated_utf8() {
431 let mut c = XMPPCodec::new();
432 let mut b = BytesMut::with_capacity(1024);
433 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
434 let r = c.decode(&mut b);
435 assert!(match r {
436 Ok(Some(Packet::StreamStart(_))) => true,
437 _ => false,
438 });
439
440 b.clear();
441 b.put(&b"<test>\xc3"[..]);
442 let r = c.decode(&mut b);
443 assert!(match r {
444 Ok(None) => true,
445 _ => false,
446 });
447
448 b.clear();
449 b.put(&b"\x9f</test>"[..]);
450 let r = c.decode(&mut b);
451 assert!(match r {
452 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
453 _ => false,
454 });
455 }
456
457 /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
458 #[test]
459 fn test_atrribute_prefix() {
460 let mut c = XMPPCodec::new();
461 let mut b = BytesMut::with_capacity(1024);
462 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
463 let r = c.decode(&mut b);
464 assert!(match r {
465 Ok(Some(Packet::StreamStart(_))) => true,
466 _ => false,
467 });
468
469 b.clear();
470 b.put_slice(b"<status xml:lang='en'>Test status</status>");
471 let r = c.decode(&mut b);
472 assert!(match r {
473 Ok(Some(Packet::Stanza(ref el)))
474 if el.name() == "status"
475 && el.text() == "Test status"
476 && el.attr("xml:lang").map_or(false, |a| a == "en") =>
477 true,
478 _ => false,
479 });
480 }
481
482 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
483 #[test]
484 fn test_large_stanza() {
485 use futures::{executor::block_on, sink::SinkExt};
486 use std::io::Cursor;
487 use tokio_util::codec::FramedWrite;
488 let mut framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
489 let mut text = "".to_owned();
490 for _ in 0..2usize.pow(15) {
491 text = text + "A";
492 }
493 let stanza = Element::builder("message")
494 .append(Element::builder("body").append(text.as_ref()).build())
495 .build();
496 block_on(framed.send(Packet::Stanza(stanza))).expect("send");
497 assert_eq!(
498 framed.get_ref().get_ref(),
499 &("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
500 );
501 }
502
503 #[test]
504 fn test_cut_out_stanza() {
505 let mut c = XMPPCodec::new();
506 let mut b = BytesMut::with_capacity(1024);
507 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
508 let r = c.decode(&mut b);
509 assert!(match r {
510 Ok(Some(Packet::StreamStart(_))) => true,
511 _ => false,
512 });
513
514 b.clear();
515 b.put_slice(b"<message ");
516 b.put_slice(b"type='chat'><body>Foo</body></message>");
517 let r = c.decode(&mut b);
518 assert!(match r {
519 Ok(Some(Packet::Stanza(_))) => true,
520 _ => false,
521 });
522 }
523}