1// Copyright (c) 2017 Emmanuel Gil Peyrot <linkmauve@linkmauve.fr>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7use base64::{engine::general_purpose::STANDARD as Base64Engine, Engine};
8use jid::Jid;
9use std::str::FromStr;
10use xso::error::Error;
11
12/// A trait for codecs that can decode and encode text nodes.
13pub trait Codec {
14 type Decoded;
15
16 /// Decode the given string into the codec’s output.
17 fn decode(s: &str) -> Result<Self::Decoded, Error>;
18
19 /// Encode the given value; return None to not produce a text node at all.
20 fn encode(decoded: &Self::Decoded) -> Option<String>;
21}
22
23/// Codec for text content.
24pub struct Text;
25
26impl Codec for Text {
27 type Decoded = String;
28
29 fn decode(s: &str) -> Result<String, Error> {
30 Ok(s.to_owned())
31 }
32
33 fn encode(decoded: &String) -> Option<String> {
34 Some(decoded.to_owned())
35 }
36}
37
38/// Codec transformer that makes the text optional; a "" string is decoded as None.
39pub struct OptionalCodec<T: Codec>(std::marker::PhantomData<T>);
40
41impl<T> Codec for OptionalCodec<T>
42where
43 T: Codec,
44{
45 type Decoded = Option<T::Decoded>;
46
47 fn decode(s: &str) -> Result<Option<T::Decoded>, Error> {
48 if s.is_empty() {
49 return Ok(None);
50 }
51
52 Ok(Some(T::decode(s)?))
53 }
54
55 fn encode(decoded: &Option<T::Decoded>) -> Option<String> {
56 decoded.as_ref().and_then(T::encode)
57 }
58}
59
60/// Codec that trims whitespace around the text.
61pub struct Trimmed<T: Codec>(std::marker::PhantomData<T>);
62
63impl<T> Codec for Trimmed<T>
64where
65 T: Codec,
66{
67 type Decoded = T::Decoded;
68
69 fn decode(s: &str) -> Result<T::Decoded, Error> {
70 match s.trim() {
71 // TODO: This error message can be a bit opaque when used
72 // in-context; ideally it'd be configurable.
73 "" => Err(Error::Other(
74 "The text in the element's text node was empty after trimming.",
75 )),
76 trimmed => T::decode(trimmed),
77 }
78 }
79
80 fn encode(decoded: &T::Decoded) -> Option<String> {
81 T::encode(decoded)
82 }
83}
84
85/// Codec wrapping that encodes/decodes a string as base64.
86pub struct Base64;
87
88impl Codec for Base64 {
89 type Decoded = Vec<u8>;
90
91 fn decode(s: &str) -> Result<Vec<u8>, Error> {
92 Base64Engine.decode(s).map_err(Error::text_parse_error)
93 }
94
95 fn encode(decoded: &Vec<u8>) -> Option<String> {
96 Some(Base64Engine.encode(decoded))
97 }
98}
99
100/// Codec wrapping base64 encode/decode, while ignoring whitespace characters.
101pub struct WhitespaceAwareBase64;
102
103impl Codec for WhitespaceAwareBase64 {
104 type Decoded = Vec<u8>;
105
106 fn decode(s: &str) -> Result<Self::Decoded, Error> {
107 let s: String = s
108 .chars()
109 .filter(|ch| *ch != ' ' && *ch != '\n' && *ch != '\t')
110 .collect();
111
112 Base64Engine.decode(s).map_err(Error::text_parse_error)
113 }
114
115 fn encode(decoded: &Self::Decoded) -> Option<String> {
116 Some(Base64Engine.encode(decoded))
117 }
118}
119
120/// Codec for bytes of lowercase hexadecimal, with a fixed length `N` (in bytes).
121pub struct FixedHex<const N: usize>;
122
123impl<const N: usize> Codec for FixedHex<N> {
124 type Decoded = [u8; N];
125
126 fn decode(s: &str) -> Result<Self::Decoded, Error> {
127 if s.len() != 2 * N {
128 return Err(Error::Other("Invalid length"));
129 }
130
131 let mut bytes = [0u8; N];
132 for i in 0..N {
133 bytes[i] =
134 u8::from_str_radix(&s[2 * i..2 * i + 2], 16).map_err(Error::text_parse_error)?;
135 }
136
137 Ok(bytes)
138 }
139
140 fn encode(decoded: &Self::Decoded) -> Option<String> {
141 let mut bytes = String::with_capacity(N * 2);
142 for byte in decoded {
143 bytes.extend(format!("{:02x}", byte).chars());
144 }
145 Some(bytes)
146 }
147}
148
149/// Codec for colon-separated bytes of uppercase hexadecimal.
150pub struct ColonSeparatedHex;
151
152impl Codec for ColonSeparatedHex {
153 type Decoded = Vec<u8>;
154
155 fn decode(s: &str) -> Result<Self::Decoded, Error> {
156 let mut bytes = vec![];
157 for i in 0..(1 + s.len()) / 3 {
158 let byte =
159 u8::from_str_radix(&s[3 * i..3 * i + 2], 16).map_err(Error::text_parse_error)?;
160 if 3 * i + 2 < s.len() {
161 assert_eq!(&s[3 * i + 2..3 * i + 3], ":");
162 }
163 bytes.push(byte);
164 }
165 Ok(bytes)
166 }
167
168 fn encode(decoded: &Self::Decoded) -> Option<String> {
169 let mut bytes = vec![];
170 for byte in decoded {
171 bytes.push(format!("{:02X}", byte));
172 }
173 Some(bytes.join(":"))
174 }
175}
176
177/// Codec for a JID.
178pub struct JidCodec;
179
180impl Codec for JidCodec {
181 type Decoded = Jid;
182
183 fn decode(s: &str) -> Result<Jid, Error> {
184 Jid::from_str(s).map_err(Error::text_parse_error)
185 }
186
187 fn encode(jid: &Jid) -> Option<String> {
188 Some(jid.to_string())
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn fixed_hex() {
198 let value = [0x01, 0xfe, 0xef];
199
200 // Test that we support both lowercase and uppercase as input.
201 let hex = FixedHex::<3>::decode("01feEF").unwrap();
202 assert_eq!(&hex, &value);
203
204 // Test that we do output lowercase.
205 let hex = FixedHex::<3>::encode(&value).unwrap();
206 assert_eq!(hex, "01feef");
207
208 // What if we give it a string that's too long?
209 let err = FixedHex::<3>::decode("01feEF01").unwrap_err();
210 assert_eq!(err.to_string(), "Invalid length");
211
212 // Too short?
213 let err = FixedHex::<3>::decode("01fe").unwrap_err();
214 assert_eq!(err.to_string(), "Invalid length");
215
216 // Not-even numbers?
217 let err = FixedHex::<3>::decode("01feE").unwrap_err();
218 assert_eq!(err.to_string(), "Invalid length");
219
220 // No colon supported.
221 let err = FixedHex::<3>::decode("0:f:EF").unwrap_err();
222 assert_eq!(
223 err.to_string(),
224 "text parse error: invalid digit found in string"
225 );
226
227 // No non-hex character allowed.
228 let err = FixedHex::<3>::decode("01defg").unwrap_err();
229 assert_eq!(
230 err.to_string(),
231 "text parse error: invalid digit found in string"
232 );
233 }
234}