jid: Test for too many @ before the resourcepart

Emmanuel Gil Peyrot created

This additionally should optimize the parsing a tiny bit by looking for
both @ and / at the same time and iterating on them, instead of one by
one.

Thanks nicoco for the report!

Change summary

jid/CHANGELOG.md |   6 ++
jid/src/error.rs |   4 +
jid/src/lib.rs   | 118 +++++++++++++++++++++++++++++--------------------
3 files changed, 80 insertions(+), 48 deletions(-)

Detailed changes

jid/CHANGELOG.md πŸ”—

@@ -1,3 +1,9 @@
+Version NEXT:
+  * Additions:
+    - Add missing check for JIDs with too many `@` before the resource, such as
+      `a@b@c` or `a@b@c/d` which should clearly be invalid.  The new error it
+      produces is named `TooManyAts`.
+
 Version 0.11.1, release 2024-07-23:
   * Breaking:
     - Move InnerJid into Jid and reformulate BareJid and FullJid in terms of

jid/src/error.rs πŸ”—

@@ -46,6 +46,9 @@ pub enum Error {
 
     /// Happens when parsing a bare JID and there is a resource.
     ResourceInBareJid,
+
+    /// Happens when parsing a JID which has two @ before the resource.
+    TooManyAts,
 }
 
 impl core::error::Error for Error {}
@@ -64,6 +67,7 @@ impl fmt::Display for Error {
             Error::ResourcePrep => "resource doesn’t pass resourceprep validation",
             Error::ResourceMissingInFullJid => "no resource found in this full JID",
             Error::ResourceInBareJid => "resource found while parsing a bare JID",
+            Error::TooManyAts => "second @ found before parsing the resource",
         })
     }
 }

jid/src/lib.rs πŸ”—

@@ -49,7 +49,7 @@ use core::num::NonZeroU16;
 use core::ops::Deref;
 use core::str::FromStr;
 
-use memchr::memchr;
+use memchr::memchr2_iter;
 
 use stringprep::{nameprep, nodeprep, resourceprep};
 
