JIDs now have typed and stringy methods for node/domain/resource access

xmppftw created

Jid now has typed with_resource and stringy with_resource_str
Jid now has is_full, is_bare

Change summary

jid/src/lib.rs                         | 152 +++++++++++++++++++++++----
jid/src/parts.rs                       |  44 +++++++-
tokio-xmpp/src/client/async_client.rs  |   4 
tokio-xmpp/src/client/bind.rs          |   2 
tokio-xmpp/src/client/simple_client.rs |   4 
tokio-xmpp/src/starttls.rs             |   2 
tokio-xmpp/src/stream_start.rs         |   2 
xmpp/src/lib.rs                        |   6 
8 files changed, 177 insertions(+), 39 deletions(-)

Detailed changes

jid/src/lib.rs 🔗

@@ -34,7 +34,6 @@ use core::num::NonZeroU16;
 use std::convert::TryFrom;
 use std::fmt;
 use std::str::FromStr;
-use stringprep::resourceprep;
 
 #[cfg(feature = "serde")]
 use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
@@ -103,9 +102,9 @@ impl Jid {
     /// # fn main() -> Result<(), Error> {
     /// let jid = Jid::new("node@domain/resource")?;
     ///
-    /// assert_eq!(jid.node(), Some("node"));
-    /// assert_eq!(jid.domain(), "domain");
-    /// assert_eq!(jid.resource(), Some("resource"));
+    /// assert_eq!(jid.node_str(), Some("node"));
+    /// assert_eq!(jid.domain_str(), "domain");
+    /// assert_eq!(jid.resource_str(), Some("resource"));
     /// # Ok(())
     /// # }
     /// ```
@@ -133,22 +132,51 @@ impl Jid {
         }
     }
 
