update to sasl-rs 0.3.0, process error stanzas

lumi created

Change summary

Cargo.toml                     |   5 
src/client.rs                  |  39 +++++--
src/components/mod.rs          |   3 
src/components/sasl_error.rs   |  75 +++++++++++++++
src/components/stanza_error.rs | 172 ++++++++++++++++++++++++++++++++++++
src/error.rs                   |   3 
src/lib.rs                     |   2 
src/ns.rs                      |   1 
src/util.rs                    |  19 +++
9 files changed, 305 insertions(+), 14 deletions(-)

Detailed changes

Cargo.toml 🔗

@@ -20,4 +20,7 @@ openssl = "0.9.7"
 base64 = "0.4.0"
 minidom = "0.1.0"
 jid = "0.1.0"
-sasl = "0.1.0"
+sasl = "0.3.0"
+
+[features]
+insecure = []

src/client.rs 🔗

@@ -5,8 +5,14 @@ use ns;
 use plugin::{Plugin, PluginProxyBinding};
 use event::AbstractEvent;
 use connection::{Connection, C2S};
-use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
+use sasl::{ Mechanism as SaslMechanism
+          , Credentials as SaslCredentials
+          , Secret as SaslSecret
+          , ChannelBinding
+          };
 use sasl::mechanisms::{Plain, Scram, Sha1, Sha256};
+use components::sasl_error::SaslError;
+use util::FromElement;
 
 use base64;
 