@@ -170,67 +170,87 @@ impl Jid {
     /// ```
     pub fn new(unnormalized: &str) -> Result<Jid, Error> {
         let bytes = unnormalized.as_bytes();
-        let mut orig_at = memchr(b'@', bytes);
-        let mut orig_slash = memchr(b'/', bytes);
-        if orig_at.is_some() && orig_slash.is_some() && orig_at > orig_slash {
-            // This is part of the resource, not a node@domain separator.
-            orig_at = None;
-        }
-
-        let normalized = match (orig_at, orig_slash) {
-            (Some(at), Some(slash)) => {
-                let node = nodeprep(&unnormalized[..at]).map_err(|_| Error::NodePrep)?;
-                length_check(node.len(), Error::NodeEmpty, Error::NodeTooLong)?;
-
-                let domain = nameprep(&unnormalized[at + 1..slash]).map_err(|_| Error::NamePrep)?;
-                length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
-
-                let resource =
-                    resourceprep(&unnormalized[slash + 1..]).map_err(|_| Error::ResourcePrep)?;
-                length_check(resource.len(), Error::ResourceEmpty, Error::ResourceTooLong)?;
-
-                orig_at = Some(node.len());
-                orig_slash = Some(node.len() + domain.len() + 1);
-                match (node, domain, resource) {
-                    (Cow::Borrowed(_), Cow::Borrowed(_), Cow::Borrowed(_)) => {
-                        unnormalized.to_string()
+        let orig_at;
+        let orig_slash;
+        let mut iter = memchr2_iter(b'@', b'/', bytes);
+        let normalized = if let Some(first_index) = iter.next() {
+            let byte = bytes[first_index];
+            if byte == b'@' {
+                if let Some(second_index) = iter.next() {
+                    let byte = bytes[second_index];
+                    if byte == b'/' {
+                        let node =
+                            nodeprep(&unnormalized[..first_index]).map_err(|_| Error::NodePrep)?;
+                        length_check(node.len(), Error::NodeEmpty, Error::NodeTooLong)?;
+
+                        let domain = nameprep(&unnormalized[first_index + 1..second_index])
+                            .map_err(|_| Error::NamePrep)?;
+                        length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
+
+                        let resource = resourceprep(&unnormalized[second_index + 1..])
+                            .map_err(|_| Error::ResourcePrep)?;
+                        length_check(resource.len(), Error::ResourceEmpty, Error::ResourceTooLong)?;
+
+                        orig_at = Some(node.len());
+                        orig_slash = Some(node.len() + domain.len() + 1);
+                        match (node, domain, resource) {
+                            (Cow::Borrowed(_), Cow::Borrowed(_), Cow::Borrowed(_)) => {
+                                unnormalized.to_string()
+                            }
+                            (node, domain, resource) => format!("{node}@{domain}/{resource}"),
+                        }
+                    } else
+                    /* This is another '@' character. */
+                    {
+                        return Err(Error::TooManyAts);
+                    }
+                } else {
+                    // That is a node@domain JID.
+
+                    let node =
+                        nodeprep(&unnormalized[..first_index]).map_err(|_| Error::NodePrep)?;
+                    length_check(node.len(), Error::NodeEmpty, Error::NodeTooLong)?;
+
+                    let domain =
+                        nameprep(&unnormalized[first_index + 1..]).map_err(|_| Error::NamePrep)?;
+                    length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
+
+                    orig_at = Some(node.len());
+                    orig_slash = None;
+                    match (node, domain) {
+                        (Cow::Borrowed(_), Cow::Borrowed(_)) => unnormalized.to_string(),
+                        (node, domain) => format!("{node}@{domain}"),
                     }
-                    (node, domain, resource) => format!("{node}@{domain}/{resource}"),
                 }
-            }
-            (Some(at), None) => {
-                let node = nodeprep(&unnormalized[..at]).map_err(|_| Error::NodePrep)?;
-                length_check(node.len(), Error::NodeEmpty, Error::NodeTooLong)?;
+            } else
+            /* This is a '/' character. */
+            {
+                // The JID is of the form domain/resource, we can stop looking for further
+                // characters.
 
-                let domain = nameprep(&unnormalized[at + 1..]).map_err(|_| Error::NamePrep)?;
+                let domain = nameprep(&unnormalized[..first_index]).map_err(|_| Error::NamePrep)?;
                 length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
 
-                orig_at = Some(node.len());
-                match (node, domain) {
-                    (Cow::Borrowed(_), Cow::Borrowed(_)) => unnormalized.to_string(),
-                    (node, domain) => format!("{node}@{domain}"),
-                }
-            }
-            (None, Some(slash)) => {
-                let domain = nameprep(&unnormalized[..slash]).map_err(|_| Error::NamePrep)?;
-                length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
-
-                let resource =
-                    resourceprep(&unnormalized[slash + 1..]).map_err(|_| Error::ResourcePrep)?;
+                let resource = resourceprep(&unnormalized[first_index + 1..])
+                    .map_err(|_| Error::ResourcePrep)?;
                 length_check(resource.len(), Error::ResourceEmpty, Error::ResourceTooLong)?;
 
+                orig_at = None;
                 orig_slash = Some(domain.len());
                 match (domain, resource) {
                     (Cow::Borrowed(_), Cow::Borrowed(_)) => unnormalized.to_string(),
                     (domain, resource) => format!("{domain}/{resource}"),
                 }
             }
-            (None, None) => {
-                let domain = nameprep(unnormalized).map_err(|_| Error::NamePrep)?;
-                length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
+        } else {
+            // Last possible case, just a domain JID.
 
-                domain.into_owned()
-            }
+            let domain = nameprep(unnormalized).map_err(|_| Error::NamePrep)?;
+            length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
+
+            orig_at = None;
+            orig_slash = None;
+            domain.into_owned()
         };
 
         Ok(Self {
@@ -1059,6 +1079,8 @@ mod tests {
             Err(Error::ResourceMissingInFullJid)
         );
         assert_eq!(BareJid::from_str("a@b/c"), Err(Error::ResourceInBareJid));
+        assert_eq!(BareJid::from_str("a@b@c"), Err(Error::TooManyAts));
+        assert_eq!(FullJid::from_str("a@b@c/d"), Err(Error::TooManyAts));
     }
 
     #[test]