1//! XML stream parser for XMPP
2
3use crate::Error;
4use bytes::{BufMut, BytesMut};
5use log::debug;
6use minidom::tree_builder::TreeBuilder;
7use rxml::{Parse, RawParser};
8use std::collections::HashMap;
9use std::fmt::Write;
10use std::io;
11#[cfg(feature = "syntax-highlighting")]
12use std::sync::OnceLock;
13use tokio_util::codec::{Decoder, Encoder};
14use xmpp_parsers::Element;
15
16#[cfg(feature = "syntax-highlighting")]
17static PS: OnceLock<syntect::parsing::SyntaxSet> = OnceLock::new();
18#[cfg(feature = "syntax-highlighting")]
19static SYNTAX: OnceLock<syntect::parsing::SyntaxReference> = OnceLock::new();
20#[cfg(feature = "syntax-highlighting")]
21static THEME: OnceLock<syntect::highlighting::Theme> = OnceLock::new();
22
23#[cfg(feature = "syntax-highlighting")]
24fn init_syntect() {
25 let ps = syntect::parsing::SyntaxSet::load_defaults_newlines();
26 let syntax = ps.find_syntax_by_extension("xml").unwrap();
27 let ts = syntect::highlighting::ThemeSet::load_defaults();
28 let theme = ts.themes["Solarized (dark)"].clone();
29
30 SYNTAX.set(syntax.clone()).unwrap();
31 PS.set(ps).unwrap();
32 THEME.set(theme).unwrap();
33}
34
35#[cfg(feature = "syntax-highlighting")]
36fn highlight_xml(xml: &str) -> String {
37 let mut h = syntect::easy::HighlightLines::new(SYNTAX.get().unwrap(), THEME.get().unwrap());
38 let ranges: Vec<_> = h.highlight_line(&xml, PS.get().unwrap()).unwrap();
39 let escaped = syntect::util::as_24_bit_terminal_escaped(&ranges[..], false);
40 format!("{}\x1b[0m", escaped)
41}
42
43#[cfg(not(feature = "syntax-highlighting"))]
44fn highlight_xml(xml: &str) -> &str {
45 xml
46}
47
48/// Anything that can be sent or received on an XMPP/XML stream
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum Packet {
51 /// `<stream:stream>` start tag
52 StreamStart(HashMap<String, String>),
53 /// A complete stanza or nonza
54 Stanza(Element),
55 /// Plain text (think whitespace keep-alive)
56 Text(String),
57 /// `</stream:stream>` closing tag
58 StreamEnd,
59}
60
61/// Stateful encoder/decoder for a bytestream from/to XMPP `Packet`
62pub struct XmppCodec {
63 /// Outgoing
64 ns: Option<String>,
65 /// Incoming
66 driver: RawParser,
67 stanza_builder: TreeBuilder,
68}
69
70impl XmppCodec {
71 /// Constructor
72 pub fn new() -> Self {
73 let stanza_builder = TreeBuilder::new();
74 let driver = RawParser::new();
75 #[cfg(feature = "syntax-highlighting")]
76 if log::log_enabled!(log::Level::Debug) && PS.get().is_none() {
77 init_syntect();
78 }
79 XmppCodec {
80 ns: None,
81 driver,
82 stanza_builder,
83 }
84 }
85}
86
87impl Default for XmppCodec {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl Decoder for XmppCodec {
94 type Item = Packet;
95 type Error = Error;
96
97 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
98 loop {
99 let token = match self.driver.parse_buf(buf, false) {
100 Ok(Some(token)) => token,
101 Ok(None) => break,
102 Err(rxml::Error::IO(e)) if e.kind() == std::io::ErrorKind::WouldBlock => break,
103 Err(e) => return Err(minidom::Error::from(e).into()),
104 };
105
106 let had_stream_root = self.stanza_builder.depth() > 0;
107 self.stanza_builder.process_event(token)?;
108 let has_stream_root = self.stanza_builder.depth() > 0;
109
110 if !had_stream_root && has_stream_root {
111 let root = self.stanza_builder.top().unwrap();
112 let attrs =
113 root.attrs()
114 .map(|(name, value)| (name.to_owned(), value.to_owned()))
115 .chain(root.prefixes.declared_prefixes().iter().map(
116 |(prefix, namespace)| {
117 (
118 prefix
119 .as_ref()
120 .map(|prefix| format!("xmlns:{}", prefix))
121 .unwrap_or_else(|| "xmlns".to_owned()),
122 namespace.clone(),
123 )
124 },
125 ))
126 .collect();
127 debug!("<< {}", highlight_xml(&String::from(root)));
128 return Ok(Some(Packet::StreamStart(attrs)));
129 } else if self.stanza_builder.depth() == 1 {
130 self.driver.release_temporaries();
131
132 if let Some(stanza) = self.stanza_builder.unshift_child() {
133 debug!("<< {}", highlight_xml(&String::from(&stanza)));
134 return Ok(Some(Packet::Stanza(stanza)));
135 }
136 } else if let Some(_) = self.stanza_builder.root.take() {
137 self.driver.release_temporaries();
138
139 debug!("<< {}", highlight_xml("</stream:stream>"));
140 return Ok(Some(Packet::StreamEnd));
141 }
142 }
143
144 Ok(None)
145 }
146
147 fn decode_eof(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
148 self.decode(buf)
149 }
150}
151
152impl Encoder<Packet> for XmppCodec {
153 type Error = Error;
154
155 fn encode(&mut self, item: Packet, dst: &mut BytesMut) -> Result<(), Self::Error> {
156 let remaining = dst.capacity() - dst.len();
157 let max_stanza_size: usize = 2usize.pow(16);
158 if remaining < max_stanza_size {
159 dst.reserve(max_stanza_size - remaining);
160 }
161
162 fn to_io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
163 io::Error::new(io::ErrorKind::InvalidInput, e)
164 }
165
166 match item {
167 Packet::StreamStart(start_attrs) => {
168 let mut buf = String::new();
169 write!(buf, "<stream:stream").map_err(to_io_err)?;
170 for (name, value) in start_attrs {
171 write!(buf, " {}=\"{}\"", escape(&name), escape(&value)).map_err(to_io_err)?;
172 if name == "xmlns" {
173 self.ns = Some(value);
174 }
175 }
176 write!(buf, ">").map_err(to_io_err)?;
177
178 write!(dst, "{}", buf)?;
179 let utf8 = std::str::from_utf8(dst)?;
180 debug!(">> {}", highlight_xml(utf8))
181 }
182 Packet::Stanza(stanza) => {
183 let _ = stanza
184 .write_to(&mut WriteBytes::new(dst))
185 .map_err(|e| to_io_err(format!("{}", e)))?;
186 let utf8 = std::str::from_utf8(dst)?;
187 debug!(">> {}", highlight_xml(utf8));
188 }
189 Packet::Text(text) => {
190 let _ = write_text(&text, dst).map_err(to_io_err)?;
191 let utf8 = std::str::from_utf8(dst)?;
192 debug!(">> {}", highlight_xml(utf8));
193 }
194 Packet::StreamEnd => {
195 let _ = write!(dst, "</stream:stream>\n").map_err(to_io_err);
196 debug!(">> {}", highlight_xml("</stream:stream>"));
197 }
198 }
199
200 Ok(())
201 }
202}
203
204/// Write XML-escaped text string
205pub fn write_text<W: Write>(text: &str, writer: &mut W) -> Result<(), std::fmt::Error> {
206 write!(writer, "{}", escape(text))
207}
208
209/// Copied from `RustyXML` for now
210pub fn escape(input: &str) -> String {
211 let mut result = String::with_capacity(input.len());
212
213 for c in input.chars() {
214 match c {
215 '&' => result.push_str("&"),
216 '<' => result.push_str("<"),
217 '>' => result.push_str(">"),
218 '\'' => result.push_str("'"),
219 '"' => result.push_str("""),
220 o => result.push(o),
221 }
222 }
223 result
224}
225
226/// BytesMut impl only std::fmt::Write but not std::io::Write. The
227/// latter trait is required for minidom's
228/// `Element::write_to_inner()`.
229struct WriteBytes<'a> {
230 dst: &'a mut BytesMut,
231}
232
233impl<'a> WriteBytes<'a> {
234 fn new(dst: &'a mut BytesMut) -> Self {
235 WriteBytes { dst }
236 }
237}
238
239impl<'a> std::io::Write for WriteBytes<'a> {
240 fn write(&mut self, buf: &[u8]) -> std::result::Result<usize, std::io::Error> {
241 self.dst.put_slice(buf);
242 Ok(buf.len())
243 }
244
245 fn flush(&mut self) -> std::result::Result<(), std::io::Error> {
246 Ok(())
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_stream_start() {
256 let mut c = XmppCodec::new();
257 let mut b = BytesMut::with_capacity(1024);
258 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
259 let r = c.decode(&mut b);
260 assert!(match r {
261 Ok(Some(Packet::StreamStart(_))) => true,
262 _ => false,
263 });
264 }
265
266 #[test]
267 fn test_stream_end() {
268 let mut c = XmppCodec::new();
269 let mut b = BytesMut::with_capacity(1024);
270 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
271 let r = c.decode(&mut b);
272 assert!(match r {
273 Ok(Some(Packet::StreamStart(_))) => true,
274 _ => false,
275 });
276 b.put_slice(b"</stream:stream>");
277 let r = c.decode(&mut b);
278 assert!(match r {
279 Ok(Some(Packet::StreamEnd)) => true,
280 _ => false,
281 });
282 }
283
284 #[test]
285 fn test_truncated_stanza() {
286 let mut c = XmppCodec::new();
287 let mut b = BytesMut::with_capacity(1024);
288 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
289 let r = c.decode(&mut b);
290 assert!(match r {
291 Ok(Some(Packet::StreamStart(_))) => true,
292 _ => false,
293 });
294
295 b.put_slice("<test>ß</test".as_bytes());
296 let r = c.decode(&mut b);
297 assert!(match r {
298 Ok(None) => true,
299 _ => false,
300 });
301
302 b.put_slice(b">");
303 let r = c.decode(&mut b);
304 assert!(match r {
305 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
306 _ => false,
307 });
308 }
309
310 #[test]
311 fn test_truncated_utf8() {
312 let mut c = XmppCodec::new();
313 let mut b = BytesMut::with_capacity(1024);
314 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
315 let r = c.decode(&mut b);
316 assert!(match r {
317 Ok(Some(Packet::StreamStart(_))) => true,
318 _ => false,
319 });
320
321 b.put(&b"<test>\xc3"[..]);
322 let r = c.decode(&mut b);
323 assert!(match r {
324 Ok(None) => true,
325 _ => false,
326 });
327
328 b.put(&b"\x9f</test>"[..]);
329 let r = c.decode(&mut b);
330 assert!(match r {
331 Ok(Some(Packet::Stanza(ref el))) if el.name() == "test" && el.text() == "ß" => true,
332 _ => false,
333 });
334 }
335
336 /// test case for https://gitlab.com/xmpp-rs/tokio-xmpp/issues/3
337 #[test]
338 fn test_atrribute_prefix() {
339 let mut c = XmppCodec::new();
340 let mut b = BytesMut::with_capacity(1024);
341 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
342 let r = c.decode(&mut b);
343 assert!(match r {
344 Ok(Some(Packet::StreamStart(_))) => true,
345 _ => false,
346 });
347
348 b.put_slice(b"<status xml:lang='en'>Test status</status>");
349 let r = c.decode(&mut b);
350 assert!(match r {
351 Ok(Some(Packet::Stanza(ref el)))
352 if el.name() == "status"
353 && el.text() == "Test status"
354 && el.attr("xml:lang").map_or(false, |a| a == "en") =>
355 true,
356 _ => false,
357 });
358 }
359
360 /// By default, encode() only gets a BytesMut that has 8 KiB space reserved.
361 #[test]
362 fn test_large_stanza() {
363 use futures::{executor::block_on, sink::SinkExt};
364 use std::io::Cursor;
365 use tokio_util::codec::FramedWrite;
366 let mut framed = FramedWrite::new(Cursor::new(vec![]), XmppCodec::new());
367 let mut text = "".to_owned();
368 for _ in 0..2usize.pow(15) {
369 text = text + "A";
370 }
371 let stanza = Element::builder("message", "jabber:client")
372 .append(
373 Element::builder("body", "jabber:client")
374 .append(text.as_ref())
375 .build(),
376 )
377 .build();
378 block_on(framed.send(Packet::Stanza(stanza))).expect("send");
379 assert_eq!(
380 framed.get_ref().get_ref(),
381 &format!(
382 "<message xmlns='jabber:client'><body>{}</body></message>",
383 text
384 )
385 .as_bytes()
386 );
387 }
388
389 #[test]
390 fn test_cut_out_stanza() {
391 let mut c = XmppCodec::new();
392 let mut b = BytesMut::with_capacity(1024);
393 b.put_slice(b"<?xml version='1.0'?><stream:stream xmlns:stream='http://etherx.jabber.org/streams' version='1.0' xmlns='jabber:client'>");
394 let r = c.decode(&mut b);
395 assert!(match r {
396 Ok(Some(Packet::StreamStart(_))) => true,
397 _ => false,
398 });
399
400 b.put_slice(b"<message ");
401 b.put_slice(b"type='chat'><body>Foo</body></message>");
402 let r = c.decode(&mut b);
403 assert!(match r {
404 Ok(Some(Packet::Stanza(_))) => true,
405 _ => false,
406 });
407 }
408}