jid: Validate domains against idna

Emmanuel Gil Peyrot created

The idna crate validates against UTS 46, which supports domains using
either IDNA2003 or IDNA2008.  This allows us to support both old and
new internationalized domain names.

This dependency isn’t a new one in the tree, as hickory-proto, url and
tokio-xmpp all were already depending on it.

There are a bunch of other checks that have to be performed, this is
inspired by slixmpp’s JID implementation.

Change summary

jid/CHANGELOG.md      |   2 
jid/Cargo.toml        |   1 
jid/src/error.rs      |   4 +
jid/src/lib.rs        | 114 ++++++++++++++++++++++++++++++++++----------
parsers/src/roster.rs |   2 
5 files changed, 95 insertions(+), 28 deletions(-)

Detailed changes

jid/CHANGELOG.md 🔗

@@ -3,6 +3,8 @@ Version NEXT:
     - Add missing check for JIDs with too many `@` before the resource, such as
       `a@b@c` or `a@b@c/d` which should clearly be invalid.  The new error it
       produces is named `TooManyAts`.
+    - domainparts are now checked much more in-depth, using the idna crate and
+      various custom rules.
 
 Version 0.11.1, release 2024-07-23:
   * Breaking:

jid/Cargo.toml 🔗

@@ -24,6 +24,7 @@ serde = { version = "1.0", features = ["derive"], optional = true }
 stringprep = "0.1.3"
 quote = { version = "1.0", optional = true }
 proc-macro2 = { version = "1.0", optional = true }
+idna = "1"
 # same repository dependencies
 minidom = { version = "0.16", path = "../minidom", optional = true }
 

jid/src/error.rs 🔗

@@ -49,6 +49,9 @@ pub enum Error {
 
     /// Happens when parsing a JID which has two @ before the resource.
     TooManyAts,
+
+    /// Happens when the domain is invalid according to idna.
+    Idna,
 }
 
 impl core::error::Error for Error {}
@@ -68,6 +71,7 @@ impl fmt::Display for Error {
             Error::ResourceMissingInFullJid => "no resource found in this full JID",
             Error::ResourceInBareJid => "resource found while parsing a bare JID",
             Error::TooManyAts => "second @ found before parsing the resource",
+            Error::Idna => "domain doesn’t pass idna validation",
         })
     }
 }

jid/src/lib.rs 🔗

@@ -45,12 +45,14 @@ use core::cmp::Ordering;
 use core::fmt;
 use core::hash::{Hash, Hasher};
 use core::mem;
+use core::net::{Ipv4Addr, Ipv6Addr};
 use core::num::NonZeroU16;
 use core::ops::Deref;
 use core::str::FromStr;
 
 use memchr::memchr2_iter;
 
+use idna::uts46::{AsciiDenyList, DnsLength, Hyphens, Uts46};
 use stringprep::{nameprep, nodeprep, resourceprep};
 
 #[cfg(feature = "serde")]
@@ -80,6 +82,37 @@ fn length_check(len: usize, error_empty: Error, error_too_long: Error) -> Result
     }
 }
 
+fn domain_check(mut domain: &str) -> Result<Cow<'_, str>, Error> {
+    // First, check if this is an IPv4 address.
+    if Ipv4Addr::from_str(domain).is_ok() {
+        return Ok(Cow::Borrowed(domain));
+    }
+
+    // Then if this is an IPv6 address.
+    if domain.starts_with('[') && domain.ends_with(']') {
+        if Ipv6Addr::from_str(&domain[1..domain.len() - 1]).is_ok() {
+            return Ok(Cow::Borrowed(domain));
+        }
+    }
+
+    // idna can handle the root dot for us, but we still want to remove it for normalization
+    // purposes.
+    if domain.ends_with('.') {
+        domain = &domain[..domain.len() - 1];
+    }
+
+    Uts46::new()
+        .to_ascii(
+            domain.as_bytes(),
+            AsciiDenyList::URL,
+            Hyphens::Check,
+            DnsLength::Verify,
+        )
+        .map_err(|_| Error::Idna)?;
+    let domain = nameprep(domain).map_err(|_| Error::NamePrep)?;
+    Ok(domain)
+}
+
 /// A struct representing a Jabber ID (JID).
 ///
 /// This JID can either be "bare" (without a `/resource` suffix) or full (with
