refactor ScramMechanism to support PLUS

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/Config.java                         |  2 
src/main/java/eu/siacs/conversations/crypto/sasl/DigestMd5.java          |  5 
src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java      | 13 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java     | 46 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramPlusMechanism.java | 22 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1Plus.java      | 36 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java            | 19 
7 files changed, 126 insertions(+), 17 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/Config.java 🔗

@@ -57,6 +57,8 @@ public final class Config {
     public static final long CONTACT_SYNC_RETRY_INTERVAL = 1000L * 60 * 5;
 
 
+    public static final boolean SASL_2_ENABLED = false;
+
     //Notification settings
     public static final boolean HIDE_MESSAGE_TEXT_IN_NOTIFICATION = false;
     public static final boolean ALWAYS_NOTIFY_BY_DEFAULT = false;

src/main/java/eu/siacs/conversations/crypto/sasl/DigestMd5.java 🔗

@@ -6,6 +6,8 @@ import java.nio.charset.Charset;
 import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
 
+import javax.net.ssl.SSLSocket;
+
 import eu.siacs.conversations.entities.Account;
 import eu.siacs.conversations.utils.CryptoHelper;
 
@@ -29,7 +31,8 @@ public class DigestMd5 extends SaslMechanism {
     }
 
     @Override
-    public String getResponse(final String challenge) throws AuthenticationException {
+    public String getResponse(final String challenge, final SSLSocket sslSocket)
+            throws AuthenticationException {
         switch (state) {
             case INITIAL:
                 state = State.RESPONSE_SENT;

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java 🔗

@@ -4,6 +4,8 @@ import com.google.common.base.Strings;
 
 import java.util.Collection;
 
+import javax.net.ssl.SSLSocket;
+
 import eu.siacs.conversations.entities.Account;
 import eu.siacs.conversations.xml.Element;
 import eu.siacs.conversations.xml.Namespace;
@@ -31,7 +33,8 @@ public abstract class SaslMechanism {
         return "";
     }
 
-    public String getResponse(final String challenge) throws AuthenticationException {
+    public String getResponse(final String challenge, final SSLSocket sslSocket)
+            throws AuthenticationException {
         return "";
     }
 
@@ -112,4 +115,12 @@ public abstract class SaslMechanism {
             }
         }
     }
+
+    public static String namespace(final Version version) {
+        if (version == Version.SASL) {
+            return Namespace.SASL;
+        } else {
+            return Namespace.SASL_2;
+        }
+    }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java 🔗

@@ -2,6 +2,7 @@ package eu.siacs.conversations.crypto.sasl;
 
 import android.util.Base64;
 
+import com.google.common.base.CaseFormat;
 import com.google.common.base.Objects;
 import com.google.common.cache.Cache;
 import com.google.common.cache.CacheBuilder;
@@ -14,18 +15,19 @@ import java.nio.charset.Charset;
 import java.security.InvalidKeyException;
 import java.util.concurrent.ExecutionException;
 
+import javax.net.ssl.SSLSocket;
+
 import eu.siacs.conversations.entities.Account;
 import eu.siacs.conversations.utils.CryptoHelper;
 
 abstract class ScramMechanism extends SaslMechanism {
-    // TODO: When channel binding (SCRAM-SHA1-PLUS) is supported in future, generalize this to
-    // indicate support and/or usage.
-    private static final String GS2_HEADER = "n,,";
+
     private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
     private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
     private static final Cache<CacheKey, KeyPair> CACHE =
             CacheBuilder.newBuilder().maximumSize(10).build();
     protected final ChannelBinding channelBinding;
+    private final String gs2Header;
     private final String clientNonce;
     protected State state = State.INITIAL;
     private String clientFirstMessageBare;
@@ -34,6 +36,16 @@ abstract class ScramMechanism extends SaslMechanism {
     ScramMechanism(final Account account, final ChannelBinding channelBinding) {
         super(account);
         this.channelBinding = channelBinding;
+        if (channelBinding == ChannelBinding.NONE) {
+            this.gs2Header = "n,,";
+        } else {
+            this.gs2Header =
+                    String.format(
+                            "p=%s,,",
+                            CaseFormat.UPPER_UNDERSCORE
+                                    .converterTo(CaseFormat.LOWER_HYPHEN)
+                                    .convert(channelBinding.toString()));
+        }
         // This nonce should be different for each authentication attempt.
         this.clientNonce = CryptoHelper.random(100);
         clientFirstMessageBare = "";
@@ -69,7 +81,7 @@ abstract class ScramMechanism extends SaslMechanism {
         return out;
     }
 
-    public byte[] digest(byte[] bytes) {
+    public byte[] digest(final byte[] bytes) {
         final Digest digest = getDigest();
         digest.reset();
         digest.update(bytes, 0, bytes.length);
@@ -107,12 +119,13 @@ abstract class ScramMechanism extends SaslMechanism {
             state = State.AUTH_TEXT_SENT;
         }
         return Base64.encodeToString(
-                (GS2_HEADER + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
+                (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
                 Base64.NO_WRAP);
     }
 
     @Override
-    public String getResponse(final String challenge) throws AuthenticationException {
+    public String getResponse(final String challenge, final SSLSocket socket)
+            throws AuthenticationException {
         switch (state) {
             case AUTH_TEXT_SENT:
                 if (challenge == null) {
@@ -169,11 +182,17 @@ abstract class ScramMechanism extends SaslMechanism {
                     throw new AuthenticationException("Server sent empty salt");
                 }
 
+                final byte[] channelBindingData = getChannelBindingData(socket);
+
+                final int gs2Len = this.gs2Header.getBytes().length;
+                final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
+                System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
+                System.arraycopy(
+                        channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
+
                 final String clientFinalMessageWithoutProof =
-                        "c="
-                                + Base64.encodeToString(GS2_HEADER.getBytes(), Base64.NO_WRAP)
-                                + ",r="
-                                + nonce;
+                        "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
+
                 final byte[] authMessage =
                         (clientFirstMessageBare
                                         + ','
@@ -239,6 +258,13 @@ abstract class ScramMechanism extends SaslMechanism {
         }
     }
 
+    protected byte[] getChannelBindingData(final SSLSocket sslSocket) throws AuthenticationException {
+        if (this.channelBinding == ChannelBinding.NONE) {
+            return new byte[0];
+        }
+        throw new AssertionError("getChannelBindingData needs to be overwritten");
+    }
+
     private static class CacheKey {
         final String algorithm;
         final String password;

src/main/java/eu/siacs/conversations/crypto/sasl/ScramPlusMechanism.java 🔗

@@ -0,0 +1,22 @@
+package eu.siacs.conversations.crypto.sasl;
+
+import javax.net.ssl.SSLSocket;
+
+import eu.siacs.conversations.entities.Account;
+
+abstract class ScramPlusMechanism extends ScramMechanism {
+    ScramPlusMechanism(Account account, ChannelBinding channelBinding) {
+        super(account, channelBinding);
+    }
+
+    @Override
+    protected byte[] getChannelBindingData(final SSLSocket sslSocket) throws AuthenticationException {
+        if (this.channelBinding == ChannelBinding.NONE) {
+            throw new AuthenticationException(String.format("%s is not a valid channel binding", ChannelBinding.NONE));
+        }
+        if (sslSocket == null) {
+            throw new AuthenticationException("Channel binding attempt on non secure socket");
+        }
+        throw new AssertionError("not yet implemented");
+    }
+}

src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1Plus.java 🔗

@@ -0,0 +1,36 @@
+package eu.siacs.conversations.crypto.sasl;
+
+import org.bouncycastle.crypto.Digest;
+import org.bouncycastle.crypto.digests.SHA1Digest;
+import org.bouncycastle.crypto.macs.HMac;
+
+import eu.siacs.conversations.entities.Account;
+
+public class ScramSha1Plus extends ScramPlusMechanism {
+
+    public static final String MECHANISM = "SCRAM-SHA-1-PLUS";
+
+    public ScramSha1Plus(final Account account, final ChannelBinding channelBinding) {
+        super(account, channelBinding);
+    }
+
+    @Override
+    protected HMac getHMAC() {
+        return new HMac(new SHA1Digest());
+    }
+
+    @Override
+    protected Digest getDigest() {
+        return new SHA1Digest();
+    }
+
+    @Override
+    public int getPriority() {
+        return 35; //higher than SCRAM-SHA512 (30)
+    }
+
+    @Override
+    public String getMechanism() {
+        return MECHANISM;
+    }
+}

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -607,7 +607,7 @@ public class XmppConnection implements Runnable {
                     throw new AssertionError("Missing implementation for " + version);
                 }
                 try {
-                    response.setContent(saslMechanism.getResponse(challenge.getContent()));
+                    response.setContent(saslMechanism.getResponse(challenge.getContent(), sslSocketOrNull(socket)));
                 } catch (final SaslMechanism.AuthenticationException e) {
                     // TODO: Send auth abort tag.
                     Log.e(Config.LOGTAG, e.toString());
@@ -707,7 +707,7 @@ public class XmppConnection implements Runnable {
             throw new AssertionError("Missing implementation for " + version);
         }
         try {
-            saslMechanism.getResponse(challenge);
+            saslMechanism.getResponse(challenge, sslSocketOrNull(socket));
         } catch (final SaslMechanism.AuthenticationException e) {
             Log.e(Config.LOGTAG, String.valueOf(e));
             throw new StateChangingException(Account.State.UNAUTHORIZED);
@@ -798,6 +798,14 @@ public class XmppConnection implements Runnable {
         }
     }
 
+    private static SSLSocket sslSocketOrNull(final Socket socket) {
+        if (socket instanceof SSLSocket) {
+            return (SSLSocket) socket;
+        } else {
+            return null;
+        }
+    }
+
     private void processEnabled(final Element enabled) {
         final String streamId;
         if (enabled.getAttributeAsBoolean("resume")) {
@@ -1170,7 +1178,8 @@ public class XmppConnection implements Runnable {
         } else if (!this.streamFeatures.hasChild("register", Namespace.REGISTER_STREAM_FEATURE)
                 && account.isOptionSet(Account.OPTION_REGISTER)) {
             throw new StateChangingException(Account.State.REGISTRATION_NOT_SUPPORTED);
-        } else if (this.streamFeatures.hasChild("mechanisms", Namespace.SASL_2)
+        } else if (Config.SASL_2_ENABLED
+                && this.streamFeatures.hasChild("mechanisms", Namespace.SASL_2)
                 && shouldAuthenticate
                 && isSecure) {
             authenticate(SaslMechanism.Version.SASL_2);
@@ -1213,9 +1222,8 @@ public class XmppConnection implements Runnable {
     }
 
     private void authenticate(final SaslMechanism.Version version) throws IOException {
-        Log.d(Config.LOGTAG, "stream features: " + this.streamFeatures);
         final Element element =
-                this.streamFeatures.findChild("mechanisms"); // TODO get from correct NS
+                this.streamFeatures.findChild("mechanisms", SaslMechanism.namespace(version));
         final Collection<String> mechanisms =
                 Collections2.transform(
                         Collections2.filter(
@@ -1234,6 +1242,7 @@ public class XmppConnection implements Runnable {
                                         c -> c != null && "channel-binding".equals(c.getName())),
                                 c -> c == null ? null : ChannelBinding.of(c.getAttribute("type"))),
                         Predicates.notNull());
+        Log.d(Config.LOGTAG,"mechanisms: "+mechanisms);
         Log.d(Config.LOGTAG, "channel bindings: " + channelBindings);
         final SaslMechanism.Factory factory = new SaslMechanism.Factory(account);
         this.saslMechanism = factory.of(mechanisms, channelBindings);