@@ -27,7 +33,7 @@ pub struct StreamFeatures {
 /// A builder for `Client`s.
 pub struct ClientBuilder {
     jid: Jid,
-    credentials: Option<SaslCredentials>,
+    credentials: SaslCredentials,
     host: Option<String>,
     port: u16,
 }
@@ -37,7 +43,7 @@ impl ClientBuilder {
     pub fn new(jid: Jid) -> ClientBuilder {
         ClientBuilder {
             jid: jid,
-            credentials: None,
+            credentials: SaslCredentials::default(),
             host: None,
             port: 5222,
         }
@@ -57,11 +63,11 @@ impl ClientBuilder {
 
     /// Sets the password to use.
     pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
-        self.credentials = Some(SaslCredentials {
-            username: self.jid.node.clone().expect("JID has no node"),
+        self.credentials = SaslCredentials {
+            username: Some(self.jid.node.clone().expect("JID has no node")),
             secret: SaslSecret::Password(password.into()),
-            channel_binding: None,
-        });
+            channel_binding: ChannelBinding::None,
+        };
         self
     }
 
@@ -72,7 +78,7 @@ impl ClientBuilder {
         C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
         let (sender_out, sender_in) = channel();
         let (dispatcher_out, dispatcher_in) = channel();
-        let mut credentials = self.credentials.expect("can't connect without credentials");
+        let mut credentials = self.credentials;
         credentials.channel_binding = transport.channel_bind();
         let mut client = Client {
             jid: self.jid,
@@ -147,15 +153,23 @@ impl Client {
         Ok(())
     }
 
-    fn connect(&mut self, credentials: SaslCredentials) -> Result<(), Error> {
+    fn connect(&mut self, mut credentials: SaslCredentials) -> Result<(), Error> {
         let features = self.wait_for_features()?;
         let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
         fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
         // TODO: better way for selecting these, enabling anonymous auth
-        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256") {
+        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256-PLUS") {
+            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
+        }
+        else if ms.contains("SCRAM-SHA-1-PLUS") {
+            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
+        }
+        else if ms.contains("SCRAM-SHA-256") {
+            credentials.channel_binding = ChannelBinding::Unsupported;
             Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
         }
         else if ms.contains("SCRAM-SHA-1") {
+            credentials.channel_binding = ChannelBinding::Unsupported;
             Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
         }
         else if ms.contains("PLAIN") {
@@ -207,9 +221,8 @@ impl Client {
                 return Ok(());
             }
             else if n.is("failure", ns::SASL) {
-                let msg = n.text();
-                let inner = if msg == "" { None } else { Some(msg) };
-                return Err(Error::SaslError(inner));
+                let inner = SaslError::from_element(&n).map_err(|_| Error::SaslError(None))?;
+                return Err(Error::XmppSaslError(inner));
             }
         }
     }

src/components/sasl_error.rs 🔗

@@ -0,0 +1,75 @@
+use ns;
+use minidom::Element;
+use util::FromElement;
+
+#[derive(Clone, Debug)]
+pub enum Condition {
+    Aborted,
+    AccountDisabled(Option<String>),
+    CredentialsExpired,
+    EncryptionRequired,
+    IncorrectEncoding,
+    InvalidAuthzid,
+    InvalidMechanism,
+    MalformedRequest,
+    MechanismTooWeak,
+    NotAuthorized,
+    TemporaryAuthFailure,
+    Unknown,
+}
+
+#[derive(Clone, Debug)]
+pub struct SaslError {
+    condition: Condition,
+    text: Option<String>,
+}
+
+impl FromElement for SaslError {
+    type Err = ();
+
+    fn from_element(element: &Element) -> Result<SaslError, ()> {
+        if !element.is("failure", ns::SASL) {
+            return Err(());
+        }
+        let mut err = SaslError {
+            condition: Condition::Unknown,
+            text: None,
+        };
+        if let Some(text) = element.get_child("text", ns::SASL) {
+            let desc = text.text();
+            err.text = Some(desc);
+        }
+        if element.has_child("aborted", ns::SASL) {
+            err.condition = Condition::Aborted;
+        }
+        else if let Some(account_disabled) = element.get_child("account-disabled", ns::SASL) {
+            let text = account_disabled.text();
+            err.condition = Condition::AccountDisabled(if text == "" { None } else { Some(text) });
+        }
+        else if element.has_child("credentials-expired", ns::SASL) {
+            err.condition = Condition::CredentialsExpired;
+        }
+        else if element.has_child("encryption-required", ns::SASL) {
+            err.condition = Condition::EncryptionRequired;
+        }
+        else if element.has_child("incorrect-encoding", ns::SASL) {
+            err.condition = Condition::IncorrectEncoding;
+        }
+        else if element.has_child("invalid-authzid", ns::SASL) {
+            err.condition = Condition::InvalidAuthzid;
+        }
+        else if element.has_child("malformed-request", ns::SASL) {
+            err.condition = Condition::MalformedRequest;
+        }
+        else if element.has_child("mechanism-too-weak", ns::SASL) {
+            err.condition = Condition::MechanismTooWeak;
+        }
+        else if element.has_child("not-authorized", ns::SASL) {
+            err.condition = Condition::NotAuthorized;
+        }
+        else if element.has_child("temporary-auth-failure", ns::SASL) {
+            err.condition = Condition::TemporaryAuthFailure;
+        }
+        Ok(err)
+    }
+}

src/components/stanza_error.rs 🔗

@@ -0,0 +1,172 @@
+use ns;
+use minidom::Element;
+use util::{FromElement, FromParentElement};
+use std::str::FromStr;
+
+#[derive(Copy, Clone, Debug)]
+pub enum ErrorType {
+    Auth,
+    Cancel,
+    Continue,
+    Modify,
+    Wait,
+}
+
+impl FromStr for ErrorType {
+    type Err = ();
+
+    fn from_str(s: &str) -> Result<ErrorType, ()> {
+        Ok(match s {
+            "auth" => ErrorType::Auth,
+            "cancel" => ErrorType::Cancel,
+            "continue" => ErrorType::Continue,
+            "modify" => ErrorType::Modify,
+            "wait" => ErrorType::Wait,
+            _ => { return Err(()); },
+        })
+    }
+}
+
+#[derive(Clone, Debug)]
+pub enum Condition {
+    BadRequest,
+    Conflict,
+    FeatureNotImplemented,
+    Forbidden,
+    Gone(Option<String>),
+    InternalServerError,
+    ItemNotFound,
+    JidMalformed,
+    NotAcceptable,
+    NotAllowed,
+    NotAuthorized,
+    PolicyViolation,
+    RecipientUnavailable,
+    Redirect(Option<String>),
+    RegistrationRequired,
+    RemoteServerNotFound,
+    RemoteServerTimeout,
+    ResourceConstraint,
+    ServiceUnavailable,
+    SubscriptionRequired,
+    UndefinedCondition,
+    UnexpectedRequest,
+}
+
+impl FromParentElement for Condition {
+    type Err = ();
+
+    fn from_parent_element(elem: &Element) -> Result<Condition, ()> {
+        if elem.has_child("bad-request", ns::STANZAS) {
+            Ok(Condition::BadRequest)
+        }
+        else if elem.has_child("conflict", ns::STANZAS) {
+            Ok(Condition::Conflict)
+        }
+        else if elem.has_child("feature-not-implemented", ns::STANZAS) {
+            Ok(Condition::FeatureNotImplemented)
+        }
+        else if elem.has_child("forbidden", ns::STANZAS) {
+            Ok(Condition::Forbidden)
+        }
+        else if let Some(alt) = elem.get_child("gone", ns::STANZAS) {
+            let text = alt.text();
+            let inner = if text == "" { None } else { Some(text) };
+            Ok(Condition::Gone(inner))
+        }
+        else if elem.has_child("internal-server-error", ns::STANZAS) {
+            Ok(Condition::InternalServerError)
+        }
+        else if elem.has_child("item-not-found", ns::STANZAS) {
+            Ok(Condition::ItemNotFound)
+        }
+        else if elem.has_child("jid-malformed", ns::STANZAS) {
+            Ok(Condition::JidMalformed)
+        }
+        else if elem.has_child("not-acceptable", ns::STANZAS) {
+            Ok(Condition::NotAcceptable)
+        }
+        else if elem.has_child("not-allowed", ns::STANZAS) {
+            Ok(Condition::NotAllowed)
+        }
+        else if elem.has_child("not-authorized", ns::STANZAS) {
+            Ok(Condition::NotAuthorized)
+        }
+        else if elem.has_child("policy-violation", ns::STANZAS) {
+            Ok(Condition::PolicyViolation)
+        }
+        else if elem.has_child("recipient-unavailable", ns::STANZAS) {
+            Ok(Condition::RecipientUnavailable)
+        }
+        else if let Some(alt) = elem.get_child("redirect", ns::STANZAS) {
+            let text = alt.text();
+            let inner = if text == "" { None } else { Some(text) };
+            Ok(Condition::Redirect(inner))
+        }
+        else if elem.has_child("registration-required", ns::STANZAS) {
+            Ok(Condition::RegistrationRequired)
+        }
+        else if elem.has_child("remote-server-not-found", ns::STANZAS) {
+            Ok(Condition::RemoteServerNotFound)
+        }
+        else if elem.has_child("remote-server-timeout", ns::STANZAS) {
+            Ok(Condition::RemoteServerTimeout)
+        }
+        else if elem.has_child("resource-constraint", ns::STANZAS) {
+            Ok(Condition::ResourceConstraint)
+        }
+        else if elem.has_child("service-unavailable", ns::STANZAS) {
+            Ok(Condition::ServiceUnavailable)
+        }
+        else if elem.has_child("subscription-required", ns::STANZAS) {
+            Ok(Condition::SubscriptionRequired)
+        }
+        else if elem.has_child("undefined-condition", ns::STANZAS) {
+            Ok(Condition::UndefinedCondition)
+        }
+        else if elem.has_child("unexpected-request", ns::STANZAS) {
+            Ok(Condition::UnexpectedRequest)
+        }
+        else {
+            Err(())
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+pub struct StanzaError {
+    error_type: ErrorType,
+    text: Option<String>,
+    condition: Condition,
+}
+
+impl StanzaError {
+    pub fn new(error_type: ErrorType, text: Option<String>, condition: Condition) -> StanzaError {
+        StanzaError {
+            error_type: error_type,
+            text: text,
+            condition: condition,
+        }
+    }
+}
+
+impl FromElement for StanzaError {
+    type Err = ();
+
+    fn from_element(elem: &Element) -> Result<StanzaError, ()> {
+        if elem.is("error", ns::STANZAS) {
+            let error_type = elem.attr("type").ok_or(())?;
+            let err: ErrorType = error_type.parse().map_err(|_| ())?;
+            let condition: Condition = Condition::from_parent_element(elem)?;
+            let text = elem.get_child("text", ns::STANZAS).map(|c| c.text());
+            Ok(StanzaError {
+                error_type: err,
+                text: text,
+                condition: condition,
+            })
+        }
+        else {
+            Err(())
+        }
+    }
+}

src/error.rs 🔗

@@ -14,6 +14,8 @@ use minidom::Error as MinidomError;
 
 use base64::Base64Error;
 
+use components::sasl_error::SaslError;
+
 /// An error which wraps a bunch of errors from different crates and the stdlib.
 #[derive(Debug)]
 pub enum Error {
@@ -25,6 +27,7 @@ pub enum Error {
     MinidomError(MinidomError),
     Base64Error(Base64Error),
     SaslError(Option<String>),
+    XmppSaslError(SaslError),
     StreamError,
     EndOfDocument,
 }

src/lib.rs 🔗

@@ -13,5 +13,7 @@ pub mod plugin;
 pub mod event;
 pub mod plugins;
 pub mod connection;
+pub mod util;
+pub mod components;
 
 mod locked_io;

src/ns.rs 🔗

@@ -5,3 +5,4 @@ pub const STREAM: &'static str = "http://etherx.jabber.org/streams";
 pub const TLS: &'static str = "urn:ietf:params:xml:ns:xmpp-tls";
 pub const SASL: &'static str = "urn:ietf:params:xml:ns:xmpp-sasl";
 pub const BIND: &'static str = "urn:ietf:params:xml:ns:xmpp-bind";
+pub const STANZAS: &'static str = "urn:ietf:params:xml:ns:xmpp-stanzas";

src/util.rs 🔗

@@ -0,0 +1,19 @@
+use minidom::Element;
+
+pub trait FromElement where Self: Sized {
+    type Err;
+
+    fn from_element(elem: &Element) -> Result<Self, Self::Err>;
+}
+
+pub trait FromParentElement where Self: Sized {
+    type Err;
+
+    fn from_parent_element(elem: &Element) -> Result<Self, Self::Err>;
+}
+
+pub trait ToElement where Self: Sized {
+    type Err;
+
+    fn to_element(&self) -> Result<Element, Self::Err>;
+}