1use std;
2use std::fmt::Write;
3use std::str::from_utf8;
4use std::io::{Error, ErrorKind};
5use std::collections::HashMap;
6use tokio_io::codec::{Encoder, Decoder};
7use xml;
8use bytes::*;
9
10const NS_XMLNS: &'static str = "http://www.w3.org/2000/xmlns/";
11
12pub type Attributes = HashMap<(String, Option<String>), String>;
13
14struct XMPPRoot {
15 builder: xml::ElementBuilder,
16 pub attributes: Attributes,
17}
18
19impl XMPPRoot {
20 fn new(root: xml::StartTag) -> Self {
21 let mut builder = xml::ElementBuilder::new();
22 let mut attributes = HashMap::new();
23 for (name_ns, value) in root.attributes {
24 match name_ns {
25 (ref name, None) if name == "xmlns" =>
26 builder.set_default_ns(value),
27 (ref prefix, Some(ref ns)) if ns == NS_XMLNS =>
28 builder.define_prefix(prefix.to_owned(), value),
29 _ => {
30 attributes.insert(name_ns, value);
31 },
32 }
33 }
34
35 XMPPRoot {
36 builder: builder,
37 attributes: attributes,
38 }
39 }
40
41 fn handle_event(&mut self, event: Result<xml::Event, xml::ParserError>)
42 -> Option<Result<xml::Element, xml::BuilderError>> {
43 self.builder.handle_event(event)
44 }
45}
46
47#[derive(Debug)]
48pub enum Packet {
49 Error(Box<std::error::Error>),
50 StreamStart(HashMap<String, String>),
51 Stanza(xml::Element),
52 Text(String),
53 StreamEnd,
54}
55
56pub struct XMPPCodec {
57 parser: xml::Parser,
58 root: Option<XMPPRoot>,
59 buf: Vec<u8>,
60}
61
62impl XMPPCodec {
63 pub fn new() -> Self {
64 XMPPCodec {
65 parser: xml::Parser::new(),
66 root: None,
67 buf: vec![],
68 }
69 }
70}
71
72impl Decoder for XMPPCodec {
73 type Item = Packet;
74 type Error = Error;
75
76 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
77 let buf1: Box<AsRef<[u8]>> =
78 if self.buf.len() > 0 && buf.len() > 0 {
79 let mut prefix = std::mem::replace(&mut self.buf, vec![]);
80 prefix.extend_from_slice(buf.take().as_ref());
81 Box::new(prefix)
82 } else {
83 Box::new(buf.take())
84 };
85 let buf1 = buf1.as_ref().as_ref();
86 match from_utf8(buf1) {
87 Ok(s) => {
88 if s.len() > 0 {
89 println!("<< {}", s);
90 self.parser.feed_str(s);
91 }
92 },
93 // Remedies for truncated utf8
94 Err(e) if e.valid_up_to() >= buf1.len() - 3 => {
95 // Prepare all the valid data
96 let mut b = BytesMut::with_capacity(e.valid_up_to());
97 b.put(&buf1[0..e.valid_up_to()]);
98
99 // Retry
100 let result = self.decode(&mut b);
101
102 // Keep the tail back in
103 self.buf.extend_from_slice(&buf1[e.valid_up_to()..]);
104
105 return result;
106 },
107 Err(e) => {
108 println!("error {} at {}/{} in {:?}", e, e.valid_up_to(), buf1.len(), buf1);
109 return Err(Error::new(ErrorKind::InvalidInput, e));
110 },
111 }
112
113 let mut new_root: Option<XMPPRoot> = None;
114 let mut result = None;
115 for event in &mut self.parser {
116 match self.root {
117 None => {
118 // Expecting <stream:stream>
119 match event {
120 Ok(xml::Event::ElementStart(start_tag)) => {
121 let mut attrs: HashMap<String, String> = HashMap::new();
122 for (&(ref name, _), value) in &start_tag.attributes {
123 attrs.insert(name.to_owned(), value.to_owned());
124 }
125 result = Some(Packet::StreamStart(attrs));
126 self.root = Some(XMPPRoot::new(start_tag));
127 break
128 },
129 Err(e) => {
130 result = Some(Packet::Error(Box::new(e)));
131 break
132 },
133 _ =>
134 (),
135 }
136 }
137
138 Some(ref mut root) => {
139 match root.handle_event(event) {
140 None => (),
141 Some(Ok(stanza)) => {
142 // Emit the stanza
143 result = Some(Packet::Stanza(stanza));
144 break
145 },
146 Some(Err(e)) => {
147 result = Some(Packet::Error(Box::new(e)));
148 break
149 }
150 };
151 },
152 }
153
154 match new_root.take() {
155 None => (),
156 Some(root) => self.root = Some(root),
157 }
158 }
159
160 Ok(result)
161 }
162
163 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Error> {
164 self.decode(buf)
165 }
166}
167
168impl Encoder for XMPPCodec {
169 type Item = Packet;
170 type Error = Error;
171
172 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
173 match item {
174 Packet::StreamStart(start_attrs) => {
175 let mut buf = String::new();
176 write!(buf, "<stream:stream").unwrap();
177 for (ref name, ref value) in &start_attrs {
178 write!(buf, " {}=\"{}\"", xml::escape(&name), xml::escape(&value))
179 .unwrap();
180 }
181 write!(buf, ">\n").unwrap();
182
183 print!(">> {}", buf);
184 write!(dst, "{}", buf)
185 },
186 Packet::Stanza(stanza) => {
187 println!(">> {}", stanza);
188 write!(dst, "{}", stanza)
189 },
190 Packet::Text(text) => {
191 let escaped = xml::escape(&text);
192 println!(">> {}", escaped);
193 write!(dst, "{}", escaped)
194 },
195 // TODO: Implement all
196 _ => Ok(())
197 }
198 .map_err(|_| Error::from(ErrorKind::InvalidInput))
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use bytes::BytesMut;
206
207 #[test]
208 fn test_stream_start() {
209 let mut c = XMPPCodec::new();
210 let mut b = BytesMut::with_capacity(1024);
211 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
212 let r = c.decode(&mut b);
213 assert!(match r {
214 Ok(Some(Packet::StreamStart(_))) => true,
215 _ => false,
216 });
217 }
218
219 #[test]
220 fn test_truncated_stanza() {
221 let mut c = XMPPCodec::new();
222 let mut b = BytesMut::with_capacity(1024);
223 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
224 let r = c.decode(&mut b);
225 assert!(match r {
226 Ok(Some(Packet::StreamStart(_))) => true,
227 _ => false,
228 });
229
230 b.clear();
231 b.put(r"<test>ß</test");
232 let r = c.decode(&mut b);
233 assert!(match r {
234 Ok(None) => true,
235 _ => false,
236 });
237
238 b.clear();
239 b.put(r">");
240 let r = c.decode(&mut b);
241 assert!(match r {
242 Ok(Some(Packet::Stanza(ref el)))
243 if el.name == "test"
244 && el.content_str() == "ß"
245 => true,
246 _ => false,
247 });
248 }
249
250 #[test]
251 fn test_truncated_utf8() {
252 let mut c = XMPPCodec::new();
253 let mut b = BytesMut::with_capacity(1024);
254 b.put(r"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
255 let r = c.decode(&mut b);
256 assert!(match r {
257 Ok(Some(Packet::StreamStart(_))) => true,
258 _ => false,
259 });
260
261 b.clear();
262 b.put(&b"<test>\xc3"[..]);
263 let r = c.decode(&mut b);
264 assert!(match r {
265 Ok(None) => true,
266 _ => false,
267 });
268
269 b.clear();
270 b.put(&b"\x9f</test>"[..]);
271 let r = c.decode(&mut b);
272 assert!(match r {
273 Ok(Some(Packet::Stanza(ref el)))
274 if el.name == "test"
275 && el.content_str() == "ß"
276 => true,
277 _ => false,
278 });
279 }
280}