1//! XML stream parser for XMPP
2
3use crate::{ParseError, ParserError};
4use bytes::{BufMut, BytesMut};
5use minidom::Element;
6use quick_xml::Writer as EventWriter;
7use std;
8use std::cell::RefCell;
9use std::collections::vec_deque::VecDeque;
10use std::collections::HashMap;
11use std::default::Default;
12use std::fmt::Write;
13use std::io;
14use std::iter::FromIterator;
15use std::rc::Rc;
16use std::str::from_utf8;
17use tokio_codec::{Decoder, Encoder};
18use xml5ever::interface::Attribute;
19use xml5ever::tokenizer::{Tag, TagKind, Token, TokenSink, XmlTokenizer};
20
21/// Anything that can be sent or received on an XMPP/XML stream
22#[derive(Debug)]
23pub enum Packet {
24 /// `<stream:stream>` start tag
25 StreamStart(HashMap<String, String>),
26 /// A complete stanza or nonza
27 Stanza(Element),
28 /// Plain text (think whitespace keep-alive)
29 Text(String),
30 /// `</stream:stream>` closing tag
31 StreamEnd,
32}
33
34type QueueItem = Result<Packet, ParserError>;
35
36/// Parser state
37struct ParserSink {
38 // Ready stanzas, shared with XMPPCodec
39 queue: Rc<RefCell<VecDeque<QueueItem>>>,
40 // Parsing stack
41 stack: Vec<Element>,
42 ns_stack: Vec<HashMap<Option<String>, String>>,
43}
44
45impl ParserSink {
46 pub fn new(queue: Rc<RefCell<VecDeque<QueueItem>>>) -> Self {
47 ParserSink {
48 queue,
49 stack: vec![],
50 ns_stack: vec![],
51 }
52 }
53
54 fn push_queue(&self, pkt: Packet) {
55 self.queue.borrow_mut().push_back(Ok(pkt));
56 }
57
58 fn push_queue_error(&self, e: ParserError) {
59 self.queue.borrow_mut().push_back(Err(e));
60 }
61
62 /// Lookup XML namespace declaration for given prefix (or no prefix)
63 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
64 for nss in self.ns_stack.iter().rev() {
65 if let Some(ns) = nss.get(prefix) {
66 return Some(ns);
67 }
68 }
69
70 None
71 }
72
73 fn handle_start_tag(&mut self, tag: Tag) {
74 let mut nss = HashMap::new();
75 let is_prefix_xmlns = |attr: &Attribute| {
76 attr.name
77 .prefix
78 .as_ref()
79 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
80 .unwrap_or(false)
81 };
82 for attr in &tag.attrs {
83 match attr.name.local.as_ref() {
84 "xmlns" => {
85 nss.insert(None, attr.value.as_ref().to_owned());
86 }
87 prefix if is_prefix_xmlns(attr) => {
88 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
89 }
90 _ => (),
91 }
92 }
93 self.ns_stack.push(nss);
94
95 let el = {
96 let mut el_builder = Element::builder(tag.name.local.as_ref());
97 if let Some(el_ns) =
98 self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned()))
99 {
100 el_builder = el_builder.ns(el_ns);
101 }
102 for attr in &tag.attrs {
103 match attr.name.local.as_ref() {
104 "xmlns" => (),
105 _ if is_prefix_xmlns(attr) => (),
106 _ => {
107 el_builder = el_builder.attr(attr.name.local.as_ref(), attr.value.as_ref());
108 }
109 }
110 }
111 el_builder.build()
112 };
113
114 if self.stack.is_empty() {
115 let attrs = HashMap::from_iter(tag.attrs.iter().map(|attr| {
116 (
117 attr.name.local.as_ref().to_owned(),
118 attr.value.as_ref().to_owned(),
119 )
120 }));
121 self.push_queue(Packet::StreamStart(attrs));
122 }
123
124 self.stack.push(el);
125 }
126
127 fn handle_end_tag(&mut self) {
128 let el = self.stack.pop().unwrap();
129 self.ns_stack.pop();
130
131 match self.stack.len() {
132 // </stream:stream>
133 0 => self.push_queue(Packet::StreamEnd),
134 // </stanza>
135 1 => self.push_queue(Packet::Stanza(el)),
136 len => {
137 let parent = &mut self.stack[len - 1];
138 parent.append_child(el);
139 }
140 }
141 }
142}
143
144impl TokenSink for ParserSink {
145 fn process_token(&mut self, token: Token) {
146 match token {
147 Token::TagToken(tag) => match tag.kind {
148 TagKind::StartTag => self.handle_start_tag(tag),
149 TagKind::EndTag => self.handle_end_tag(),
150 TagKind::EmptyTag => {
151 self.handle_start_tag(tag);
152 self.handle_end_tag();
153 }
154 TagKind::ShortTag => self.push_queue_error(ParserError::ShortTag),
155 },
156 Token::CharacterTokens(tendril) => match self.stack.len() {
157 0 | 1 => self.push_queue(Packet::Text(tendril.into())),
158 len => {
159 let el = &mut self.stack[len - 1];
160 el.append_text_node(tendril);
161 }
162 },
163 Token::EOFToken => self.push_queue(Packet::StreamEnd),
164 Token::ParseError(s) => {
165 // println!("ParseError: {:?}", s);
166 self.push_queue_error(ParserError::Parse(ParseError(s)));
167 }
168 _ => (),
169 }
170 }
171
172 // fn end(&mut self) {
173 // }
174}
175
176/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
177pub struct XMPPCodec {
178 /// Outgoing
179 ns: Option<String>,
180 /// Incoming
181 parser: XmlTokenizer<ParserSink>,
182 /// For handling incoming truncated utf8
183 // TODO: optimize using tendrils?
184 buf: Vec<u8>,
185 /// Shared with ParserSink
186 queue: Rc<RefCell<VecDeque<QueueItem>>>,
187}
188
189impl XMPPCodec {
190 /// Constructor
191 pub fn new() -> Self {
192 let queue = Rc::new(RefCell::new(VecDeque::new()));
193 let sink = ParserSink::new(queue.clone());
194 // TODO: configure parser?
195 let parser = XmlTokenizer::new(sink, Default::default());
196 XMPPCodec {
197 ns: None,
198 parser,
199 queue,
200 buf: vec![],
201 }
202 }
203}
204
205impl Default for XMPPCodec {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211impl Decoder for XMPPCodec {
212 type Item = Packet;
213 type Error = ParserError;
214
215 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
216 let buf1: Box<AsRef<[u8]>> = if !self.buf.is_empty() && !buf.is_empty() {
217 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
218 prefix.extend_from_slice(buf.take().as_ref());
219 Box::new(prefix)
220 } else {
221 Box::new(buf.take())
222 };
223 let buf1 = buf1.as_ref().as_ref();
224 match from_utf8(buf1) {
225 Ok(s) => {
226 if !s.is_empty() {
227 // println!("<< {}", s);
228 let tendril = FromIterator::from_iter(s.chars());
229 self.parser.feed(tendril);
230 }
231 }
232 // Remedies for truncated utf8
233 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
234 // Prepare all the valid data
235 let mut b = BytesMut::with_capacity(e.valid_up_to());
236 b.put(&buf1[0..e.valid_up_to()]);
237
238 // Retry
239 let result = self.decode(&mut b);
240
241 // Keep the tail back in
242 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
243
244 return result;
245 }
246 Err(e) => {
247 // println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
248 return Err(ParserError::Utf8(e));
249 }
250 }
251
252 match self.queue.borrow_mut().pop_front() {
253 None => Ok(None),
254 Some(result) => result.map(|pkt| Some(pkt)),
255 }
256 }
257
258 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
259 self.decode(buf)
260 }
261}
262
263impl Encoder for XMPPCodec {
264 type Item = Packet;
265 type Error = io::Error;
266
267 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
268 let remaining = dst.capacity() - dst.len();
269 let max_stanza_size: usize = 2usize.pow(16);
270 if remaining < max_stanza_size {
271 dst.reserve(max_stanza_size - remaining);
272 }
273
274 match item {
275 Packet::StreamStart(start_attrs) => {
276 let mut buf = String::new();
277 write!(buf, "<stream:stream").unwrap();
278 for (name, value) in start_attrs {
279 write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).unwrap();
280 if name == "xmlns" {
281 self.ns = Some(value);
282 }
283 }
284 write!(buf, ">\n").unwrap();
285
286 print!(">> {}", buf);
287 write!(dst, "{}", buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
288 }
289 Packet::Stanza(stanza) => {
290 stanza
291 .write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
292 .and_then(|_| {
293 // println!(">> {:?}", dst);
294 Ok(())
295 })
296 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
297 }
298 Packet::Text(text) => {
299 write_text(&text, dst)
300 .and_then(|_| {
301 // println!(">> {:?}", dst);
302 Ok(())
303 })
304 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{}", e)))
305 }
306 // TODO: Implement all
307 _ => Ok(()),
308 }
309 }
310}
311
312/// Write XML-escaped text string
313pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
314 write!(writer, "{}", escape(text))
315}
316
317/// Copied from `RustyXML` for now
318pub fn escape(input: &str) -> String {
319 let mut result = String::with_capacity(input.len());
320
321 for c in input.chars() {
322 match c {
323 '&' => result.push_str("&"),
324 '<' => result.push_str("<"),
325 '>' => result.push_str(">"),
326 '\'' => result.push_str("'"),
327 '"' => result.push_str("""),
328 o => result.push(o),
329 }
330 }
331 result
332}
333
334/// BytesMut impl only std::fmt::Write but not std::io::Write. The
335/// latter trait is required for minidom's
336/// `Element::write_to_inner()`.
337struct WriteBytes<'a> {
338 dst: &'a mut BytesMut,
339}
340
341impl<'a> WriteBytes<'a> {
342 fn new(dst: &'a mut BytesMut) -> Self {
343 WriteBytes { dst }
344 }
345}
346
347impl<'a> std::io::Write for WriteBytes<'a> {
348 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
349 self.dst.put_slice(buf);
350 Ok(buf.len())
351 }
352
353 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
354 Ok(())
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use bytes::BytesMut;
362
363 #[test]
364 fn test_stream_start() {
365 let mut c = XMPPCodec::new();
366 let mut b = BytesMut::with_capacity(1024);
367 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
368 let r = c.decode(&mut b);
369 assert!(match r {
370 Ok(Some(Packet::StreamStart(_))) => true,
371 _ => false,
372 });
373 }
374
375 #[test]
376 fn test_truncated_stanza() {
377 let mut c = XMPPCodec::new();
378 let mut b = BytesMut::with_capacity(1024);
379 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
380 let r = c.decode(&mut b);
381 assert!(match r {
382 Ok(Some(Packet::StreamStart(_))) => true,
383 _ => false,
384 });
385
386 b.clear();
387 b.put(r"<test>ß</test");
388 let r = c.decode(&mut b);
389 assert!(match r {
390 Ok(None) => true,
391 _ => false,
392 });
393
394 b.clear();
395 b.put(r">");
396 let r = c.decode(&mut b);
397 assert!(match r {
398 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
399 _ => false,
400 });
401 }
402
403 #[test]
404 fn test_truncated_utf8() {
405 let mut c = XMPPCodec::new();
406 let mut b = BytesMut::with_capacity(1024);
407 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
408 let r = c.decode(&mut b);
409 assert!(match r {
410 Ok(Some(Packet::StreamStart(_))) => true,
411 _ => false,
412 });
413
414 b.clear();
415 b.put(&b"<test>\xc3"[..]);
416 let r = c.decode(&mut b);
417 assert!(match r {
418 Ok(None) => true,
419 _ => false,
420 });
421
422 b.clear();
423 b.put(&b"\x9f</test>"[..]);
424 let r = c.decode(&mut b);
425 assert!(match r {
426 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
427 _ => false,
428 });
429 }
430
431 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
432 #[test]
433 fn test_large_stanza() {
434 use futures::{Future, Sink};
435 use std::io::Cursor;
436 use tokio_codec::FramedWrite;
437 let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
438 let mut text = "".to_owned();
439 for _ in 0..2usize.pow(15) {
440 text = text + "A";
441 }
442 let stanza = Element::builder("message")
443 .append(Element::builder("body").append(&text).build())
444 .build();
445 let framed = framed.send(Packet::Stanza(stanza)).wait().expect("send");
446 assert_eq!(
447 framed.get_ref().get_ref(),
448 &("<message><body>".to_owned() + &text + "</body></message>").as_bytes()
449 );
450 }
451}