1use std;
2use std::default::Default;
3use std::iter::FromIterator;
4use std::cell::RefCell;
5use std::rc::Rc;
6use std::fmt::Write;
7use std::str::from_utf8;
8use std::io::{Error, ErrorKind};
9use std::collections::HashMap;
10use std::collections::vec_deque::VecDeque;
11use tokio_io::codec::{Encoder, Decoder};
12use minidom::Element;
13use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
14use xml5ever::interface::Attribute;
15use bytes::{BytesMut, BufMut};
16
17// const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
18
19#[derive(Debug)]
20pub enum Packet {
21 Error(Box<std::error::Error>),
22 StreamStart(HashMap<String, String>),
23 Stanza(Element),
24 Text(String),
25 StreamEnd,
26}
27
28struct ParserSink {
29 // Ready stanzas, shared with XMPPCodec
30 queue: Rc<RefCell<VecDeque<Packet>>>,
31 // Parsing stack
32 stack: Vec<Element>,
33 ns_stack: Vec<HashMap<Option<String>, String>>,
34}
35
36impl ParserSink {
37 pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
38 ParserSink {
39 queue,
40 stack: vec![],
41 ns_stack: vec![],
42 }
43 }
44
45 fn push_queue(&self, pkt: Packet) {
46 self.queue.borrow_mut().push_back(pkt);
47 }
48
49 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
50 for nss in self.ns_stack.iter().rev() {
51 if let Some(ns) = nss.get(prefix) {
52 return Some(ns);
53 }
54 }
55
56 None
57 }
58
59 fn handle_start_tag(&mut self, tag: Tag) {
60 let mut nss = HashMap::new();
61 let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
62 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
63 .unwrap_or(false);
64 for attr in &tag.attrs {
65 match attr.name.local.as_ref() {
66 "xmlns" => {
67 nss.insert(None, attr.value.as_ref().to_owned());
68 },
69 prefix if is_prefix_xmlns(attr) => {
70 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
71 },
72 _ => (),
73 }
74 }
75 self.ns_stack.push(nss);
76
77 let el = {
78 let mut el_builder = Element::builder(tag.name.local.as_ref());
79 if let Some(el_ns) = self.lookup_ns(
80 &tag.name.prefix.map(|prefix|
81 prefix.as_ref().to_owned())
82 ) {
83 el_builder = el_builder.ns(el_ns);
84 }
85 for attr in &tag.attrs {
86 match attr.name.local.as_ref() {
87 "xmlns" => (),
88 _ if is_prefix_xmlns(attr) => (),
89 _ => {
90 el_builder = el_builder.attr(
91 attr.name.local.as_ref(),
92 attr.value.as_ref()
93 );
94 },
95 }
96 }
97 el_builder.build()
98 };
99
100 if self.stack.is_empty() {
101 let attrs = HashMap::from_iter(
102 tag.attrs.iter()
103 .map(|attr| (attr.name.local.as_ref().to_owned(), attr.value.as_ref().to_owned()))
104 );
105 self.push_queue(Packet::StreamStart(attrs));
106 }
107
108 self.stack.push(el);
109 }
110
111 fn handle_end_tag(&mut self) {
112 let el = self.stack.pop().unwrap();
113 self.ns_stack.pop();
114
115 match self.stack.len() {
116 // </stream:stream>
117 0 =>
118 self.push_queue(Packet::StreamEnd),
119 // </stanza>
120 1 =>
121 self.push_queue(Packet::Stanza(el)),
122 len => {
123 let parent = &mut self.stack[len - 1];
124 parent.append_child(el);
125 },
126 }
127 }
128}
129
130impl TokenSink for ParserSink {
131 fn process_token(&mut self, token: Token) {
132 match token {
133 Token::TagToken(tag) => match tag.kind {
134 TagKind::StartTag =>
135 self.handle_start_tag(tag),
136 TagKind::EndTag =>
137 self.handle_end_tag(),
138 TagKind::EmptyTag => {
139 self.handle_start_tag(tag);
140 self.handle_end_tag();
141 },
142 TagKind::ShortTag =>
143 self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
144 },
145 Token::CharacterTokens(tendril) =>
146 match self.stack.len() {
147 0 | 1 =>
148 self.push_queue(Packet::Text(tendril.into())),
149 len => {
150 let el = &mut self.stack[len - 1];
151 el.append_text_node(tendril);
152 },
153 },
154 Token::EOFToken =>
155 self.push_queue(Packet::StreamEnd),
156 Token::ParseError(s) => {
157 println!("ParseError: {:?}", s);
158 self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
159 },
160 _ => (),
161 }
162 }
163
164 // fn end(&mut self) {
165 // }
166}
167
168pub struct XMPPCodec {
169 /// Outgoing
170 ns: Option<String>,
171 /// Incoming
172 parser: XmlTokenizer<ParserSink>,
173 /// For handling incoming truncated utf8
174 // TODO: optimize using tendrils?
175 buf: Vec<u8>,
176 /// Shared with ParserSink
177 queue: Rc<RefCell<VecDeque<Packet>>>,
178}
179
180impl XMPPCodec {
181 pub fn new() -> Self {
182 let queue = Rc::new(RefCell::new(VecDeque::new()));
183 let sink = ParserSink::new(queue.clone());
184 // TODO: configure parser?
185 let parser = XmlTokenizer::new(sink, Default::default());
186 XMPPCodec {
187 ns: None,
188 parser,
189 queue,
190 buf: vec![],
191 }
192 }
193}
194
195impl Default for XMPPCodec {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201impl Decoder for XMPPCodec {
202 type Item = Packet;
203 type Error = Error;
204
205 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
206 let buf1: Box<AsRef<[u8]>> =
207 if ! self.buf.is_empty() && ! buf.is_empty() {
208 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
209 prefix.extend_from_slice(buf.take().as_ref());
210 Box::new(prefix)
211 } else {
212 Box::new(buf.take())
213 };
214 let buf1 = buf1.as_ref().as_ref();
215 match from_utf8(buf1) {
216 Ok(s) => {
217 if ! s.is_empty() {
218 println!("<< {}", s);
219 let tendril = FromIterator::from_iter(s.chars());
220 self.parser.feed(tendril);
221 }
222 },
223 // Remedies for truncated utf8
224 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
225 // Prepare all the valid data
226 let mut b = BytesMut::with_capacity(e.valid_up_to());
227 b.put(&buf1[0..e.valid_up_to()]);
228
229 // Retry
230 let result = self.decode(&mut b);
231
232 // Keep the tail back in
233 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
234
235 return result;
236 },
237 Err(e) => {
238 println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
239 return Err(Error::new(ErrorKind::InvalidInput, e));
240 },
241 }
242
243 let result = self.queue.borrow_mut().pop_front();
244 Ok(result)
245 }
246
247 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
248 self.decode(buf)
249 }
250}
251
252impl Encoder for XMPPCodec {
253 type Item = Packet;
254 type Error = Error;
255
256 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
257 let remaining = dst.capacity() - dst.len();
258 let max_stanza_size: usize = 2usize.pow(16);
259 if remaining < max_stanza_size {
260 dst.reserve(max_stanza_size - remaining);
261 }
262
263 match item {
264 Packet::StreamStart(start_attrs) => {
265 let mut buf = String::new();
266 write!(buf, "<stream:stream").unwrap();
267 for (name, value) in start_attrs {
268 write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
269 .unwrap();
270 if name == "xmlns" {
271 self.ns = Some(value);
272 }
273 }
274 write!(buf, ">\n").unwrap();
275
276 print!(">> {}", buf);
277 write!(dst, "{}", buf)
278 .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
279 },
280 Packet::Stanza(stanza) => {
281 stanza.write_to_inner(&mut WriteBytes::new(dst))
282 .and_then(|_| {
283 println!(">> {:?}", dst);
284 Ok(())
285 })
286 .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
287 },
288 Packet::Text(text) => {
289 write_text(&text, dst)
290 .and_then(|_| {
291 println!(">> {:?}", dst);
292 Ok(())
293 })
294 .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
295 },
296 // TODO: Implement all
297 _ => Ok(())
298 }
299 }
300}
301
302pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
303 write!(writer, "{}", escape(text))
304}
305
306/// Copied from `RustyXML` for now
307pub fn escape(input: &str) -> String {
308 let mut result = String::with_capacity(input.len());
309
310 for c in input.chars() {
311 match c {
312 '&' => result.push_str("&"),
313 '<' => result.push_str("<"),
314 '>' => result.push_str(">"),
315 '\'' => result.push_str("'"),
316 '"' => result.push_str("""),
317 o => result.push(o)
318 }
319 }
320 result
321}
322
323/// BytesMut impl only std::fmt::Write but not std::io::Write. The
324/// latter trait is required for minidom's
325/// `Element::write_to_inner()`.
326struct WriteBytes<'a> {
327 dst: &'a mut BytesMut,
328}
329
330impl<'a> WriteBytes<'a> {
331 fn new(dst: &'a mut BytesMut) -> Self {
332 WriteBytes { dst }
333 }
334}
335
336impl<'a> std::io::Write for WriteBytes<'a> {
337 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
338 self.dst.put_slice(buf);
339 Ok(buf.len())
340 }
341
342 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
343 Ok(())
344 }
345}
346
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use bytes::BytesMut;
352
353 #[test]
354 fn test_stream_start() {
355 let mut c = XMPPCodec::new();
356 let mut b = BytesMut::with_capacity(1024);
357 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
358 let r = c.decode(&mut b);
359 assert!(match r {
360 Ok(Some(Packet::StreamStart(_))) => true,
361 _ => false,
362 });
363 }
364
365 #[test]
366 fn test_truncated_stanza() {
367 let mut c = XMPPCodec::new();
368 let mut b = BytesMut::with_capacity(1024);
369 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
370 let r = c.decode(&mut b);
371 assert!(match r {
372 Ok(Some(Packet::StreamStart(_))) => true,
373 _ => false,
374 });
375
376 b.clear();
377 b.put(r"<test>ß</test");
378 let r = c.decode(&mut b);
379 assert!(match r {
380 Ok(None) => true,
381 _ => false,
382 });
383
384 b.clear();
385 b.put(r">");
386 let r = c.decode(&mut b);
387 assert!(match r {
388 Ok(Some(Packet::Stanza(ref el)))
389 if el.name() == "test"
390 && el.text() == "ß"
391 => true,
392 _ => false,
393 });
394 }
395
396 #[test]
397 fn test_truncated_utf8() {
398 let mut c = XMPPCodec::new();
399 let mut b = BytesMut::with_capacity(1024);
400 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
401 let r = c.decode(&mut b);
402 assert!(match r {
403 Ok(Some(Packet::StreamStart(_))) => true,
404 _ => false,
405 });
406
407 b.clear();
408 b.put(&b"<test>\xc3"[..]);
409 let r = c.decode(&mut b);
410 assert!(match r {
411 Ok(None) => true,
412 _ => false,
413 });
414
415 b.clear();
416 b.put(&b"\x9f</test>"[..]);
417 let r = c.decode(&mut b);
418 assert!(match r {
419 Ok(Some(Packet::Stanza(ref el)))
420 if el.name() == "test"
421 && el.text() == "ß"
422 => true,
423 _ => false,
424 });
425 }
426
427 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
428 #[test]
429 fn test_large_stanza() {
430 use std::io::Cursor;
431 use futures::{Future, Sink};
432 use tokio_io::codec::FramedWrite;
433 let framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
434 let mut text = "".to_owned();
435 for _ in 0..2usize.pow(15) {
436 text = text + "A";
437 }
438 let stanza = Element::builder("message")
439 .append(
440 Element::builder("body")
441 .append(&text)
442 .build()
443 )
444 .build();
445 let framed = framed.send(Packet::Stanza(stanza))
446 .wait()
447 .expect("send");
448 assert_eq!(framed.get_ref().get_ref(), &("<message><body>".to_owned() + &text + "</body></message>").as_bytes());
449 }
450}