@@ -183,9 +216,7 @@ impl Jid {
                             nodeprep(&unnormalized[..first_index]).map_err(|_| Error::NodePrep)?;
                         length_check(node.len(), Error::NodeEmpty, Error::NodeTooLong)?;
 
-                        let domain = nameprep(&unnormalized[first_index + 1..second_index])
-                            .map_err(|_| Error::NamePrep)?;
-                        length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
+                        let domain = domain_check(&unnormalized[first_index + 1..second_index])?;
 
                         let resource = resourceprep(&unnormalized[second_index + 1..])
                             .map_err(|_| Error::ResourcePrep)?;
@@ -211,9 +242,7 @@ impl Jid {
                         nodeprep(&unnormalized[..first_index]).map_err(|_| Error::NodePrep)?;
                     length_check(node.len(), Error::NodeEmpty, Error::NodeTooLong)?;
 
-                    let domain =
-                        nameprep(&unnormalized[first_index + 1..]).map_err(|_| Error::NamePrep)?;
-                    length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
+                    let domain = domain_check(&unnormalized[first_index + 1..])?;
 
                     orig_at = Some(node.len());
                     orig_slash = None;
@@ -228,8 +257,7 @@ impl Jid {
                 // The JID is of the form domain/resource, we can stop looking for further
                 // characters.
 
-                let domain = nameprep(&unnormalized[..first_index]).map_err(|_| Error::NamePrep)?;
-                length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
+                let domain = domain_check(&unnormalized[..first_index])?;
 
                 let resource = resourceprep(&unnormalized[first_index + 1..])
                     .map_err(|_| Error::ResourcePrep)?;
@@ -244,9 +272,7 @@ impl Jid {
             }
         } else {
             // Last possible case, just a domain JID.
-
-            let domain = nameprep(unnormalized).map_err(|_| Error::NamePrep)?;
-            length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
+            let domain = domain_check(unnormalized)?;
 
             orig_at = None;
             orig_slash = None;
@@ -1063,15 +1089,15 @@ mod tests {
 
     #[test]
     fn invalid_jids() {
-        assert_eq!(BareJid::from_str(""), Err(Error::DomainEmpty));
-        assert_eq!(BareJid::from_str("/c"), Err(Error::DomainEmpty));
-        assert_eq!(BareJid::from_str("a@/c"), Err(Error::DomainEmpty));
+        assert_eq!(BareJid::from_str(""), Err(Error::Idna));
+        assert_eq!(BareJid::from_str("/c"), Err(Error::Idna));
+        assert_eq!(BareJid::from_str("a@/c"), Err(Error::Idna));
         assert_eq!(BareJid::from_str("@b"), Err(Error::NodeEmpty));
         assert_eq!(BareJid::from_str("b/"), Err(Error::ResourceEmpty));
 
-        assert_eq!(FullJid::from_str(""), Err(Error::DomainEmpty));
-        assert_eq!(FullJid::from_str("/c"), Err(Error::DomainEmpty));
-        assert_eq!(FullJid::from_str("a@/c"), Err(Error::DomainEmpty));
+        assert_eq!(FullJid::from_str(""), Err(Error::Idna));
+        assert_eq!(FullJid::from_str("/c"), Err(Error::Idna));
+        assert_eq!(FullJid::from_str("a@/c"), Err(Error::Idna));
         assert_eq!(FullJid::from_str("@b"), Err(Error::NodeEmpty));
         assert_eq!(FullJid::from_str("b/"), Err(Error::ResourceEmpty));
         assert_eq!(
@@ -1148,6 +1174,31 @@ mod tests {
         FullJid::from_str("a@b/🎉").unwrap_err();
     }
 
+    #[test]
+    fn idna() {
+        let bare = BareJid::from_str("Weiß.com.").unwrap();
+        let equiv = BareJid::new("weiss.com").unwrap();
+        assert_eq!(bare, equiv);
+        BareJid::from_str("127.0.0.1").unwrap();
+        BareJid::from_str("[::1]").unwrap();
+        BareJid::from_str("domain.tld.").unwrap();
+    }
+
+    #[test]
+    fn invalid_idna() {
+        BareJid::from_str("a@b@c").unwrap_err();
+        FullJid::from_str("a@b@c/d").unwrap_err();
+        BareJid::from_str("[::1234").unwrap_err();
+        BareJid::from_str("1::1234]").unwrap_err();
+        BareJid::from_str("domain.tld:5222").unwrap_err();
+        BareJid::from_str("-domain.tld").unwrap_err();
+        BareJid::from_str("domain.tld-").unwrap_err();
+        BareJid::from_str("domain..tld").unwrap_err();
+        BareJid::from_str("domain.tld..").unwrap_err();
+        BareJid::from_str("1234567890123456789012345678901234567890123456789012345678901234.com")
+            .unwrap_err();
+    }
+
     #[test]
     fn jid_from_parts() {
         let node = NodePart::new("node").unwrap();
@@ -1364,21 +1415,30 @@ mod tests {
 
     #[test]
     fn reject_long_domainpart() {
-        let mut long = Vec::with_capacity(1028);
+        let mut long = Vec::with_capacity(66);
         long.push(b'x');
         long.push(b'@');
-        long.resize(1026, b'a');
+        long.resize(66, b'a');
         let long = String::from_utf8(long).unwrap();
 
-        match Jid::new(&long) {
-            Err(Error::DomainTooLong) => (),
-            other => panic!("unexpected result: {:?}", other),
-        }
+        Jid::new(&long).unwrap_err();
+        BareJid::new(&long).unwrap_err();
 
-        match BareJid::new(&long) {
-            Err(Error::DomainTooLong) => (),
-            other => panic!("unexpected result: {:?}", other),
-        }
+        // A domain can be up to 253 bytes.
+        let mut long = Vec::with_capacity(256);
+        long.push(b'x');
+        long.push(b'@');
+        long.resize(65, b'a');
+        long.push(b'.');
+        long.resize(129, b'b');
+        long.push(b'.');
+        long.resize(193, b'c');
+        long.push(b'.');
+        long.resize(256, b'd');
+        let long = String::from_utf8(long).unwrap();
+
+        Jid::new(&long).unwrap_err();
+        BareJid::new(&long).unwrap_err();
     }
 
     #[test]

parsers/src/roster.rs 🔗

@@ -312,7 +312,7 @@ mod tests {
         let error = Roster::try_from(elem).unwrap_err();
         assert_eq!(
             format!("{error}"),
-            "text parse error: no domain found in this JID"
+            "text parse error: domain doesn’t pass idna validation"
         );
     }