support logging in via SASL 2

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java |  17 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java       | 160 
2 files changed, 121 insertions(+), 56 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java 🔗

@@ -1,8 +1,12 @@
 package eu.siacs.conversations.crypto.sasl;
 
+import com.google.common.base.Strings;
+
 import java.security.SecureRandom;
 
 import eu.siacs.conversations.entities.Account;
+import eu.siacs.conversations.xml.Element;
+import eu.siacs.conversations.xml.Namespace;
 import eu.siacs.conversations.xml.TagWriter;
 
 public abstract class SaslMechanism {
@@ -68,6 +72,17 @@ public abstract class SaslMechanism {
     }
 
     public enum Version {
-        SASL, SASL_2
+        SASL, SASL_2;
+
+        public static Version of(final Element element) {
+            switch ( Strings.nullToEmpty(element.getNamespace())) {
+                case Namespace.SASL:
+                    return SASL;
+                case Namespace.SASL_2:
+                    return SASL_2;
+                default:
+                    throw new IllegalArgumentException("Unrecognized SASL namespace");
+            }
+        }
     }
 }

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -469,63 +469,102 @@ public class XmppConnection implements Runnable {
             } else if (nextTag.isStart("proceed")) {
                 switchOverToTls();
             } else if (nextTag.isStart("success")) {
-                final String challenge = tagReader.readElement(nextTag).getContent();
+                final Element success = tagReader.readElement(nextTag);
+                final SaslMechanism.Version version;
+                try {
+                    version = SaslMechanism.Version.of(success);
+                } catch (final IllegalArgumentException e) {
+                    throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
+                }
+                final String challenge;
+                if (version == SaslMechanism.Version.SASL) {
+                    challenge = success.getContent();
+                } else if (version == SaslMechanism.Version.SASL_2) {
+                    challenge = success.findChildContent("additional-data");
+                } else {
+                    throw new AssertionError("Missing implementation for " + version);
+                }
                 try {
                     saslMechanism.getResponse(challenge);
                 } catch (final SaslMechanism.AuthenticationException e) {
                     Log.e(Config.LOGTAG, String.valueOf(e));
                     throw new StateChangingException(Account.State.UNAUTHORIZED);
                 }
-                Log.d(Config.LOGTAG, account.getJid().asBareJid().toString() + ": logged in");
-                account.setKey(Account.PINNED_MECHANISM_KEY,
-                        String.valueOf(saslMechanism.getPriority()));
-                tagReader.reset();
-                sendStartStream();
-                final Tag tag = tagReader.readTag();
-                if (tag != null && tag.isStart("stream")) {
-                    processStream();
-                } else {
-                    throw new StateChangingException(Account.State.STREAM_OPENING_ERROR);
+                Log.d(
+                        Config.LOGTAG,
+                        account.getJid().asBareJid().toString()
+                                + ": logged in (using "
+                                + version
+                                + ")");
+                account.setKey(
+                        Account.PINNED_MECHANISM_KEY, String.valueOf(saslMechanism.getPriority()));
+                if (version == SaslMechanism.Version.SASL) {
+                    tagReader.reset();
+                    sendStartStream();
+                    final Tag tag = tagReader.readTag();
+                    if (tag != null && tag.isStart("stream")) {
+                        processStream();
+                    } else {
+                        throw new StateChangingException(Account.State.STREAM_OPENING_ERROR);
+                    }
+                    break;
                 }
-                break;
             } else if (nextTag.isStart("failure")) {
                 final Element failure = tagReader.readElement(nextTag);
-                if (Namespace.SASL.equals(failure.getNamespace())) {
-                    if (failure.hasChild("temporary-auth-failure")) {
-                        throw new StateChangingException(Account.State.TEMPORARY_AUTH_FAILURE);
-                    } else if (failure.hasChild("account-disabled")) {
-                        final String text = failure.findChildContent("text");
-                        if ( Strings.isNullOrEmpty(text)) {
+                if (Namespace.TLS.equals(failure.getNamespace())) {
+                    throw new StateChangingException(Account.State.TLS_ERROR);
+                }
+                final SaslMechanism.Version version;
+                try {
+                    version = SaslMechanism.Version.of(failure);
+                } catch (final IllegalArgumentException e) {
+                    throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
+                }
+                Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": login failure " + version);
+                if (failure.hasChild("temporary-auth-failure")) {
+                    throw new StateChangingException(Account.State.TEMPORARY_AUTH_FAILURE);
+                } else if (failure.hasChild("account-disabled")) {
+                    final String text = failure.findChildContent("text");
+                    if (Strings.isNullOrEmpty(text)) {
+                        throw new StateChangingException(Account.State.UNAUTHORIZED);
+                    }
+                    final Matcher matcher = Patterns.AUTOLINK_WEB_URL.matcher(text);
+                    if (matcher.find()) {
+                        final HttpUrl url;
+                        try {
+                            url = HttpUrl.get(text.substring(matcher.start(), matcher.end()));
+                        } catch (final IllegalArgumentException e) {
                             throw new StateChangingException(Account.State.UNAUTHORIZED);
                         }
-                        final Matcher matcher = Patterns.AUTOLINK_WEB_URL.matcher(text);
-                        if (matcher.find()) {
-                            final HttpUrl url;
-                            try {
-                                url = HttpUrl.get(text.substring(matcher.start(), matcher.end()));
-                            } catch (final IllegalArgumentException e) {
-                                throw new StateChangingException(Account.State.UNAUTHORIZED);
-                            }
-                            if (url.isHttps()) {
-                                this.redirectionUrl = url;
-                                throw new StateChangingException(Account.State.PAYMENT_REQUIRED);
-                            }
+                        if (url.isHttps()) {
+                            this.redirectionUrl = url;
+                            throw new StateChangingException(Account.State.PAYMENT_REQUIRED);
                         }
                     }
-                    throw new StateChangingException(Account.State.UNAUTHORIZED);
-                } else if (Namespace.TLS.equals(failure.getNamespace())) {
-                    throw new StateChangingException(Account.State.TLS_ERROR);
-                } else {
-                    throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
                 }
+                throw new StateChangingException(Account.State.UNAUTHORIZED);
             } else if (nextTag.isStart("challenge")) {
-                final String challenge = tagReader.readElement(nextTag).getContent();
-                final Element response = new Element("response", Namespace.SASL);
+                final Element challenge = tagReader.readElement(nextTag);
+                final SaslMechanism.Version version;
+                try {
+                    version = SaslMechanism.Version.of(challenge);
+                } catch (final IllegalArgumentException e) {
+                    throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
+                }
+                final Element response;
+                if (version == SaslMechanism.Version.SASL) {
+                    response = new Element("response", Namespace.SASL);
+                } else if (version == SaslMechanism.Version.SASL_2) {
+                    response = new Element("response", Namespace.SASL_2);
+                } else {
+                    throw new AssertionError("Missing implementation for " + version);
+                }
                 try {
-                    response.setContent(saslMechanism.getResponse(challenge));
+                    response.setContent(saslMechanism.getResponse(challenge.getContent()));
                 } catch (final SaslMechanism.AuthenticationException e) {
                     // TODO: Send auth abort tag.
                     Log.e(Config.LOGTAG, e.toString());
+                    throw new StateChangingException(Account.State.UNAUTHORIZED);
                 }
                 tagWriter.writeElement(response);
             } else if (nextTag.isStart("enabled")) {
@@ -848,7 +887,6 @@ public class XmppConnection implements Runnable {
 
     private void processStreamFeatures(final Tag currentTag) throws IOException {
         this.streamFeatures = tagReader.readElement(currentTag);
-        Log.d(Config.LOGTAG, this.streamFeatures.toString());
         final boolean isSecure =
                 features.encryptionEnabled || Config.ALLOW_NON_TLS_CONNECTIONS || account.isOnion();
         final boolean needsBinding = !isBound && !account.isOptionSet(Account.OPTION_REGISTER);
@@ -907,7 +945,6 @@ public class XmppConnection implements Runnable {
 
     private void authenticate(final SaslMechanism.Version version) throws IOException {
         final List<String> mechanisms = extractMechanisms(streamFeatures.findChild("mechanisms"));
-        final Element auth = new Element("auth", Namespace.SASL);
         if (mechanisms.contains(External.MECHANISM) && account.getPrivateKeyAlias() != null) {
             saslMechanism = new External(tagWriter, account, mXmppConnectionService.getRNG());
         } else if (mechanisms.contains(ScramSha512.MECHANISM)) {
@@ -923,25 +960,38 @@ public class XmppConnection implements Runnable {
         } else if (mechanisms.contains(Anonymous.MECHANISM)) {
             saslMechanism = new Anonymous(tagWriter, account, mXmppConnectionService.getRNG());
         }
-        if (saslMechanism != null) {
-            final int pinnedMechanism = account.getKeyAsInt(Account.PINNED_MECHANISM_KEY, -1);
-            if (pinnedMechanism > saslMechanism.getPriority()) {
-                Log.e(Config.LOGTAG, "Auth failed. Authentication mechanism " + saslMechanism.getMechanism() +
-                        " has lower priority (" + saslMechanism.getPriority() +
-                        ") than pinned priority (" + pinnedMechanism +
-                        "). Possible downgrade attack?");
-                throw new StateChangingException(Account.State.DOWNGRADE_ATTACK);
+        if (saslMechanism == null) {
+            Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": unable to find supported SASL mechanism in " + mechanisms);
+            throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
+        }
+        final int pinnedMechanism = account.getKeyAsInt(Account.PINNED_MECHANISM_KEY, -1);
+        if (pinnedMechanism > saslMechanism.getPriority()) {
+            Log.e(Config.LOGTAG, "Auth failed. Authentication mechanism " + saslMechanism.getMechanism() +
+                    " has lower priority (" + saslMechanism.getPriority() +
+                    ") than pinned priority (" + pinnedMechanism +
+                    "). Possible downgrade attack?");
+            throw new StateChangingException(Account.State.DOWNGRADE_ATTACK);
+        }
+        final String firstMessage = saslMechanism.getClientFirstMessage();
+        final Element authenticate;
+        if (version == SaslMechanism.Version.SASL) {
+            authenticate = new Element("auth", Namespace.SASL);
+            if (!Strings.isNullOrEmpty(firstMessage)) {
+                authenticate.setContent(firstMessage);
             }
-            Log.d(Config.LOGTAG, account.getJid().toString() + ": Authenticating with " + saslMechanism.getMechanism());
-            auth.setAttribute("mechanism", saslMechanism.getMechanism());
-            if (!saslMechanism.getClientFirstMessage().isEmpty()) {
-                auth.setContent(saslMechanism.getClientFirstMessage());
+        } else if (version == SaslMechanism.Version.SASL_2) {
+            authenticate = new Element("authenticate", Namespace.SASL_2);
+            if (!Strings.isNullOrEmpty(firstMessage)) {
+                authenticate.addChild("initial-response").setContent(firstMessage);
             }
-            tagWriter.writeElement(auth);
+            // TODO place to add extensions
         } else {
-            Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": unable to find supported SASL mechanism in " + mechanisms);
-            throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
+            throw new AssertionError("Missing implementation for " + version);
         }
+
+        Log.d(Config.LOGTAG, account.getJid().toString() + ": Authenticating with "+version+ "/" + saslMechanism.getMechanism());
+        authenticate.setAttribute("mechanism", saslMechanism.getMechanism());
+        tagWriter.writeElement(authenticate);
     }
 
     private List<String> extractMechanisms(final Element stream) {