1use std;
2use std::default::Default;
3use std::iter::FromIterator;
4use std::cell::RefCell;
5use std::rc::Rc;
6use std::fmt::Write;
7use std::str::from_utf8;
8use std::io::{Error, ErrorKind};
9use std::collections::HashMap;
10use std::collections::vec_deque::VecDeque;
11use tokio_io::codec::{Encoder, Decoder};
12use minidom::{Element, Node};
13use xml5ever::tokenizer::{XmlTokenizer, TokenSink, Token, Tag, TagKind};
14use xml5ever::interface::Attribute;
15use bytes::*;
16
17// const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
18
19#[derive(Debug)]
20pub enum Packet {
21 Error(Box<std::error::Error>),
22 StreamStart(HashMap<String, String>),
23 Stanza(Element),
24 Text(String),
25 StreamEnd,
26}
27
28struct ParserSink {
29 // Ready stanzas, shared with XMPPCodec
30 queue: Rc<RefCell<VecDeque<Packet>>>,
31 // Parsing stack
32 stack: Vec<Element>,
33 ns_stack: Vec<HashMap<Option<String>, String>>,
34}
35
36impl ParserSink {
37 pub fn new(queue: Rc<RefCell<VecDeque<Packet>>>) -> Self {
38 ParserSink {
39 queue,
40 stack: vec![],
41 ns_stack: vec![],
42 }
43 }
44
45 fn push_queue(&self, pkt: Packet) {
46 self.queue.borrow_mut().push_back(pkt);
47 }
48
49 fn lookup_ns(&self, prefix: &Option<String>) -> Option<&str> {
50 for nss in self.ns_stack.iter().rev() {
51 match nss.get(prefix) {
52 Some(ns) => return Some(ns),
53 None => (),
54 }
55 }
56
57 None
58 }
59
60 fn handle_start_tag(&mut self, tag: Tag) {
61 let mut nss = HashMap::new();
62 let is_prefix_xmlns = |attr: &Attribute| attr.name.prefix.as_ref()
63 .map(|prefix| prefix.eq_str_ignore_ascii_case("xmlns"))
64 .unwrap_or(false);
65 for attr in &tag.attrs {
66 match attr.name.local.as_ref() {
67 "xmlns" => {
68 nss.insert(None, attr.value.as_ref().to_owned());
69 },
70 prefix if is_prefix_xmlns(attr) => {
71 nss.insert(Some(prefix.to_owned()), attr.value.as_ref().to_owned());
72 },
73 _ => (),
74 }
75 }
76 self.ns_stack.push(nss);
77
78 let el = {
79 let mut el_builder = Element::builder(tag.name.local.as_ref());
80 match self.lookup_ns(&tag.name.prefix.map(|prefix| prefix.as_ref().to_owned())) {
81 Some(el_ns) => el_builder = el_builder.ns(el_ns),
82 None => (),
83 }
84 for attr in &tag.attrs {
85 match attr.name.local.as_ref() {
86 "xmlns" => (),
87 _ if is_prefix_xmlns(attr) => (),
88 _ => {
89 el_builder = el_builder.attr(
90 attr.name.local.as_ref(),
91 attr.value.as_ref()
92 );
93 },
94 }
95 }
96 el_builder.build()
97 };
98
99 if self.stack.is_empty() {
100 let attrs = HashMap::from_iter(
101 el.attrs()
102 .map(|(name, value)| (name.to_owned(), value.to_owned()))
103 );
104 self.push_queue(Packet::StreamStart(attrs));
105 }
106
107 self.stack.push(el);
108 }
109
110 fn handle_end_tag(&mut self) {
111 let el = self.stack.pop().unwrap();
112 self.ns_stack.pop();
113
114 match self.stack.len() {
115 // </stream:stream>
116 0 =>
117 self.push_queue(Packet::StreamEnd),
118 // </stanza>
119 1 =>
120 self.push_queue(Packet::Stanza(el)),
121 len => {
122 let parent = &mut self.stack[len - 1];
123 parent.append_child(el);
124 },
125 }
126 }
127}
128
129impl TokenSink for ParserSink {
130 fn process_token(&mut self, token: Token) {
131 match token {
132 Token::TagToken(tag) => match tag.kind {
133 TagKind::StartTag =>
134 self.handle_start_tag(tag),
135 TagKind::EndTag =>
136 self.handle_end_tag(),
137 TagKind::EmptyTag => {
138 self.handle_start_tag(tag);
139 self.handle_end_tag();
140 },
141 TagKind::ShortTag =>
142 self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, "ShortTag")))),
143 },
144 Token::CharacterTokens(tendril) =>
145 match self.stack.len() {
146 0 | 1 =>
147 self.push_queue(Packet::Text(tendril.into())),
148 len => {
149 let el = &mut self.stack[len - 1];
150 el.append_text_node(tendril);
151 },
152 },
153 Token::EOFToken =>
154 self.push_queue(Packet::StreamEnd),
155 Token::ParseError(s) => {
156 println!("ParseError: {:?}", s);
157 self.push_queue(Packet::Error(Box::new(Error::new(ErrorKind::InvalidInput, (*s).to_owned()))))
158 },
159 _ => (),
160 }
161 }
162
163 // fn end(&mut self) {
164 // }
165}
166
167pub struct XMPPCodec {
168 /// Outgoing
169 ns: Option<String>,
170 /// Incoming
171 parser: XmlTokenizer<ParserSink>,
172 /// For handling incoming truncated utf8
173 // TODO: optimize using tendrils?
174 buf: Vec<u8>,
175 /// Shared with ParserSink
176 queue: Rc<RefCell<VecDeque<Packet>>>,
177}
178
179impl XMPPCodec {
180 pub fn new() -> Self {
181 let queue = Rc::new(RefCell::new((VecDeque::new())));
182 let sink = ParserSink::new(queue.clone());
183 // TODO: configure parser?
184 let parser = XmlTokenizer::new(sink, Default::default());
185 XMPPCodec {
186 ns: None,
187 parser,
188 queue,
189 buf: vec![],
190 }
191 }
192}
193
194impl Decoder for XMPPCodec {
195 type Item = Packet;
196 type Error = Error;
197
198 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
199 let buf1: Box<AsRef<[u8]>> =
200 if self.buf.len() > 0 && buf.len() > 0 {
201 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
202 prefix.extend_from_slice(buf.take().as_ref());
203 Box::new(prefix)
204 } else {
205 Box::new(buf.take())
206 };
207 let buf1 = buf1.as_ref().as_ref();
208 match from_utf8(buf1) {
209 Ok(s) => {
210 if s.len() > 0 {
211 println!("<< {}", s);
212 let tendril = FromIterator::from_iter(s.chars());
213 self.parser.feed(tendril);
214 }
215 },
216 // Remedies for truncated utf8
217 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
218 // Prepare all the valid data
219 let mut b = BytesMut::with_capacity(e.valid_up_to());
220 b.put(&buf1[0..e.valid_up_to()]);
221
222 // Retry
223 let result = self.decode(&mut b);
224
225 // Keep the tail back in
226 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
227
228 return result;
229 },
230 Err(e) => {
231 println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
232 return Err(Error::new(ErrorKind::InvalidInput, e));
233 },
234 }
235
236 let result = self.queue.borrow_mut().pop_front();
237 Ok(result)
238 }
239
240 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
241 self.decode(buf)
242 }
243}
244
245impl Encoder for XMPPCodec {
246 type Item = Packet;
247 type Error = Error;
248
249 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
250 match item {
251 Packet::StreamStart(start_attrs) => {
252 let mut buf = String::new();
253 write!(buf, "<stream:stream").unwrap();
254 for (name, value) in start_attrs.into_iter() {
255 write!(buf, " {}=\"{}\"", escape(&name), escape(&value))
256 .unwrap();
257 if name == "xmlns" {
258 self.ns = Some(value);
259 }
260 }
261 write!(buf, ">\n").unwrap();
262
263 print!(">> {}", buf);
264 write!(dst, "{}", buf)
265 .map_err(|e| Error::new(ErrorKind::InvalidInput, e))
266 },
267 Packet::Stanza(stanza) => {
268 let root_ns = self.ns.as_ref().map(|s| s.as_ref());
269 write_element(&stanza, dst, root_ns)
270 .and_then(|_| {
271 println!(">> {:?}", dst);
272 Ok(())
273 })
274 .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
275 },
276 Packet::Text(text) => {
277 write_text(&text, dst)
278 .and_then(|_| {
279 println!(">> {:?}", dst);
280 Ok(())
281 })
282 .map_err(|e| Error::new(ErrorKind::InvalidInput, format!("{}", e)))
283 },
284 // TODO: Implement all
285 _ => Ok(())
286 }
287 }
288}
289
290pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
291 write!(writer, "{}", text)
292}
293
294// TODO: escape everything?
295pub fn write_element<W: Write>(el: &Element, writer: &mut W, parent_ns: Option<&str>) -> Result<(), std::fmt::Error> {
296 write!(writer, "<")?;
297 write!(writer, "{}", el.name())?;
298
299 if let Some(ref ns) = el.ns() {
300 if parent_ns.map(|s| s.as_ref()) != el.ns() {
301 write!(writer, " xmlns=\"{}\"", ns)?;
302 }
303 }
304
305 for (key, value) in el.attrs() {
306 write!(writer, " {}=\"{}\"", key, value)?;
307 }
308
309 if ! el.nodes().any(|_| true) {
310 write!(writer, " />")?;
311 return Ok(())
312 }
313
314 write!(writer, ">")?;
315
316 for node in el.nodes() {
317 match node {
318 &Node::Element(ref child) =>
319 write_element(child, writer, el.ns())?,
320 &Node::Text(ref text) =>
321 write_text(text, writer)?,
322 }
323 }
324
325 write!(writer, "</{}>", el.name())?;
326 Ok(())
327}
328
329/// Copied from RustyXML for now
330pub fn escape(input: &str) -> String {
331 let mut result = String::with_capacity(input.len());
332
333 for c in input.chars() {
334 match c {
335 '&' => result.push_str("&"),
336 '<' => result.push_str("<"),
337 '>' => result.push_str(">"),
338 '\'' => result.push_str("'"),
339 '"' => result.push_str("""),
340 o => result.push(o)
341 }
342 }
343 result
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use bytes::BytesMut;
350
351 #[test]
352 fn test_stream_start() {
353 let mut c = XMPPCodec::new();
354 let mut b = BytesMut::with_capacity(1024);
355 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
356 let r = c.decode(&mut b);
357 assert!(match r {
358 Ok(Some(Packet::StreamStart(_))) => true,
359 _ => false,
360 });
361 }
362
363 #[test]
364 fn test_truncated_stanza() {
365 let mut c = XMPPCodec::new();
366 let mut b = BytesMut::with_capacity(1024);
367 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
368 let r = c.decode(&mut b);
369 assert!(match r {
370 Ok(Some(Packet::StreamStart(_))) => true,
371 _ => false,
372 });
373
374 b.clear();
375 b.put(r"<test>ß</test");
376 let r = c.decode(&mut b);
377 assert!(match r {
378 Ok(None) => true,
379 _ => false,
380 });
381
382 b.clear();
383 b.put(r">");
384 let r = c.decode(&mut b);
385 assert!(match r {
386 Ok(Some(Packet::Stanza(ref el)))
387 if el.name() == "test"
388 && el.text() == "ß"
389 => true,
390 _ => false,
391 });
392 }
393
394 #[test]
395 fn test_truncated_utf8() {
396 let mut c = XMPPCodec::new();
397 let mut b = BytesMut::with_capacity(1024);
398 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
399 let r = c.decode(&mut b);
400 assert!(match r {
401 Ok(Some(Packet::StreamStart(_))) => true,
402 _ => false,
403 });
404
405 b.clear();
406 b.put(&b"<test>\xc3"[..]);
407 let r = c.decode(&mut b);
408 assert!(match r {
409 Ok(None) => true,
410 _ => false,
411 });
412
413 b.clear();
414 b.put(&b"\x9f</test>"[..]);
415 let r = c.decode(&mut b);
416 assert!(match r {
417 Ok(Some(Packet::Stanza(ref el)))
418 if el.name() == "test"
419 && el.text() == "ß"
420 => true,
421 _ => false,
422 });
423 }
424}