1//! XML stream parser for XMPP
2
3use crate::Error;
4use bytes::{BufMut, BytesMut};
5use log::debug;
6use minidom::tree_builder::TreeBuilder;
7use rxml::{Lexer, PushDriver, RawParser};
8use std;
9use std::collections::HashMap;
10use std::default::Default;
11use std::fmt::Write;
12use std::io;
13use tokio_util::codec::{Decoder, Encoder};
14use xmpp_parsers::Element;
15
16/// Anything that can be sent or received on an XMPP/XML stream
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum Packet {
19 /// `<stream:stream>` start tag
20 StreamStart(HashMap<String, String>),
21 /// A complete stanza or nonza
22 Stanza(Element),
23 /// Plain text (think whitespace keep-alive)
24 Text(String),
25 /// `</stream:stream>` closing tag
26 StreamEnd,
27}
28
29/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
30pub struct XMPPCodec {
31 /// Outgoing
32 ns: Option<String>,
33 /// Incoming
34 driver: PushDriver<RawParser>,
35 stanza_builder: TreeBuilder,
36}
37
38impl XMPPCodec {
39 /// Constructor
40 pub fn new() -> Self {
41 let stanza_builder = TreeBuilder::new();
42 let driver = PushDriver::wrap(Lexer::new(), RawParser::new());
43 XMPPCodec {
44 ns: None,
45 driver,
46 stanza_builder,
47 }
48 }
49}
50
51impl Default for XMPPCodec {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl Decoder for XMPPCodec {
58 type Item = Packet;
59 type Error = Error;
60
61 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
62 loop {
63 let token = match self.driver.parse(buf, false) {
64 Ok(Some(token)) => token,
65 Ok(None) => break,
66 Err(rxml::Error::IO(e)) if e.kind() == std::io::ErrorKind::WouldBlock => break,
67 Err(e) => return Err(minidom::Error::from(e).into()),
68 };
69
70 let had_stream_root = self.stanza_builder.depth() > 0;
71 self.stanza_builder.process_event(token)?;
72 let has_stream_root = self.stanza_builder.depth() > 0;
73
74 if !had_stream_root && has_stream_root {
75 let root = self.stanza_builder.top().unwrap();
76 let attrs =
77 root.attrs()
78 .map(|(name, value)| (name.to_owned(), value.to_owned()))
79 .chain(root.prefixes.declared_prefixes().iter().map(
80 |(prefix, namespace)| {
81 (
82 prefix
83 .as_ref()
84 .map(|prefix| format!("xmlns:{}", prefix))
85 .unwrap_or_else(|| "xmlns".to_owned()),
86 namespace.clone(),
87 )
88 },
89 ))
90 .collect();
91 debug!("<< {}", String::from(root));
92 return Ok(Some(Packet::StreamStart(attrs)));
93 } else if self.stanza_builder.depth() == 1 {
94 self.driver.release_temporaries();
95
96 if let Some(stanza) = self.stanza_builder.unshift_child() {
97 debug!("<< {}", String::from(&stanza));
98 return Ok(Some(Packet::Stanza(stanza)));
99 }
100 } else if let Some(_) = self.stanza_builder.root.take() {
101 self.driver.release_temporaries();
102
103 debug!("<< </stream:stream>");
104 return Ok(Some(Packet::StreamEnd));
105 }
106 }
107
108 Ok(None)
109 }
110
111 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
112 self.decode(buf)
113 }
114}
115
116impl Encoder<Packet> for XMPPCodec {
117 type Error = Error;
118
119 fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
120 let remaining = dst.capacity() - dst.len();
121 let max_stanza_size: usize = 2usize.pow(16);
122 if remaining < max_stanza_size {
123 dst.reserve(max_stanza_size - remaining);
124 }
125
126 fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
127 io::Error::new(io::ErrorKind::InvalidInput, e)
128 }
129
130 match item {
131 Packet::StreamStart(start_attrs) => {
132 let mut buf = String::new();
133 write!(buf, "<stream:stream").map_err(to_io_err)?;
134 for (name, value) in start_attrs {
135 write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
136 if name == "xmlns" {
137 self.ns = Some(value);
138 }
139 }
140 write!(buf, ">\n").map_err(to_io_err)?;
141
142 let utf8 = std::str::from_utf8(dst)?;
143 debug!(">> {}", utf8);
144 write!(dst, "{}", buf)?
145 }
146 Packet::Stanza(stanza) => {
147 let _ = stanza
148 .write_to(&mut WriteBytes::new(dst))
149 .map_err(|e| to_io_err(format!("{}", e)))?;
150 let utf8 = std::str::from_utf8(dst)?;
151 debug!(">> {}", utf8);
152 }
153 Packet::Text(text) => {
154 let _ = write_text(&text, dst).map_err(to_io_err)?;
155 let utf8 = std::str::from_utf8(dst)?;
156 debug!(">> {}", utf8);
157 }
158 Packet::StreamEnd => {
159 let _ = write!(dst, "</stream:stream>\n").map_err(to_io_err);
160 }
161 }
162
163 Ok(())
164 }
165}
166
167/// Write XML-escaped text string
168pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
169 write!(writer, "{}", escape(text))
170}
171
172/// Copied from `RustyXML` for now
173pub fn escape(input: &str) -> String {
174 let mut result = String::with_capacity(input.len());
175
176 for c in input.chars() {
177 match c {
178 '&' => result.push_str("&"),
179 '<' => result.push_str("<"),
180 '>' => result.push_str(">"),
181 '\'' => result.push_str("'"),
182 '"' => result.push_str("""),
183 o => result.push(o),
184 }
185 }
186 result
187}
188
189/// BytesMut impl only std::fmt::Write but not std::io::Write. The
190/// latter trait is required for minidom's
191/// `Element::write_to_inner()`.
192struct WriteBytes<'a> {
193 dst: &'a mut BytesMut,
194}
195
196impl<'a> WriteBytes<'a> {
197 fn new(dst: &'a mut BytesMut) -> Self {
198 WriteBytes { dst }
199 }
200}
201
202impl<'a> std::io::Write for WriteBytes<'a> {
203 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
204 self.dst.put_slice(buf);
205 Ok(buf.len())
206 }
207
208 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
209 Ok(())
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use bytes::BytesMut;
217
218 #[test]
219 fn test_stream_start() {
220 let mut c = XMPPCodec::new();
221 let mut b = BytesMut::with_capacity(1024);
222 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
223 let r = c.decode(&mut b);
224 assert!(match r {
225 Ok(Some(Packet::StreamStart(_))) => true,
226 _ => false,
227 });
228 }
229
230 #[test]
231 fn test_stream_end() {
232 let mut c = XMPPCodec::new();
233 let mut b = BytesMut::with_capacity(1024);
234 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
235 let r = c.decode(&mut b);
236 assert!(match r {
237 Ok(Some(Packet::StreamStart(_))) => true,
238 _ => false,
239 });
240 b.put_slice(b"</stream:stream>");
241 let r = c.decode(&mut b);
242 assert!(match r {
243 Ok(Some(Packet::StreamEnd)) => true,
244 _ => false,
245 });
246 }
247
248 #[test]
249 fn test_truncated_stanza() {
250 let mut c = XMPPCodec::new();
251 let mut b = BytesMut::with_capacity(1024);
252 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
253 let r = c.decode(&mut b);
254 assert!(match r {
255 Ok(Some(Packet::StreamStart(_))) => true,
256 _ => false,
257 });
258
259 b.put_slice("<test>ß</test".as_bytes());
260 let r = c.decode(&mut b);
261 assert!(match r {
262 Ok(None) => true,
263 _ => false,
264 });
265
266 b.put_slice(b">");
267 let r = c.decode(&mut b);
268 assert!(match r {
269 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
270 _ => false,
271 });
272 }
273
274 #[test]
275 fn test_truncated_utf8() {
276 let mut c = XMPPCodec::new();
277 let mut b = BytesMut::with_capacity(1024);
278 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
279 let r = c.decode(&mut b);
280 assert!(match r {
281 Ok(Some(Packet::StreamStart(_))) => true,
282 _ => false,
283 });
284
285 b.put(&b"<test>\xc3"[..]);
286 let r = c.decode(&mut b);
287 assert!(match r {
288 Ok(None) => true,
289 _ => false,
290 });
291
292 b.put(&b"\x9f</test>"[..]);
293 let r = c.decode(&mut b);
294 assert!(match r {
295 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
296 _ => false,
297 });
298 }
299
300 /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
301 #[test]
302 fn test_atrribute_prefix() {
303 let mut c = XMPPCodec::new();
304 let mut b = BytesMut::with_capacity(1024);
305 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
306 let r = c.decode(&mut b);
307 assert!(match r {
308 Ok(Some(Packet::StreamStart(_))) => true,
309 _ => false,
310 });
311
312 b.put_slice(b"<status xml:lang='en'>Test status</status>");
313 let r = c.decode(&mut b);
314 assert!(match r {
315 Ok(Some(Packet::Stanza(ref el)))
316 if el.name() == "status"
317 && el.text() == "Test status"
318 && el.attr("xml:lang").map_or(false, |a| a == "en") =>
319 true,
320 _ => false,
321 });
322 }
323
324 /// By default, encode() only get's a BytesMut that has 8kb space reserved.
325 #[test]
326 fn test_large_stanza() {
327 use futures::{executor::block_on, sink::SinkExt};
328 use std::io::Cursor;
329 use tokio_util::codec::FramedWrite;
330 let mut framed = FramedWrite::new(Cursor::new(vec![]), XMPPCodec::new());
331 let mut text = "".to_owned();
332 for _ in 0..2usize.pow(15) {
333 text = text + "A";
334 }
335 let stanza = Element::builder("message", "jabber:client")
336 .append(
337 Element::builder("body", "jabber:client")
338 .append(text.as_ref())
339 .build(),
340 )
341 .build();
342 block_on(framed.send(Packet::Stanza(stanza))).expect("send");
343 assert_eq!(
344 framed.get_ref().get_ref(),
345 &format!(
346 "<message xmlns='jabber:client'><body>{}</body></message>",
347 text
348 )
349 .as_bytes()
350 );
351 }
352
353 #[test]
354 fn test_cut_out_stanza() {
355 let mut c = XMPPCodec::new();
356 let mut b = BytesMut::with_capacity(1024);
357 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
358 let r = c.decode(&mut b);
359 assert!(match r {
360 Ok(Some(Packet::StreamStart(_))) => true,
361 _ => false,
362 });
363
364 b.put_slice(b"<message ");
365 b.put_slice(b"type='chat'><body>Foo</body></message>");
366 let r = c.decode(&mut b);
367 assert!(match r {
368 Ok(Some(Packet::Stanza(_))) => true,
369 _ => false,
370 });
371 }
372}