-    /// The optional node part of the JID.
-    pub fn node(&self) -> Option<&str> {
+    /// The optional node part of the JID, as a [`NodePart`]
+    pub fn node(&self) -> Option<NodePart> {
+        match self {
+            Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => {
+                inner.node().map(|s| NodePart::new_unchecked(s))
+            }
+        }
+    }
+
+    /// The optional node part of the JID, as a stringy reference
+    pub fn node_str(&self) -> Option<&str> {
         match self {
             Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => inner.node(),
         }
     }
 
-    /// The domain part of the JID.
-    pub fn domain(&self) -> &str {
+    /// The domain part of the JID, as a [`DomainPart`]
+    pub fn domain(&self) -> DomainPart {
+        match self {
+            Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => {
+                DomainPart::new_unchecked(inner.domain())
+            }
+        }
+    }
+
+    /// The domain part of the JID, as a stringy reference
+    pub fn domain_str(&self) -> &str {
         match self {
             Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => inner.domain(),
         }
     }
 
-    /// The optional resource part of the JID.
-    pub fn resource(&self) -> Option<&str> {
+    /// The optional resource part of the JID, as a [`ResourcePart`]. It is guaranteed to be present
+    /// when the JID is a Full variant, which you can check with [`Jid::is_full`].
+    pub fn resource(&self) -> Option<ResourcePart> {
+        match self {
+            Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => {
+                inner.resource().map(|s| ResourcePart::new_unchecked(s))
+            }
+        }
+    }
+
+    /// The optional resource of the Jabber ID. It is guaranteed to be present when the JID is
+    /// a Full variant, which you can check with [`Jid::is_full`].
+    pub fn resource_str(&self) -> Option<&str> {
         match self {
             Jid::Bare(BareJid { inner }) | Jid::Full(FullJid { inner }) => inner.resource(),
         }
@@ -169,6 +197,19 @@ impl Jid {
             Jid::Bare(jid) => jid,
         }
     }
+
+    /// Checks if the JID contains a [`FullJid`]
+    pub fn is_full(&self) -> bool {
+        match self {
+            Self::Full(_) => true,
+            Self::Bare(_) => false,
+        }
+    }
+
+    /// Checks if the JID contains a [`BareJid`]
+    pub fn is_bare(&self) -> bool {
+        !self.is_full()
+    }
 }
 
 impl TryFrom<Jid> for FullJid {
@@ -488,22 +529,23 @@ impl BareJid {
         self.inner.domain()
     }
 
-    /// Constructs a [`FullJid`] from the bare JID, by specifying a `resource`.
+    /// Constructs a [`BareJid`] from the bare JID, by specifying a [`ResourcePart`].
+    /// If you'd like to specify a stringy resource, use [`BareJid::with_resource_str`] instead.
     ///
     /// # Examples
     ///
     /// ```
-    /// use jid::BareJid;
+    /// use jid::{BareJid, ResourcePart};
     ///
+    /// let resource = ResourcePart::new("resource").unwrap();
     /// let bare = BareJid::new("node@domain").unwrap();
-    /// let full = bare.with_resource("resource").unwrap();
+    /// let full = bare.with_resource(&resource);
     ///
     /// assert_eq!(full.node(), Some("node"));
     /// assert_eq!(full.domain(), "domain");
     /// assert_eq!(full.resource(), "resource");
     /// ```
-    pub fn with_resource(&self, resource: &str) -> Result<FullJid, Error> {
-        let resource = resourceprep(resource).map_err(|_| Error::ResourcePrep)?;
+    pub fn with_resource(&self, resource: &ResourcePart) -> FullJid {
         let slash = NonZeroU16::new(self.inner.normalized.len() as u16);
         let normalized = format!("{}/{resource}", self.inner.normalized);
         let inner = InnerJid {
@@ -511,7 +553,28 @@ impl BareJid {
             at: self.inner.at,
             slash,
         };
-        Ok(FullJid { inner })
+
+        FullJid { inner }
+    }
+
+    /// Constructs a [`FullJid`] from the bare JID, by specifying a stringy `resource`.
+    /// If your resource has already been parsed into a [`ResourcePart`], use [`BareJid::with_resource`].
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use jid::BareJid;
+    ///
+    /// let bare = BareJid::new("node@domain").unwrap();
+    /// let full = bare.with_resource_str("resource").unwrap();
+    ///
+    /// assert_eq!(full.node(), Some("node"));
+    /// assert_eq!(full.domain(), "domain");
+    /// assert_eq!(full.resource(), "resource");
+    /// ```
+    pub fn with_resource_str(&self, resource: &str) -> Result<FullJid, Error> {
+        let resource = ResourcePart::new(resource)?;
+        Ok(self.with_resource(&resource))
     }
 }
 
@@ -634,24 +697,51 @@ mod tests {
     }
 
     #[test]
-    fn bare_to_full_jid() {
+    fn bare_to_full_jid_str() {
         assert_eq!(
-            BareJid::new("a@b.c").unwrap().with_resource("d").unwrap(),
+            BareJid::new("a@b.c")
+                .unwrap()
+                .with_resource_str("d")
+                .unwrap(),
             FullJid::new("a@b.c/d").unwrap()
         );
     }
 
     #[test]
-    fn node_from_jid() {
+    fn bare_to_full_jid() {
         assert_eq!(
-            Jid::Full(FullJid::new("a@b.c/d").unwrap()).node(),
-            Some("a"),
-        );
+            BareJid::new("a@b.c")
+                .unwrap()
+                .with_resource(&ResourcePart::new("d").unwrap()),
+            FullJid::new("a@b.c/d").unwrap()
+        )
+    }
+
+    #[test]
+    fn node_from_jid() {
+        let jid = Jid::new("a@b.c/d").unwrap();
+
+        assert_eq!(jid.node_str(), Some("a"),);
+
+        assert_eq!(jid.node(), Some(NodePart::new("a").unwrap()));
     }
 
     #[test]
     fn domain_from_jid() {
-        assert_eq!(Jid::Bare(BareJid::new("a@b.c").unwrap()).domain(), "b.c");
+        let jid = Jid::new("a@b.c").unwrap();
+
+        assert_eq!(jid.domain_str(), "b.c");
+
+        assert_eq!(jid.domain(), DomainPart::new("b.c").unwrap());
+    }
+
+    #[test]
+    fn resource_from_jid() {
+        let jid = Jid::new("a@b.c/d").unwrap();
+
+        assert_eq!(jid.resource_str(), Some("d"),);
+
+        assert_eq!(jid.resource(), Some(ResourcePart::new("d").unwrap()));
     }
 
     #[test]
@@ -772,4 +862,20 @@ mod tests {
         let equiv = FullJid::new("test@☃.com/TestTM").unwrap();
         assert_eq!(full, equiv);
     }
+
+    #[test]
+    fn jid_from_parts() {
+        let node = NodePart::new("node").unwrap();
+        let domain = DomainPart::new("domain").unwrap();
+        let resource = ResourcePart::new("resource").unwrap();
+
+        let jid = Jid::from_parts(Some(&node), &domain, Some(&resource));
+        assert_eq!(jid, Jid::new("node@domain/resource").unwrap());
+
+        let barejid = BareJid::from_parts(Some(&node), &domain);
+        assert_eq!(barejid, BareJid::new("node@domain").unwrap());
+
+        let fulljid = FullJid::from_parts(Some(&node), &domain, &resource);
+        assert_eq!(fulljid, FullJid::new("node@domain/resource").unwrap());
+    }
 }

jid/src/parts.rs 🔗

@@ -1,10 +1,8 @@
 use stringprep::{nameprep, nodeprep, resourceprep};
 
-use crate::Error;
+use std::fmt;
 
-/// The [`NodePart`] is the optional part before the (optional) `@` in any [`Jid`], whether [`BareJid`] or [`FullJid`].
-#[derive(Clone, Debug, PartialEq, Hash, PartialOrd)]
-pub struct NodePart(pub(crate) String);
+use crate::Error;
 
 fn length_check(len: usize, error_empty: Error, error_too_long: Error) -> Result<(), Error> {
     if len == 0 {
@@ -16,6 +14,10 @@ fn length_check(len: usize, error_empty: Error, error_too_long: Error) -> Result
     }
 }
 
+/// The [`NodePart`] is the optional part before the (optional) `@` in any [`Jid`], whether [`BareJid`] or [`FullJid`].
+#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
+pub struct NodePart(pub(crate) String);
+
 impl NodePart {
     /// Build a new [`NodePart`] from a string slice. Will fail in case of stringprep validation error.
     pub fn new(s: &str) -> Result<NodePart, Error> {
@@ -23,10 +25,20 @@ impl NodePart {
         length_check(node.len(), Error::NodeEmpty, Error::NodeTooLong)?;
         Ok(NodePart(node.to_string()))
     }
+
+    pub(crate) fn new_unchecked(s: &str) -> NodePart {
+        NodePart(s.to_string())
+    }
+}
+
+impl fmt::Display for NodePart {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
 }
 
 /// The [`DomainPart`] is the part between the (optional) `@` and the (optional) `/` in any [`Jid`], whether [`BareJid`] or [`FullJid`].
-#[derive(Clone, Debug, PartialEq, Hash, PartialOrd)]
+#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
 pub struct DomainPart(pub(crate) String);
 
 impl DomainPart {
@@ -36,10 +48,20 @@ impl DomainPart {
         length_check(domain.len(), Error::DomainEmpty, Error::DomainTooLong)?;
         Ok(DomainPart(domain.to_string()))
     }
+
+    pub(crate) fn new_unchecked(s: &str) -> DomainPart {
+        DomainPart(s.to_string())
+    }
+}
+
+impl fmt::Display for DomainPart {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
 }
 
 /// The [`ResourcePart`] is the optional part after the `/` in a [`Jid`]. It is mandatory in [`FullJid`].
-#[derive(Clone, Debug, PartialEq, Hash, PartialOrd)]
+#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
 pub struct ResourcePart(pub(crate) String);
 
 impl ResourcePart {
@@ -49,4 +71,14 @@ impl ResourcePart {
         length_check(resource.len(), Error::ResourceEmpty, Error::ResourceTooLong)?;
         Ok(ResourcePart(resource.to_string()))
     }
+
+    pub(crate) fn new_unchecked(s: &str) -> ResourcePart {
+        ResourcePart(s.to_string())
+    }
+}
+
+impl fmt::Display for ResourcePart {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
 }

tokio-xmpp/src/client/async_client.rs 🔗

@@ -109,13 +109,13 @@ impl Client {
         jid: Jid,
         password: String,
     ) -> Result<XMPPStream, Error> {
-        let username = jid.node().unwrap();
+        let username = jid.node_str().unwrap();
         let password = password;
 
         // TCP connection
         let tcp_stream = match server {
             ServerConfig::UseSrv => {
-                connect_with_srv(jid.domain(), "_xmpp-client._tcp", 5222).await?
+                connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
             }
             ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
         };

tokio-xmpp/src/client/bind.rs 🔗

@@ -17,7 +17,7 @@ pub async fn bind<S: AsyncRead + AsyncWrite + Unpin>(
     if stream.stream_features.can_bind() {
         let resource = stream
             .jid
-            .resource()
+            .resource_str()
             .and_then(|resource| Some(resource.to_owned()));
         let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
         stream.send_stanza(iq).await?;

tokio-xmpp/src/client/simple_client.rs 🔗

@@ -50,9 +50,9 @@ impl Client {
     }
 
     async fn connect(jid: Jid, password: String) -> Result<XMPPStream, Error> {
-        let username = jid.node().unwrap();
+        let username = jid.node_str().unwrap();
         let password = password;
-        let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?;
+        let domain = idna::domain_to_ascii(&jid.clone().domain_str()).map_err(|_| Error::Idna)?;
 
         // TCP connection
         let tcp_stream = connect_with_srv(&domain, "_xmpp-client._tcp", 5222).await?;

tokio-xmpp/src/starttls.rs 🔗

@@ -29,7 +29,7 @@ use crate::{Error, ProtocolError};
 async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
     xmpp_stream: XMPPStream<S>,
 ) -> Result<TlsStream<S>, Error> {
-    let domain = xmpp_stream.jid.domain().to_owned();
+    let domain = xmpp_stream.jid.domain_str().to_owned();
     let stream = xmpp_stream.into_inner();
     let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
         .connect(&domain, stream)

tokio-xmpp/src/stream_start.rs 🔗

@@ -16,7 +16,7 @@ pub async fn start<S: AsyncRead + AsyncWrite + Unpin>(
     ns: String,
 ) -> Result<XMPPStream<S>, Error> {
     let attrs = [
-        ("to".to_owned(), jid.domain().to_owned()),
+        ("to".to_owned(), jid.domain_str().to_owned()),
         ("version".to_owned(), "1.0".to_owned()),
         ("xmlns".to_owned(), ns.clone()),
         ("xmlns:stream".to_owned(), ns::STREAM.to_owned()),

xmpp/src/lib.rs 🔗

@@ -180,7 +180,7 @@ impl ClientBuilder<'_> {
 
     pub fn build(self) -> Agent {
         let jid: Jid = if let Some(resource) = &self.resource {
-            self.jid.with_resource(resource).unwrap().into()
+            self.jid.with_resource_str(resource).unwrap().into()
         } else {
             self.jid.clone().into()
         };
@@ -233,7 +233,7 @@ impl Agent {
         }
 
         let nick = nick.unwrap_or_else(|| self.default_nick.read().unwrap().clone());
-        let room_jid = room.with_resource(&nick).unwrap();
+        let room_jid = room.with_resource_str(&nick).unwrap();
         let mut presence = Presence::new(PresenceType::None).with_to(room_jid);
         presence.add_payload(muc);
         presence.set_status(String::from(lang), String::from(status));
@@ -262,7 +262,7 @@ impl Agent {
         lang: &str,
         text: &str,
     ) {
-        let recipient: Jid = room.with_resource(&recipient).unwrap().into();
+        let recipient: Jid = room.with_resource_str(&recipient).unwrap().into();
         let mut message = Message::new(recipient).with_payload(MucUser::new());
         message.type_ = MessageType::Chat;
         message