1//! XML stream parser for XMPP
2
3use crate::{ParseError, ParserError};
4use bytes::{BufMut, BytesMut};
5use log::{debug, error};
6use std;
7use std::borrow::Cow;
8use std::collections::vec_deque::VecDeque;
9use std::collections::HashMap;
10use std::default::Default;
11use std::fmt::Write;
12use std::io;
13use std::iter::FromIterator;
14use std::str::from_utf8;
15use std::sync::Arc;
16use std::sync::Mutex;
17use tokio_util::codec::{Decoder, Encoder};
18use xml5ever::buffer_queue::BufferQueue;
19use xml5ever::interface::Attribute;
20use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
21use xmpp_parsers::Element;
22
23/// Anything that can be sent or received on an XMPP/XML stream
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum Packet {
26 /// `<stream:stream>` start tag
27 StreamStart(HashMap<String, String>),
28 /// A complete stanza or nonza
29 Stanza(Element),
30 /// Plain text (think whitespace keep-alive)
31 Text(String),
32 /// `</stream:stream>` closing tag
33 StreamEnd,
34}
35
36type QueueItem = Result<Packet, ParserError>;
37
38/// Parser state
39struct ParserSink {
40 // Ready stanzas, shared with XMPPCodec
41 queue: Arc<Mutex<VecDeque<QueueItem>>>,
42 // Parsing stack
43 stack: Vec<Element>,
44 ns_stack: Vec<HashMap<Option<String>, String>>,
45}
46
47impl ParserSink {
48 pub fn new(queue: Arc<Mutex<VecDeque<QueueItem>>>) -> Self {
49 ParserSink {
50 queue,
51 stack: vec![],
52 ns_stack: vec![],
53 }
54 }
55
56 fn push_queue(&self, pkt: Packet) {
57 self.queue.lock().unwrap().push_back(Ok(pkt));
58 }
59
60 fn push_queue_error(&self, e: ParserError) {
61 self.queue.lock().unwrap().push_back(Err(e));
62 }
63
64 /// Lookup XML namespace declaration for given prefix (or no prefix)
65 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
66 for nss in self.ns_stack.iter().rev() {
67 if let Some(ns) = nss.get(prefix) {
68 return Some(ns);
69 }
70 }
71
72 None
73 }
74
75 fn handle_start_tag(&mut self, tag: Tag) {
76 let mut nss = HashMap::new();
77 let is_prefix_xmlns = |attr: &Attribute| {
78 attr.name
79 .prefix
80 .as_ref()
81 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
82 .unwrap_or(false)
83 };
84 for attr in &tag.attrs {
85 match attr.name.local.as_ref() {
86 "xmlns" => {
87 nss.insert(None, attr.value.as_ref().to_owned());
88 }
89 prefix if is_prefix_xmlns(attr) => {
90 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
91 }
92 _ => (),
93 }
94 }
95 self.ns_stack.push(nss);
96
97 let el = {
98 let el_ns = self
99 .lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
100 .unwrap();
101 let mut el_builder = Element::builder(tag.name.local.as_ref(), el_ns);
102 for attr in &tag.attrs {
103 match attr.name.local.as_ref() {
104 "xmlns" => (),
105 _ if is_prefix_xmlns(attr) => (),
106 _ => {
107 let attr_name = if let Some(ref prefix) = attr.name.prefix {
108 Cow::Owned(format!("{}:{}", prefix, attr.name.local))
109 } else {
110 Cow::Borrowed(attr.name.local.as_ref())
111 };
112 el_builder = el_builder.attr(attr_name, attr.value.as_ref());
113 }
114 }
115 }
116 el_builder.build()
117 };
118
119 if self.stack.is_empty() {
120 let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
121 (
122 attr.name.local.as_ref().to_owned(),
123 attr.value.as_ref().to_owned(),
124 )
125 }));
126 self.push_queue(Packet::StreamStart(attrs));
127 }
128
129 self.stack.push(el);
130 }
131
132 fn handle_end_tag(&mut self) {
133 let el = self.stack.pop().unwrap();
134 self.ns_stack.pop();
135
136 match self.stack.len() {
137 // </stream:stream>
138 0 => self.push_queue(Packet::StreamEnd),
139 // </stanza>
140 1 => self.push_queue(Packet::Stanza(el)),
141 len => {
142 let parent = &mut self.stack[len - 1];
143 parent.append_child(el);
144 }
145 }
146 }
147}
148
149impl TokenSink for ParserSink {
150 fn process_token(&mut self, token: Token) {
151 match token {
152 Token::TagToken(tag) => match tag.kind {
153 TagKind::StartTag => self.handle_start_tag(tag),
154 TagKind::EndTag => self.handle_end_tag(),
155 TagKind::EmptyTag => {
156 self.handle_start_tag(tag);
157 self.handle_end_tag();
158 }
159 TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
160 },
161 Token::CharacterTokens(tendril) => match self.stack.len() {
162 0 | 1 => self.push_queue(Packet::Text(tendril.into())),
163 len => {
164 let el = &mut self.stack[len - 1];
165 el.append_text_node(tendril);
166 }
167 },
168 Token::EOFToken => self.push_queue(Packet::StreamEnd),
169 Token::ParseError(s) => {
170 self.push_queue_error(ParserError::Parse(ParseError(s)));
171 }
172 _ => (),
173 }
174 }
175
176 // fn end(&mut self) {
177 // }
178}
179
180/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
181pub struct XMPPCodec {
182 /// Outgoing
183 ns: Option<String>,
184 /// Incoming
185 parser: XmlTokenizer<ParserSink>,
186 /// For handling incoming truncated utf8
187 // TODO: optimize using tendrils?
188 buf: Vec<u8>,
189 /// Shared with ParserSink
190 queue: Arc<Mutex<VecDeque<QueueItem>>>,
191}
192
193impl XMPPCodec {
194 /// Constructor
195 pub fn new() -> Self {
196 let queue = Arc::new(Mutex::new(VecDeque::new()));
197 let sink = ParserSink::new(queue.clone());
198 // TODO: configure parser?
199 let parser = XmlTokenizer::new(sink, Default::default());
200 XMPPCodec {
201 ns: None,
202 parser,
203 queue,
204 buf: vec![],
205 }
206 }
207}
208
209impl Default for XMPPCodec {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215impl Decoder for XMPPCodec {
216 type Item = Packet;
217 type Error = ParserError;
218
219 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
220 let buf1: Box<dyn AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
221 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
222 prefix.extend_from_slice(&buf.split_to(buf.len()));
223 Box::new(prefix)
224 } else {
225 Box::new(buf.split_to(buf.len()))
226 };
227 let buf1 = buf1.as_ref().as_ref();
228 match from_utf8(buf1) {
229 Ok(s) => {
230 debug!("<< {:?}", s);
231 if !s.is_empty() {
232 let mut buffer_queue = BufferQueue::new();
233 let tendril = FromIterator::from_iter(s.chars());
234 buffer_queue.push_back(tendril);
235 self.parser.feed(&mut buffer_queue);
236 }
237 }
238 // Remedies for truncated utf8
239 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
240 // Prepare all the valid data
241 let mut b = BytesMut::with_capacity(e.valid_up_to());
242 b.put(&buf1[0..e.valid_up_to()]);
243
244 // Retry
245 let result = self.decode(&mut b);
246
247 // Keep the tail back in
248 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
249
250 return result;
251 }
252 Err(e) => {
253 error!(
254 "error {} at {}/{} in {:?}",
255 e,
256 e.valid_up_to(),
257 buf1.len(),
258 buf1
259 );
260 return Err(ParserError::Utf8(e));
261 }
262 }
263
264 match self.queue.lock().unwrap().pop_front() {
265 None => Ok(None),
266 Some(result) => result.map(|pkt| Some(pkt)),
267 }
268 }
269
270 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
271 self.decode(buf)
272 }
273}
274
275impl Encoder<Packet> for XMPPCodec {
276 type Error = io::Error;
277
278 fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
279 let remaining = dst.capacity() - dst.len();
280 let max_stanza_size: usize = 2usize.pow(16);
281 if remaining < max_stanza_size {
282 dst.reserve(max_stanza_size - remaining);
283 }
284
285 fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
286 io::Error::new(io::ErrorKind::InvalidInput, e)
287 }
288
289 match item {
290 Packet::StreamStart(start_attrs) => {
291 let mut buf = String::new();
292 write!(buf, "<stream:stream").map_err(to_io_err)?;
293 for (name, value) in start_attrs {
294 write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
295 if name == "xmlns" {
296 self.ns = Some(value);
297 }
298 }
299 write!(buf, ">\n").map_err(to_io_err)?;
300
301 debug!(">> {:?}", buf);
302 write!(dst, "{}", buf).map_err(to_io_err)
303 }
304 Packet::Stanza(stanza) => stanza
305 .write_to(&mut WriteBytes::new(dst))
306 .and_then(|_| {
307 debug!(">> {:?}", dst);
308 Ok(())
309 })
310 .map_err(|e| to_io_err(format!("{}", e))),
311 Packet::Text(text) => write_text(&text, dst)
312 .and_then(|_| {
313 debug!(">> {:?}", dst);
314 Ok(())
315 })
316 .map_err(to_io_err),
317 Packet::StreamEnd => write!(dst, "</stream:stream>\n").map_err(to_io_err),
318 }
319 }
320}
321
322/// Write XML-escaped text string
323pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
324 write!(writer, "{}", escape(text))
325}
326
327/// Copied from `RustyXML` for now
328pub fn escape(input: &str) -> String {
329 let mut result = String::with_capacity(input.len());
330
331 for c in input.chars() {
332 match c {
333 '&' => result.push_str("&"),
334 '<' => result.push_str("<"),
335 '>' => result.push_str(">"),
336 '\'' => result.push_str("'"),
337 '"' => result.push_str("""),
338 o => result.push(o),
339 }
340 }
341 result
342}
343
344/// BytesMut impl only std::fmt::Write but not std::io::Write. The
345/// latter trait is required for minidom's
346/// `Element::write_to_inner()`.
347struct WriteBytes<'a> {
348 dst: &'a mut BytesMut,
349}
350
351impl<'a> WriteBytes<'a> {
352 fn new(dst: &'a mut BytesMut) -> Self {
353 WriteBytes { dst }
354 }
355}
356
357impl<'a> std::io::Write for WriteBytes<'a> {
358 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
359 self.dst.put_slice(buf);
360 Ok(buf.len())
361 }
362
363 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
364 Ok(())
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use bytes::BytesMut;
372
373 #[test]
374 fn test_stream_start() {
375 let mut c = XMPPCodec::new();
376 let mut b = BytesMut::with_capacity(1024);
377 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
378 let r = c.decode(&mut b);
379 assert!(match r {
380 Ok(Some(Packet::StreamStart(_))) => true,
381 _ => false,
382 });
383 }
384
385 #[test]
386 fn test_stream_end() {
387 let mut c = XMPPCodec::new();
388 let mut b = BytesMut::with_capacity(1024);
389 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
390 let r = c.decode(&mut b);
391 assert!(match r {
392 Ok(Some(Packet::StreamStart(_))) => true,
393 _ => false,
394 });
395 b.clear();
396 b.put_slice(b"</stream:stream>");
397 let r = c.decode(&mut b);
398 assert!(match r {
399 Ok(Some(Packet::StreamEnd)) => true,
400 _ => false,
401 });
402 }
403
404 #[test]
405 fn test_truncated_stanza() {
406 let mut c = XMPPCodec::new();
407 let mut b = BytesMut::with_capacity(1024);
408 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
409 let r = c.decode(&mut b);
410 assert!(match r {
411 Ok(Some(Packet::StreamStart(_))) => true,
412 _ => false,
413 });
414
415 b.clear();
416 b.put_slice("<test>ß</test".as_bytes());
417 let r = c.decode(&mut b);
418 assert!(match r {
419 Ok(None) => true,
420 _ => false,
421 });
422
423 b.clear();
424 b.put_slice(b">");
425 let r = c.decode(&mut b);
426 assert!(match r {
427 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
428 _ => false,
429 });
430 }
431
432 #[test]
433 fn test_truncated_utf8() {
434 let mut c = XMPPCodec::new();
435 let mut b = BytesMut::with_capacity(1024);
436 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
437 let r = c.decode(&mut b);
438 assert!(match r {
439 Ok(Some(Packet::StreamStart(_))) => true,
440 _ => false,
441 });
442
443 b.clear();
444 b.put(&b"<test>\xc3"[..]);
445 let r = c.decode(&mut b);
446 assert!(match r {
447 Ok(None) => true,
448 _ => false,
449 });
450
451 b.clear();
452 b.put(&b"\x9f</test>"[..]);
453 let r = c.decode(&mut b);
454 assert!(match r {
455 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
456 _ => false,
457 });
458 }
459
460 /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
461 #[test]
462 fn test_atrribute_prefix() {
463 let mut c = XMPPCodec::new();
464 let mut b = BytesMut::with_capacity(1024);
465 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
466 let r = c.decode(&mut b);
467 assert!(match r {
468 Ok(Some(Packet::StreamStart(_))) => true,
469 _ => false,
470 });
471
472 b.clear();
473 b.put_slice(b"<status xml:lang='en'>Test status</status>");
474 let r = c.decode(&mut b);
475 assert!(match r {
476 Ok(Some(Packet::Stanza(ref el)))
477 if el.name() == "status"
478 && el.text() == "Test status"
479 && el.attr("xml:lang").map_or(false, |a| a == "en") =>
480 true,
481 _ => false,
482 });
483 }
484
485 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
486 #[test]
487 fn test_large_stanza() {
488 use futures::{executor::block_on, sink::SinkExt};
489 use std::io::Cursor;
490 use tokio_util::codec::FramedWrite;
491 let mut framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
492 let mut text = "".to_owned();
493 for _ in 0..2usize.pow(15) {
494 text = text + "A";
495 }
496 let stanza = Element::builder("message", "jabber:client")
497 .append(
498 Element::builder("body", "jabber:client")
499 .append(text.as_ref())
500 .build(),
501 )
502 .build();
503 block_on(framed.send(Packet::Stanza(stanza))).expect("send");
504 assert_eq!(
505 framed.get_ref().get_ref(),
506 &("<message xmlns=\"jabber:client\"><body>".to_owned() + &text + "</body></message>")
507 .as_bytes()
508 );
509 }
510
511 #[test]
512 fn test_cut_out_stanza() {
513 let mut c = XMPPCodec::new();
514 let mut b = BytesMut::with_capacity(1024);
515 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
516 let r = c.decode(&mut b);
517 assert!(match r {
518 Ok(Some(Packet::StreamStart(_))) => true,
519 _ => false,
520 });
521
522 b.clear();
523 b.put_slice(b"<message ");
524 b.put_slice(b"type='chat'><body>Foo</body></message>");
525 let r = c.decode(&mut b);
526 assert!(match r {
527 Ok(Some(Packet::Stanza(_))) => true,
528 _ => false,
529 });
530 }
531}