1use std;
2use std::default::Default;
3use std::iter::FromIterator;
4use std::cell::RefCell;
5use std::rc::Rc;
6use std::fmt::Write;
7use std::str::from_utf8;
8use std::io::{Error, ErrorKind};
9use std::collections::HashMap;
10use std::collections::vec_deque::VecDeque;
11use tokio_io::codec::{Encoder, Decoder};
12use minidom::{Element, Node};
13use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
14use xml5ever::interface::Attribute;
15use bytes::{BytesMut, BufMut};
16
17// const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
18
19#[derive(Debug)]
20pub enum Packet {
21 Error(Box<std::error::Error>),
22 StreamStart(HashMap<String, String>),
23 Stanza(Element),
24 Text(String),
25 StreamEnd,
26}
27
28struct ParserSink {
29 // Ready stanzas, shared with XMPPCodec
30 queue: Rc<RefCell<VecDeque<Packet>>>,
31 // Parsing stack
32 stack: Vec<Element>,
33 ns_stack: Vec<HashMap<Option<String>, String>>,
34}
35
36impl ParserSink {
37 pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
38 ParserSink {
39 queue,
40 stack: vec![],
41 ns_stack: vec![],
42 }
43 }
44
45 fn push_queue(&self, pkt: Packet) {
46 self.queue.borrow_mut().push_back(pkt);
47 }
48
49 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
50 for nss in self.ns_stack.iter().rev() {
51 match nss.get(prefix) {
52 Some(ns) => return Some(ns),
53 None => (),
54 }
55 }
56
57 None
58 }
59
60 fn handle_start_tag(&mut self, tag: Tag) {
61 let mut nss = HashMap::new();
62 let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
63 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
64 .unwrap_or(false);
65 for attr in &tag.attrs {
66 match attr.name.local.as_ref() {
67 "xmlns" => {
68 nss.insert(None, attr.value.as_ref().to_owned());
69 },
70 prefix if is_prefix_xmlns(attr) => {
71 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
72 },
73 _ => (),
74 }
75 }
76 self.ns_stack.push(nss);
77
78 let el = {
79 let mut el_builder = Element::builder(tag.name.local.as_ref());
80 match self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned())) {
81 Some(el_ns) => el_builder = el_builder.ns(el_ns),
82 None => (),
83 }
84 for attr in &tag.attrs {
85 match attr.name.local.as_ref() {
86 "xmlns" => (),
87 _ if is_prefix_xmlns(attr) => (),
88 _ => {
89 el_builder = el_builder.attr(
90 attr.name.local.as_ref(),
91 attr.value.as_ref()
92 );
93 },
94 }
95 }
96 el_builder.build()
97 };
98
99 if self.stack.is_empty() {
100 let attrs = HashMap::from_iter(
101 tag.attrs.iter()
102 .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
103 );
104 self.push_queue(Packet::StreamStart(attrs));
105 }
106
107 self.stack.push(el);
108 }
109
110 fn handle_end_tag(&mut self) {
111 let el = self.stack.pop().unwrap();
112 self.ns_stack.pop();
113
114 match self.stack.len() {
115 // </stream:stream>
116 0 =>
117 self.push_queue(Packet::StreamEnd),
118 // </stanza>
119 1 =>
120 self.push_queue(Packet::Stanza(el)),
121 len => {
122 let parent = &mut self.stack[len - 1];
123 parent.append_child(el);
124 },
125 }
126 }
127}
128
129impl TokenSink for ParserSink {
130 fn process_token(&mut self, token: Token) {
131 match token {
132 Token::TagToken(tag) => match tag.kind {
133 TagKind::StartTag =>
134 self.handle_start_tag(tag),
135 TagKind::EndTag =>
136 self.handle_end_tag(),
137 TagKind::EmptyTag => {
138 self.handle_start_tag(tag);
139 self.handle_end_tag();
140 },
141 TagKind::ShortTag =>
142 self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
143 },
144 Token::CharacterTokens(tendril) =>
145 match self.stack.len() {
146 0 | 1 =>
147 self.push_queue(Packet::Text(tendril.into())),
148 len => {
149 let el = &mut self.stack[len - 1];
150 el.append_text_node(tendril);
151 },
152 },
153 Token::EOFToken =>
154 self.push_queue(Packet::StreamEnd),
155 Token::ParseError(s) => {
156 println!("ParseError: {:?}", s);
157 self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
158 },
159 _ => (),
160 }
161 }
162
163 // fn end(&mut self) {
164 // }
165}
166
167pub struct XMPPCodec {
168 /// Outgoing
169 ns: Option<String>,
170 /// Incoming
171 parser: XmlTokenizer<ParserSink>,
172 /// For handling incoming truncated utf8
173 // TODO: optimize using tendrils?
174 buf: Vec<u8>,
175 /// Shared with ParserSink
176 queue: Rc<RefCell<VecDeque<Packet>>>,
177}
178
179impl XMPPCodec {
180 pub fn new() -> Self {
181 let queue = Rc::new(RefCell::new((VecDeque::new())));
182 let sink = ParserSink::new(queue.clone());
183 // TODO: configure parser?
184 let parser = XmlTokenizer::new(sink, Default::default());
185 XMPPCodec {
186 ns: None,
187 parser,
188 queue,
189 buf: vec![],
190 }
191 }
192}
193
194impl Decoder for XMPPCodec {
195 type Item = Packet;
196 type Error = Error;
197
198 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
199 let buf1: Box<AsRef<[u8]>> =
200 if self.buf.len() > 0 && buf.len() > 0 {
201 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
202 prefix.extend_from_slice(buf.take().as_ref());
203 Box::new(prefix)
204 } else {
205 Box::new(buf.take())
206 };
207 let buf1 = buf1.as_ref().as_ref();
208 match from_utf8(buf1) {
209 Ok(s) => {
210 if s.len() > 0 {
211 println!("<< {}", s);
212 let tendril = FromIterator::from_iter(s.chars());
213 self.parser.feed(tendril);
214 }
215 },
216 // Remedies for truncated utf8
217 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
218 // Prepare all the valid data
219 let mut b = BytesMut::with_capacity(e.valid_up_to());
220 b.put(&buf1[0..e.valid_up_to()]);
221
222 // Retry
223 let result = self.decode(&mut b);
224
225 // Keep the tail back in
226 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
227
228 return result;
229 },
230 Err(e) => {
231 println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
232 return Err(Error::new(ErrorKind::InvalidInput, e));
233 },
234 }
235
236 let result = self.queue.borrow_mut().pop_front();
237 Ok(result)
238 }
239
240 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
241 self.decode(buf)
242 }
243}
244
245impl Encoder for XMPPCodec {
246 type Item = Packet;
247 type Error = Error;
248
249 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
250 let remaining = dst.capacity() - dst.len();
251 let max_stanza_size: usize = 2usize.pow(16);
252 if remaining < max_stanza_size {
253 dst.reserve(max_stanza_size - remaining);
254 }
255
256 match item {
257 Packet::StreamStart(start_attrs) => {
258 let mut buf = String::new();
259 write!(buf, "<stream:stream").unwrap();
260 for (name, value) in start_attrs.into_iter() {
261 write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
262 .unwrap();
263 if name == "xmlns" {
264 self.ns = Some(value);
265 }
266 }
267 write!(buf, ">\n").unwrap();
268
269 print!(">> {}", buf);
270 write!(dst, "{}", buf)
271 .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
272 },
273 Packet::Stanza(stanza) => {
274 let root_ns = self.ns.as_ref().map(|s| s.as_ref());
275 write_element(&stanza, dst, root_ns)
276 .and_then(|_| {
277 println!(">> {:?}", dst);
278 Ok(())
279 })
280 .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
281 },
282 Packet::Text(text) => {
283 write_text(&text, dst)
284 .and_then(|_| {
285 println!(">> {:?}", dst);
286 Ok(())
287 })
288 .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
289 },
290 // TODO: Implement all
291 _ => Ok(())
292 }
293 }
294}
295
296pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
297 write!(writer, "{}", text)
298}
299
300// TODO: escape everything?
301pub fn write_element<W: Write>(el: &Element, writer: &mut W, parent_ns: Option<&str>) -> Result<(), std::fmt::Error> {
302 write!(writer, "<")?;
303 write!(writer, "{}", el.name())?;
304
305 if let Some(ref ns) = el.ns() {
306 if parent_ns.map(|s| s.as_ref()) != el.ns() {
307 write!(writer, " xmlns=\"{}\"", ns)?;
308 }
309 }
310
311 for (key, value) in el.attrs() {
312 write!(writer, " {}=\"{}\"", key, value)?;
313 }
314
315 if ! el.nodes().any(|_| true) {
316 write!(writer, " />")?;
317 return Ok(())
318 }
319
320 write!(writer, ">")?;
321
322 for node in el.nodes() {
323 match node {
324 &Node::Element(ref child) =>
325 write_element(child, writer, el.ns())?,
326 &Node::Text(ref text) =>
327 write_text(text, writer)?,
328 }
329 }
330
331 write!(writer, "</{}>", el.name())?;
332 Ok(())
333}
334
335/// Copied from RustyXML for now
336pub fn escape(input: &str) -> String {
337 let mut result = String::with_capacity(input.len());
338
339 for c in input.chars() {
340 match c {
341 '&' => result.push_str("&"),
342 '<' => result.push_str("<"),
343 '>' => result.push_str(">"),
344 '\'' => result.push_str("'"),
345 '"' => result.push_str("""),
346 o => result.push(o)
347 }
348 }
349 result
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use bytes::BytesMut;
356
357 #[test]
358 fn test_stream_start() {
359 let mut c = XMPPCodec::new();
360 let mut b = BytesMut::with_capacity(1024);
361 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
362 let r = c.decode(&mut b);
363 assert!(match r {
364 Ok(Some(Packet::StreamStart(_))) => true,
365 _ => false,
366 });
367 }
368
369 #[test]
370 fn test_truncated_stanza() {
371 let mut c = XMPPCodec::new();
372 let mut b = BytesMut::with_capacity(1024);
373 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
374 let r = c.decode(&mut b);
375 assert!(match r {
376 Ok(Some(Packet::StreamStart(_))) => true,
377 _ => false,
378 });
379
380 b.clear();
381 b.put(r"<test>ß</test");
382 let r = c.decode(&mut b);
383 assert!(match r {
384 Ok(None) => true,
385 _ => false,
386 });
387
388 b.clear();
389 b.put(r">");
390 let r = c.decode(&mut b);
391 assert!(match r {
392 Ok(Some(Packet::Stanza(ref el)))
393 if el.name() == "test"
394 && el.text() == "ß"
395 => true,
396 _ => false,
397 });
398 }
399
400 #[test]
401 fn test_truncated_utf8() {
402 let mut c = XMPPCodec::new();
403 let mut b = BytesMut::with_capacity(1024);
404 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
405 let r = c.decode(&mut b);
406 assert!(match r {
407 Ok(Some(Packet::StreamStart(_))) => true,
408 _ => false,
409 });
410
411 b.clear();
412 b.put(&b"<test>\xc3"[..]);
413 let r = c.decode(&mut b);
414 assert!(match r {
415 Ok(None) => true,
416 _ => false,
417 });
418
419 b.clear();
420 b.put(&b"\x9f</test>"[..]);
421 let r = c.decode(&mut b);
422 assert!(match r {
423 Ok(Some(Packet::Stanza(ref el)))
424 if el.name() == "test"
425 && el.text() == "ß"
426 => true,
427 _ => false,
428 });
429 }
430
431 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
432 #[test]
433 fn test_large_stanza() {
434 use std::io::Cursor;
435 use futures::{Future, Sink};
436 use tokio_io::codec::FramedWrite;
437 let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
438 let mut text = "".to_owned();
439 for _ in 0..2usize.pow(15) {
440 text = text + "A";
441 }
442 let stanza = Element::builder("message")
443 .append(
444 Element::builder("body")
445 .append(&text)
446 .build()
447 )
448 .build();
449 let framed = framed.send(Packet::Stanza(stanza))
450 .wait()
451 .expect("send");
452 assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
453 }
454}