1use std;
2use std::default::Default;
3use std::iter::FromIterator;
4use std::cell::RefCell;
5use std::rc::Rc;
6use std::fmt::Write;
7use std::str::from_utf8;
8use std::io::{Error, ErrorKind};
9use std::collections::HashMap;
10use std::collections::vec_deque::VecDeque;
11use tokio_codec::{Encoder, Decoder};
12use minidom::Element;
13use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
14use xml5ever::interface::Attribute;
15use bytes::{BytesMut, BufMut};
16use quick_xml::Writer as EventWriter;
17
18// const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
19
20#[derive(Debug)]
21pub enum Packet {
22 Error(Box<std::error::Error>),
23 StreamStart(HashMap<String, String>),
24 Stanza(Element),
25 Text(String),
26 StreamEnd,
27}
28
29struct ParserSink {
30 // Ready stanzas, shared with XMPPCodec
31 queue: Rc<RefCell<VecDeque<Packet>>>,
32 // Parsing stack
33 stack: Vec<Element>,
34 ns_stack: Vec<HashMap<Option<String>, String>>,
35}
36
37impl ParserSink {
38 pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
39 ParserSink {
40 queue,
41 stack: vec![],
42 ns_stack: vec![],
43 }
44 }
45
46 fn push_queue(&self, pkt: Packet) {
47 self.queue.borrow_mut().push_back(pkt);
48 }
49
50 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
51 for nss in self.ns_stack.iter().rev() {
52 if let Some(ns) = nss.get(prefix) {
53 return Some(ns);
54 }
55 }
56
57 None
58 }
59
60 fn handle_start_tag(&mut self, tag: Tag) {
61 let mut nss = HashMap::new();
62 let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
63 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
64 .unwrap_or(false);
65 for attr in &tag.attrs {
66 match attr.name.local.as_ref() {
67 "xmlns" => {
68 nss.insert(None, attr.value.as_ref().to_owned());
69 },
70 prefix if is_prefix_xmlns(attr) => {
71 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
72 },
73 _ => (),
74 }
75 }
76 self.ns_stack.push(nss);
77
78 let el = {
79 let mut el_builder = Element::builder(tag.name.local.as_ref());
80 if let Some(el_ns) = self.lookup_ns(
81 &tag.name.prefix.map(|prefix|
82 prefix.as_ref().to_owned())
83 ) {
84 el_builder = el_builder.ns(el_ns);
85 }
86 for attr in &tag.attrs {
87 match attr.name.local.as_ref() {
88 "xmlns" => (),
89 _ if is_prefix_xmlns(attr) => (),
90 _ => {
91 el_builder = el_builder.attr(
92 attr.name.local.as_ref(),
93 attr.value.as_ref()
94 );
95 },
96 }
97 }
98 el_builder.build()
99 };
100
101 if self.stack.is_empty() {
102 let attrs = HashMap::from_iter(
103 tag.attrs.iter()
104 .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
105 );
106 self.push_queue(Packet::StreamStart(attrs));
107 }
108
109 self.stack.push(el);
110 }
111
112 fn handle_end_tag(&mut self) {
113 let el = self.stack.pop().unwrap();
114 self.ns_stack.pop();
115
116 match self.stack.len() {
117 // </stream:stream>
118 0 =>
119 self.push_queue(Packet::StreamEnd),
120 // </stanza>
121 1 =>
122 self.push_queue(Packet::Stanza(el)),
123 len => {
124 let parent = &mut self.stack[len - 1];
125 parent.append_child(el);
126 },
127 }
128 }
129}
130
131impl TokenSink for ParserSink {
132 fn process_token(&mut self, token: Token) {
133 match token {
134 Token::TagToken(tag) => match tag.kind {
135 TagKind::StartTag =>
136 self.handle_start_tag(tag),
137 TagKind::EndTag =>
138 self.handle_end_tag(),
139 TagKind::EmptyTag => {
140 self.handle_start_tag(tag);
141 self.handle_end_tag();
142 },
143 TagKind::ShortTag =>
144 self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
145 },
146 Token::CharacterTokens(tendril) =>
147 match self.stack.len() {
148 0 | 1 =>
149 self.push_queue(Packet::Text(tendril.into())),
150 len => {
151 let el = &mut self.stack[len - 1];
152 el.append_text_node(tendril);
153 },
154 },
155 Token::EOFToken =>
156 self.push_queue(Packet::StreamEnd),
157 Token::ParseError(s) => {
158 println!("ParseError: {:?}", s);
159 self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
160 },
161 _ => (),
162 }
163 }
164
165 // fn end(&mut self) {
166 // }
167}
168
169pub struct XMPPCodec {
170 /// Outgoing
171 ns: Option<String>,
172 /// Incoming
173 parser: XmlTokenizer<ParserSink>,
174 /// For handling incoming truncated utf8
175 // TODO: optimize using tendrils?
176 buf: Vec<u8>,
177 /// Shared with ParserSink
178 queue: Rc<RefCell<VecDeque<Packet>>>,
179}
180
181impl XMPPCodec {
182 pub fn new() -> Self {
183 let queue = Rc::new(RefCell::new(VecDeque::new()));
184 let sink = ParserSink::new(queue.clone());
185 // TODO: configure parser?
186 let parser = XmlTokenizer::new(sink, Default::default());
187 XMPPCodec {
188 ns: None,
189 parser,
190 queue,
191 buf: vec![],
192 }
193 }
194}
195
196impl Default for XMPPCodec {
197 fn default() -> Self {
198 Self::new()
199 }
200}
201
202impl Decoder for XMPPCodec {
203 type Item = Packet;
204 type Error = Error;
205
206 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
207 let buf1: Box<AsRef<[u8]>> =
208 if ! self.buf.is_empty() && ! buf.is_empty() {
209 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
210 prefix.extend_from_slice(buf.take().as_ref());
211 Box::new(prefix)
212 } else {
213 Box::new(buf.take())
214 };
215 let buf1 = buf1.as_ref().as_ref();
216 match from_utf8(buf1) {
217 Ok(s) => {
218 if ! s.is_empty() {
219 println!("<< {}", s);
220 let tendril = FromIterator::from_iter(s.chars());
221 self.parser.feed(tendril);
222 }
223 },
224 // Remedies for truncated utf8
225 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
226 // Prepare all the valid data
227 let mut b = BytesMut::with_capacity(e.valid_up_to());
228 b.put(&buf1[0..e.valid_up_to()]);
229
230 // Retry
231 let result = self.decode(&mut b);
232
233 // Keep the tail back in
234 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
235
236 return result;
237 },
238 Err(e) => {
239 println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
240 return Err(Error::new(ErrorKind::InvalidInput, e));
241 },
242 }
243
244 let result = self.queue.borrow_mut().pop_front();
245 Ok(result)
246 }
247
248 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
249 self.decode(buf)
250 }
251}
252
253impl Encoder for XMPPCodec {
254 type Item = Packet;
255 type Error = Error;
256
257 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
258 let remaining = dst.capacity() - dst.len();
259 let max_stanza_size: usize = 2usize.pow(16);
260 if remaining < max_stanza_size {
261 dst.reserve(max_stanza_size - remaining);
262 }
263
264 match item {
265 Packet::StreamStart(start_attrs) => {
266 let mut buf = String::new();
267 write!(buf, "<stream:stream").unwrap();
268 for (name, value) in start_attrs {
269 write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
270 .unwrap();
271 if name == "xmlns" {
272 self.ns = Some(value);
273 }
274 }
275 write!(buf, ">\n").unwrap();
276
277 print!(">> {}", buf);
278 write!(dst, "{}", buf)
279 .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
280 },
281 Packet::Stanza(stanza) => {
282 stanza.write_to_inner(&mut EventWriter::new(WriteBytes::new(dst)))
283 .and_then(|_| {
284 println!(">> {:?}", dst);
285 Ok(())
286 })
287 .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
288 },
289 Packet::Text(text) => {
290 write_text(&text, dst)
291 .and_then(|_| {
292 println!(">> {:?}", dst);
293 Ok(())
294 })
295 .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
296 },
297 // TODO: Implement all
298 _ => Ok(())
299 }
300 }
301}
302
303pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
304 write!(writer, "{}", escape(text))
305}
306
307/// Copied from `RustyXML` for now
308pub fn escape(input: &str) -> String {
309 let mut result = String::with_capacity(input.len());
310
311 for c in input.chars() {
312 match c {
313 '&' => result.push_str("&"),
314 '<' => result.push_str("<"),
315 '>' => result.push_str(">"),
316 '\'' => result.push_str("'"),
317 '"' => result.push_str("""),
318 o => result.push(o)
319 }
320 }
321 result
322}
323
324/// BytesMut impl only std::fmt::Write but not std::io::Write. The
325/// latter trait is required for minidom's
326/// `Element::write_to_inner()`.
327struct WriteBytes<'a> {
328 dst: &'a mut BytesMut,
329}
330
331impl<'a> WriteBytes<'a> {
332 fn new(dst: &'a mut BytesMut) -> Self {
333 WriteBytes { dst }
334 }
335}
336
337impl<'a> std::io::Write for WriteBytes<'a> {
338 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
339 self.dst.put_slice(buf);
340 Ok(buf.len())
341 }
342
343 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
344 Ok(())
345 }
346}
347
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use bytes::BytesMut;
353
354 #[test]
355 fn test_stream_start() {
356 let mut c = XMPPCodec::new();
357 let mut b = BytesMut::with_capacity(1024);
358 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
359 let r = c.decode(&mut b);
360 assert!(match r {
361 Ok(Some(Packet::StreamStart(_))) => true,
362 _ => false,
363 });
364 }
365
366 #[test]
367 fn test_truncated_stanza() {
368 let mut c = XMPPCodec::new();
369 let mut b = BytesMut::with_capacity(1024);
370 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
371 let r = c.decode(&mut b);
372 assert!(match r {
373 Ok(Some(Packet::StreamStart(_))) => true,
374 _ => false,
375 });
376
377 b.clear();
378 b.put(r"<test>ß</test");
379 let r = c.decode(&mut b);
380 assert!(match r {
381 Ok(None) => true,
382 _ => false,
383 });
384
385 b.clear();
386 b.put(r">");
387 let r = c.decode(&mut b);
388 assert!(match r {
389 Ok(Some(Packet::Stanza(ref el)))
390 if el.name() == "test"
391 && el.text() == "ß"
392 => true,
393 _ => false,
394 });
395 }
396
397 #[test]
398 fn test_truncated_utf8() {
399 let mut c = XMPPCodec::new();
400 let mut b = BytesMut::with_capacity(1024);
401 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
402 let r = c.decode(&mut b);
403 assert!(match r {
404 Ok(Some(Packet::StreamStart(_))) => true,
405 _ => false,
406 });
407
408 b.clear();
409 b.put(&b"<test>\xc3"[..]);
410 let r = c.decode(&mut b);
411 assert!(match r {
412 Ok(None) => true,
413 _ => false,
414 });
415
416 b.clear();
417 b.put(&b"\x9f</test>"[..]);
418 let r = c.decode(&mut b);
419 assert!(match r {
420 Ok(Some(Packet::Stanza(ref el)))
421 if el.name() == "test"
422 && el.text() == "ß"
423 => true,
424 _ => false,
425 });
426 }
427
428 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
429 #[test]
430 fn test_large_stanza() {
431 use std::io::Cursor;
432 use futures::{Future, Sink};
433 use tokio_codec::FramedWrite;
434 let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
435 let mut text = "".to_owned();
436 for _ in 0..2usize.pow(15) {
437 text = text + "A";
438 }
439 let stanza = Element::builder("message")
440 .append(
441 Element::builder("body")
442 .append(&text)
443 .build()
444 )
445 .build();
446 let framed = framed.send(Packet::Stanza(stanza))
447 .wait()
448 .expect("send");
449 assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
450 }
451}