add support for HashedToken channel binding

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/crypto/sasl/Anonymous.java               |  4 
src/main/java/eu/siacs/conversations/crypto/sasl/ChannelBindingMechanism.java | 94 
src/main/java/eu/siacs/conversations/crypto/sasl/External.java                |  4 
src/main/java/eu/siacs/conversations/crypto/sasl/HashedToken.java             | 25 
src/main/java/eu/siacs/conversations/crypto/sasl/Plain.java                   |  4 
src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java           | 15 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java          |  2 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramPlusMechanism.java      | 89 
src/main/java/eu/siacs/conversations/entities/Account.java                    |  7 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java                 | 75 
10 files changed, 187 insertions(+), 132 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/crypto/sasl/Anonymous.java 🔗

@@ -1,5 +1,7 @@
 package eu.siacs.conversations.crypto.sasl;
 
+import javax.net.ssl.SSLSocket;
+
 import eu.siacs.conversations.entities.Account;
 
 public class Anonymous extends SaslMechanism {
@@ -21,7 +23,7 @@ public class Anonymous extends SaslMechanism {
     }
 
     @Override
-    public String getClientFirstMessage() {
+    public String getClientFirstMessage(final SSLSocket sslSocket) {
         return "";
     }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/ChannelBindingMechanism.java 🔗

@@ -1,6 +1,100 @@
 package eu.siacs.conversations.crypto.sasl;
 
+import org.bouncycastle.jcajce.provider.digest.SHA256;
+import org.conscrypt.Conscrypt;
+
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+import java.security.cert.Certificate;
+import java.security.cert.CertificateEncodingException;
+import java.security.cert.X509Certificate;
+
+import javax.net.ssl.SSLException;
+import javax.net.ssl.SSLPeerUnverifiedException;
+import javax.net.ssl.SSLSession;
+import javax.net.ssl.SSLSocket;
+
 public interface ChannelBindingMechanism {
 
+    String EXPORTER_LABEL = "EXPORTER-Channel-Binding";
+
     ChannelBinding getChannelBinding();
+
+    static byte[] getChannelBindingData(final SSLSocket sslSocket, final ChannelBinding channelBinding)
+            throws SaslMechanism.AuthenticationException {
+        if (sslSocket == null) {
+            throw new SaslMechanism.AuthenticationException("Channel binding attempt on non secure socket");
+        }
+        if (channelBinding == ChannelBinding.TLS_EXPORTER) {
+            final byte[] keyingMaterial;
+            try {
+                keyingMaterial =
+                        Conscrypt.exportKeyingMaterial(sslSocket, EXPORTER_LABEL, new byte[0], 32);
+            } catch (final SSLException e) {
+                throw new SaslMechanism.AuthenticationException("Could not export keying material");
+            }
+            if (keyingMaterial == null) {
+                throw new SaslMechanism.AuthenticationException(
+                        "Could not export keying material. Socket not ready");
+            }
+            return keyingMaterial;
+        } else if (channelBinding == ChannelBinding.TLS_UNIQUE) {
+            final byte[] unique = Conscrypt.getTlsUnique(sslSocket);
+            if (unique == null) {
+                throw new SaslMechanism.AuthenticationException(
+                        "Could not retrieve tls unique. Socket not ready");
+            }
+            return unique;
+        } else if (channelBinding == ChannelBinding.TLS_SERVER_END_POINT) {
+            return getServerEndPointChannelBinding(sslSocket.getSession());
+        } else {
+            throw new SaslMechanism.AuthenticationException(
+                    String.format("%s is not a valid channel binding", channelBinding));
+        }
+    }
+
+    static byte[] getServerEndPointChannelBinding(final SSLSession session)
+            throws SaslMechanism.AuthenticationException {
+        final Certificate[] certificates;
+        try {
+            certificates = session.getPeerCertificates();
+        } catch (final SSLPeerUnverifiedException e) {
+            throw new SaslMechanism.AuthenticationException("Could not verify peer certificates");
+        }
+        if (certificates == null || certificates.length == 0) {
+            throw new SaslMechanism.AuthenticationException("Could not retrieve peer certificate");
+        }
+        final X509Certificate certificate;
+        if (certificates[0] instanceof X509Certificate) {
+            certificate = (X509Certificate) certificates[0];
+        } else {
+            throw new SaslMechanism.AuthenticationException("Certificate was not X509");
+        }
+        final String algorithm = certificate.getSigAlgName();
+        final int withIndex = algorithm.indexOf("with");
+        if (withIndex <= 0) {
+            throw new SaslMechanism.AuthenticationException("Unable to parse SigAlgName");
+        }
+        final String hashAlgorithm = algorithm.substring(0, withIndex);
+        final MessageDigest messageDigest;
+        // https://www.rfc-editor.org/rfc/rfc5929#section-4.1
+        if ("MD5".equalsIgnoreCase(hashAlgorithm) || "SHA1".equalsIgnoreCase(hashAlgorithm)) {
+            messageDigest = new SHA256.Digest();
+        } else {
+            try {
+                messageDigest = MessageDigest.getInstance(hashAlgorithm);
+            } catch (final NoSuchAlgorithmException e) {
+                throw new SaslMechanism.AuthenticationException(
+                        "Could not instantiate message digest for " + hashAlgorithm);
+            }
+        }
+        final byte[] encodedCertificate;
+        try {
+            encodedCertificate = certificate.getEncoded();
+        } catch (final CertificateEncodingException e) {
+            throw new SaslMechanism.AuthenticationException("Could not encode certificate");
+        }
+        messageDigest.update(encodedCertificate);
+        return messageDigest.digest();
+    }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/External.java 🔗

@@ -2,6 +2,8 @@ package eu.siacs.conversations.crypto.sasl;
 
 import android.util.Base64;
 
+import javax.net.ssl.SSLSocket;
+
 import eu.siacs.conversations.entities.Account;
 
 public class External extends SaslMechanism {
@@ -23,7 +25,7 @@ public class External extends SaslMechanism {
     }
 
     @Override
-    public String getClientFirstMessage() {
+    public String getClientFirstMessage(final SSLSocket sslSocket) {
         return Base64.encodeToString(
                 account.getJid().asBareJid().toEscapedString().getBytes(), Base64.NO_WRAP);
     }

src/main/java/eu/siacs/conversations/crypto/sasl/HashedToken.java 🔗

@@ -1,6 +1,7 @@
 package eu.siacs.conversations.crypto.sasl;
 
 import android.util.Base64;
+import android.util.Log;
 
 import com.google.common.base.MoreObjects;
 import com.google.common.base.Strings;
@@ -18,6 +19,7 @@ import java.util.List;
 
 import javax.net.ssl.SSLSocket;
 
+import eu.siacs.conversations.Config;
 import eu.siacs.conversations.entities.Account;
 import eu.siacs.conversations.utils.SSLSockets;
 
@@ -42,10 +44,10 @@ public abstract class HashedToken extends SaslMechanism implements ChannelBindin
     }
 
     @Override
-    public String getClientFirstMessage() {
+    public String getClientFirstMessage(final SSLSocket sslSocket) {
         final String token = Strings.nullToEmpty(this.account.getFastToken());
         final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
-        final byte[] cbData = new byte[0];
+        final byte[] cbData = getChannelBindingData(sslSocket);
         final byte[] initiatorHashedToken =
                 hashing.hashBytes(Bytes.concat(INITIATOR, cbData)).asBytes();
         final byte[] firstMessage =
@@ -56,6 +58,23 @@ public abstract class HashedToken extends SaslMechanism implements ChannelBindin
         return Base64.encodeToString(firstMessage, Base64.NO_WRAP);
     }
 
+    private byte[] getChannelBindingData(final SSLSocket sslSocket) {
+        if (this.channelBinding == ChannelBinding.NONE) {
+            return new byte[0];
+        }
+        try {
+            return ChannelBindingMechanism.getChannelBindingData(sslSocket, this.channelBinding);
+        } catch (final AuthenticationException e) {
+            Log.e(
+                    Config.LOGTAG,
+                    account.getJid().asBareJid()
+                            + ": unable to retrieve channel binding data for "
+                            + getMechanism(),
+                    e);
+            return new byte[0];
+        }
+    }
+
     @Override
     public String getResponse(final String challenge, final SSLSocket socket)
             throws AuthenticationException {
@@ -67,7 +86,7 @@ public abstract class HashedToken extends SaslMechanism implements ChannelBindin
         }
         final String token = Strings.nullToEmpty(this.account.getFastToken());
         final HashFunction hashing = getHashFunction(token.getBytes(StandardCharsets.UTF_8));
-        final byte[] cbData = new byte[0];
+        final byte[] cbData = getChannelBindingData(socket);
         final byte[] expectedResponderMessage =
                 hashing.hashBytes(Bytes.concat(RESPONDER, cbData)).asBytes();
         if (Arrays.equals(responderMessage, expectedResponderMessage)) {

src/main/java/eu/siacs/conversations/crypto/sasl/Plain.java 🔗

@@ -4,6 +4,8 @@ import android.util.Base64;
 
 import java.nio.charset.Charset;
 
+import javax.net.ssl.SSLSocket;
+
 import eu.siacs.conversations.entities.Account;
 
 public class Plain extends SaslMechanism {
@@ -30,7 +32,7 @@ public class Plain extends SaslMechanism {
     }
 
     @Override
-    public String getClientFirstMessage() {
+    public String getClientFirstMessage(final SSLSocket sslSocket) {
         return getMessage(account.getUsername(), account.getPassword());
     }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java 🔗

@@ -44,7 +44,7 @@ public abstract class SaslMechanism {
 
     public abstract String getMechanism();
 
-    public String getClientFirstMessage() {
+    public String getClientFirstMessage(final SSLSocket sslSocket) {
         return "";
     }
 
@@ -154,7 +154,12 @@ public abstract class SaslMechanism {
         public SaslMechanism of(
                 final Collection<String> mechanisms,
                 final Collection<ChannelBinding> bindings,
+                final Version version,
                 final SSLSockets.Version sslVersion) {
+            final HashedToken fastMechanism = account.getFastMechanism();
+            if (version == Version.SASL_2 && fastMechanism != null) {
+                return fastMechanism;
+            }
             final ChannelBinding channelBinding = ChannelBinding.best(bindings, sslVersion);
             return of(mechanisms, channelBinding);
         }
@@ -180,4 +185,12 @@ public abstract class SaslMechanism {
             return mechanism;
         }
     }
+
+    public static boolean hashedToken(final SaslMechanism saslMechanism) {
+        return saslMechanism instanceof HashedToken;
+    }
+
+    public static boolean pin(final SaslMechanism saslMechanism) {
+        return !hashedToken(saslMechanism);
+    }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java 🔗

@@ -112,7 +112,7 @@ abstract class ScramMechanism extends SaslMechanism {
     }
 
     @Override
-    public String getClientFirstMessage() {
+    public String getClientFirstMessage(final SSLSocket sslSocket) {
         if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
             clientFirstMessageBare =
                     "n="

src/main/java/eu/siacs/conversations/crypto/sasl/ScramPlusMechanism.java 🔗

@@ -1,25 +1,11 @@
 package eu.siacs.conversations.crypto.sasl;
 
-import org.bouncycastle.jcajce.provider.digest.SHA256;
-import org.conscrypt.Conscrypt;
-
-import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
-import java.security.cert.Certificate;
-import java.security.cert.CertificateEncodingException;
-import java.security.cert.X509Certificate;
-
-import javax.net.ssl.SSLException;
-import javax.net.ssl.SSLPeerUnverifiedException;
-import javax.net.ssl.SSLSession;
 import javax.net.ssl.SSLSocket;
 
 import eu.siacs.conversations.entities.Account;
 
 public abstract class ScramPlusMechanism extends ScramMechanism implements ChannelBindingMechanism {
 
-    private static final String EXPORTER_LABEL = "EXPORTER-Channel-Binding";
-
     ScramPlusMechanism(Account account, ChannelBinding channelBinding) {
         super(account, channelBinding);
     }
@@ -27,80 +13,7 @@ public abstract class ScramPlusMechanism extends ScramMechanism implements Chann
     @Override
     protected byte[] getChannelBindingData(final SSLSocket sslSocket)
             throws AuthenticationException {
-        if (sslSocket == null) {
-            throw new AuthenticationException("Channel binding attempt on non secure socket");
-        }
-        if (this.channelBinding == ChannelBinding.TLS_EXPORTER) {
-            final byte[] keyingMaterial;
-            try {
-                keyingMaterial =
-                        Conscrypt.exportKeyingMaterial(sslSocket, EXPORTER_LABEL, new byte[0], 32);
-            } catch (final SSLException e) {
-                throw new AuthenticationException("Could not export keying material");
-            }
-            if (keyingMaterial == null) {
-                throw new AuthenticationException(
-                        "Could not export keying material. Socket not ready");
-            }
-            return keyingMaterial;
-        } else if (this.channelBinding == ChannelBinding.TLS_UNIQUE) {
-            final byte[] unique = Conscrypt.getTlsUnique(sslSocket);
-            if (unique == null) {
-                throw new AuthenticationException(
-                        "Could not retrieve tls unique. Socket not ready");
-            }
-            return unique;
-        } else if (this.channelBinding == ChannelBinding.TLS_SERVER_END_POINT) {
-            return getServerEndPointChannelBinding(sslSocket.getSession());
-        } else {
-            throw new AuthenticationException(
-                    String.format("%s is not a valid channel binding", channelBinding));
-        }
-    }
-
-    private byte[] getServerEndPointChannelBinding(final SSLSession session)
-            throws AuthenticationException {
-        final Certificate[] certificates;
-        try {
-            certificates = session.getPeerCertificates();
-        } catch (final SSLPeerUnverifiedException e) {
-            throw new AuthenticationException("Could not verify peer certificates");
-        }
-        if (certificates == null || certificates.length == 0) {
-            throw new AuthenticationException("Could not retrieve peer certificate");
-        }
-        final X509Certificate certificate;
-        if (certificates[0] instanceof X509Certificate) {
-            certificate = (X509Certificate) certificates[0];
-        } else {
-            throw new AuthenticationException("Certificate was not X509");
-        }
-        final String algorithm = certificate.getSigAlgName();
-        final int withIndex = algorithm.indexOf("with");
-        if (withIndex <= 0) {
-            throw new AuthenticationException("Unable to parse SigAlgName");
-        }
-        final String hashAlgorithm = algorithm.substring(0, withIndex);
-        final MessageDigest messageDigest;
-        // https://www.rfc-editor.org/rfc/rfc5929#section-4.1
-        if ("MD5".equalsIgnoreCase(hashAlgorithm) || "SHA1".equalsIgnoreCase(hashAlgorithm)) {
-            messageDigest = new SHA256.Digest();
-        } else {
-            try {
-                messageDigest = MessageDigest.getInstance(hashAlgorithm);
-            } catch (final NoSuchAlgorithmException e) {
-                throw new AuthenticationException(
-                        "Could not instantiate message digest for " + hashAlgorithm);
-            }
-        }
-        final byte[] encodedCertificate;
-        try {
-            encodedCertificate = certificate.getEncoded();
-        } catch (final CertificateEncodingException e) {
-            throw new AuthenticationException("Could not encode certificate");
-        }
-        messageDigest.update(encodedCertificate);
-        return messageDigest.digest();
+        return ChannelBindingMechanism.getChannelBindingData(sslSocket, this.channelBinding);
     }
 
     @Override

src/main/java/eu/siacs/conversations/entities/Account.java 🔗

@@ -26,6 +26,7 @@ import eu.siacs.conversations.crypto.PgpDecryptionService;
 import eu.siacs.conversations.crypto.axolotl.AxolotlService;
 import eu.siacs.conversations.crypto.axolotl.XmppAxolotlSession;
 import eu.siacs.conversations.crypto.sasl.ChannelBinding;
+import eu.siacs.conversations.crypto.sasl.ChannelBindingMechanism;
 import eu.siacs.conversations.crypto.sasl.HashedToken;
 import eu.siacs.conversations.crypto.sasl.HashedTokenSha256;
 import eu.siacs.conversations.crypto.sasl.HashedTokenSha512;
@@ -348,9 +349,9 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
 
     public void setPinnedMechanism(final SaslMechanism mechanism) {
         this.pinnedMechanism = mechanism.getMechanism();
-        if (mechanism instanceof ScramPlusMechanism) {
+        if (mechanism instanceof ChannelBindingMechanism) {
             this.pinnedChannelBinding =
-                    ((ScramPlusMechanism) mechanism).getChannelBinding().toString();
+                    ((ChannelBindingMechanism) mechanism).getChannelBinding().toString();
         } else {
             this.pinnedChannelBinding = null;
         }
@@ -386,7 +387,7 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
         return new SaslMechanism.Factory(this).of(mechanism, channelBinding);
     }
 
-    private HashedToken getFastMechanism() {
+    public HashedToken getFastMechanism() {
         final HashedToken.Mechanism fastMechanism = HashedToken.Mechanism.ofOrNull(this.fastMechanism);
         final String token = this.fastToken;
         if (fastMechanism == null || Strings.isNullOrEmpty(token)) {

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -14,6 +14,7 @@ import android.util.Pair;
 import android.util.SparseArray;
 
 import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
 
 import com.google.common.base.Strings;
 
@@ -704,7 +705,9 @@ public class XmppConnection implements Runnable {
         Log.d(
                 Config.LOGTAG,
                 account.getJid().asBareJid().toString() + ": logged in (using " + version + ")");
-        account.setPinnedMechanism(saslMechanism);
+        if (SaslMechanism.pin(this.saslMechanism)) {
+            account.setPinnedMechanism(this.saslMechanism);
+        }
         if (version == SaslMechanism.Version.SASL_2) {
             final Tag tag = tagReader.readTag();
             if (tag != null && tag.isStart("features", Namespace.STREAMS)) {
@@ -837,6 +840,7 @@ public class XmppConnection implements Runnable {
         }
         Log.d(Config.LOGTAG,failure.toString());
         Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": login failure " + version);
+        //TODO check if we are doing FAST; reset token
         if (failure.hasChild("temporary-auth-failure")) {
             throw new StateChangingException(Account.State.TEMPORARY_AUTH_FAILURE);
         } else if (failure.hasChild("account-disabled")) {
@@ -1242,6 +1246,7 @@ public class XmppConnection implements Runnable {
                         account.getJid().asBareJid()
                                 + ": quick start in progress. ignoring features: "
                                 + XmlHelper.printElementNames(this.streamFeatures));
+                //TODO check if 'fast' is available but we are doing something else
                 return;
             }
             Log.d(Config.LOGTAG,account.getJid().asBareJid()+": server lost support for SASL 2. quick start not possible");
@@ -1320,37 +1325,12 @@ public class XmppConnection implements Runnable {
         final Element cbElement =
                 this.streamFeatures.findChild("sasl-channel-binding", Namespace.CHANNEL_BINDING);
         final Collection<ChannelBinding> channelBindings = ChannelBinding.of(cbElement);
-        Log.d(Config.LOGTAG,"mechanisms: "+mechanisms);
-        Log.d(Config.LOGTAG, "channel bindings: " + channelBindings);
         final SaslMechanism.Factory factory = new SaslMechanism.Factory(account);
-        this.saslMechanism = factory.of(mechanisms, channelBindings, SSLSockets.version(this.socket));
-
-        //TODO externalize checks
-
-        if (saslMechanism == null) {
-            Log.d(
-                    Config.LOGTAG,
-                    account.getJid().asBareJid()
-                            + ": unable to find supported SASL mechanism in "
-                            + mechanisms);
-            throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
-        }
-        final int pinnedMechanism = account.getPinnedMechanismPriority();
-        if (pinnedMechanism > saslMechanism.getPriority()) {
-            Log.e(
-                    Config.LOGTAG,
-                    "Auth failed. Authentication mechanism "
-                            + saslMechanism.getMechanism()
-                            + " has lower priority ("
-                            + saslMechanism.getPriority()
-                            + ") than pinned priority ("
-                            + pinnedMechanism
-                            + "). Possible downgrade attack?");
-            throw new StateChangingException(Account.State.DOWNGRADE_ATTACK);
-        }
+        final SaslMechanism saslMechanism = factory.of(mechanisms, channelBindings, version, SSLSockets.version(this.socket));
+        this.saslMechanism = validate(saslMechanism, mechanisms);
         final boolean quickStartAvailable;
-        final String firstMessage = saslMechanism.getClientFirstMessage();
-        final boolean usingFast = saslMechanism instanceof HashedToken;
+        final String firstMessage = this.saslMechanism.getClientFirstMessage(sslSocketOrNull(this.socket));
+        final boolean usingFast = SaslMechanism.hashedToken(this.saslMechanism);
         final Element authenticate;
         if (version == SaslMechanism.Version.SASL) {
             authenticate = new Element("auth", Namespace.SASL);
@@ -1402,11 +1382,40 @@ public class XmppConnection implements Runnable {
                         + ": Authenticating with "
                         + version
                         + "/"
-                        + saslMechanism.getMechanism());
-        authenticate.setAttribute("mechanism", saslMechanism.getMechanism());
+                        + this.saslMechanism.getMechanism());
+        authenticate.setAttribute("mechanism", this.saslMechanism.getMechanism());
         tagWriter.writeElement(authenticate);
     }
 
+    @NonNull
+    private SaslMechanism validate(final @Nullable SaslMechanism saslMechanism, Collection<String> mechanisms) throws StateChangingException {
+        if (saslMechanism == null) {
+            Log.d(
+                    Config.LOGTAG,
+                    account.getJid().asBareJid()
+                            + ": unable to find supported SASL mechanism in "
+                            + mechanisms);
+            throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
+        }
+        if (SaslMechanism.hashedToken(saslMechanism)) {
+            return saslMechanism;
+        }
+        final int pinnedMechanism = account.getPinnedMechanismPriority();
+        if (pinnedMechanism > saslMechanism.getPriority()) {
+            Log.e(
+                    Config.LOGTAG,
+                    "Auth failed. Authentication mechanism "
+                            + saslMechanism.getMechanism()
+                            + " has lower priority ("
+                            + saslMechanism.getPriority()
+                            + ") than pinned priority ("
+                            + pinnedMechanism
+                            + "). Possible downgrade attack?");
+            throw new StateChangingException(Account.State.DOWNGRADE_ATTACK);
+        }
+        return saslMechanism;
+    }
+
     private Element generateAuthenticationRequest(final String firstMessage, final boolean usingFast) {
         return generateAuthenticationRequest(firstMessage, usingFast, null, Bind2.QUICKSTART_FEATURES, true);
     }
@@ -2093,7 +2102,7 @@ public class XmppConnection implements Runnable {
             this.saslMechanism = quickStartMechanism;
             final boolean usingFast = quickStartMechanism instanceof HashedToken;
             final Element authenticate =
-                    generateAuthenticationRequest(quickStartMechanism.getClientFirstMessage(), usingFast);
+                    generateAuthenticationRequest(quickStartMechanism.getClientFirstMessage(sslSocketOrNull(this.socket)), usingFast);
             authenticate.setAttribute("mechanism", quickStartMechanism.getMechanism());
             sendStartStream(true, false);
             tagWriter.writeElement(authenticate);