ensure we only select channel binding methods available for tls version

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/Config.java                     |  2 
src/main/java/eu/siacs/conversations/crypto/sasl/ChannelBinding.java | 48 
src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java  | 58 
src/main/java/eu/siacs/conversations/utils/SSLSockets.java           | 46 
src/main/java/eu/siacs/conversations/utils/TLSSocketFactory.java     |  4 
src/main/java/eu/siacs/conversations/xml/Namespace.java              |  1 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java        | 88 
7 files changed, 181 insertions(+), 66 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/Config.java 🔗

@@ -60,7 +60,7 @@ public final class Config {
     public static final long CONTACT_SYNC_RETRY_INTERVAL = 1000L * 60 * 5;
 
 
-    public static final boolean SASL_2_ENABLED = true;
+    public static final boolean QUICKSTART_ENABLED = true;
 
     //Notification settings
     public static final boolean HIDE_MESSAGE_TEXT_IN_NOTIFICATION = false;

src/main/java/eu/siacs/conversations/crypto/sasl/ChannelBinding.java 🔗

@@ -3,11 +3,19 @@ package eu.siacs.conversations.crypto.sasl;
 import android.util.Log;
 
 import com.google.common.base.CaseFormat;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Predicates;
 import com.google.common.base.Strings;
+import com.google.common.collect.Collections2;
 
+import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 
 import eu.siacs.conversations.Config;
+import eu.siacs.conversations.utils.SSLSockets;
+import eu.siacs.conversations.xml.Element;
+import eu.siacs.conversations.xml.Namespace;
 
 public enum ChannelBinding {
     NONE,
@@ -15,7 +23,24 @@ public enum ChannelBinding {
     TLS_SERVER_END_POINT,
     TLS_UNIQUE;
 
-    public static ChannelBinding of(final String type) {
+    public static Collection<ChannelBinding> of(final Element channelBinding) {
+        Preconditions.checkArgument(
+                channelBinding == null
+                        || ("sasl-channel-binding".equals(channelBinding.getName())
+                                && Namespace.CHANNEL_BINDING.equals(channelBinding.getNamespace())),
+                "pass null or a valid channel binding stream feature");
+        return Collections2.filter(
+                Collections2.transform(
+                        Collections2.filter(
+                                channelBinding == null
+                                        ? Collections.emptyList()
+                                        : channelBinding.getChildren(),
+                                c -> c != null && "channel-binding".equals(c.getName())),
+                        c -> c == null ? null : ChannelBinding.of(c.getAttribute("type"))),
+                Predicates.notNull());
+    }
+
+    private static ChannelBinding of(final String type) {
         if (type == null) {
             return null;
         }
@@ -39,15 +64,28 @@ public enum ChannelBinding {
         }
     }
 
-    public static ChannelBinding best(final Collection<ChannelBinding> bindings) {
-        if (bindings.contains(TLS_EXPORTER)) {
+    public static ChannelBinding best(
+            final Collection<ChannelBinding> bindings, final SSLSockets.Version sslVersion) {
+        if (sslVersion == SSLSockets.Version.NONE) {
+            return NONE;
+        }
+        if (bindings.contains(TLS_EXPORTER) && sslVersion == SSLSockets.Version.TLS_1_3) {
             return TLS_EXPORTER;
-        } else if (bindings.contains(TLS_UNIQUE)) {
+        } else if (bindings.contains(TLS_UNIQUE)
+                && Arrays.asList(
+                                SSLSockets.Version.TLS_1_0,
+                                SSLSockets.Version.TLS_1_1,
+                                SSLSockets.Version.TLS_1_2)
+                        .contains(sslVersion)) {
             return TLS_UNIQUE;
         } else if (bindings.contains(TLS_SERVER_END_POINT)) {
             return TLS_SERVER_END_POINT;
         } else {
-            return null;
+            return NONE;
         }
     }
+
+    public static boolean ensureBest(final ChannelBinding channelBinding, final SSLSockets.Version sslVersion) {
+        return ChannelBinding.best(Collections.singleton(channelBinding), sslVersion) == channelBinding;
+    }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java 🔗

@@ -1,13 +1,19 @@
 package eu.siacs.conversations.crypto.sasl;
 
+import android.util.Log;
+
+import com.google.common.base.Preconditions;
 import com.google.common.base.Strings;
+import com.google.common.collect.Collections2;
 
 import java.util.Collection;
 import java.util.Collections;
 
 import javax.net.ssl.SSLSocket;
 
+import eu.siacs.conversations.Config;
 import eu.siacs.conversations.entities.Account;
+import eu.siacs.conversations.utils.SSLSockets;
 import eu.siacs.conversations.xml.Element;
 import eu.siacs.conversations.xml.Namespace;
 
@@ -47,6 +53,17 @@ public abstract class SaslMechanism {
         return "";
     }
 
+    public static Collection<String> mechanisms(final Element authElement) {
+        if (authElement == null) {
+            return Collections.emptyList();
+        }
+        return Collections2.transform(
+                Collections2.filter(
+                        authElement.getChildren(),
+                        c -> c != null && "mechanism".equals(c.getName())),
+                c -> c == null ? null : c.getContent());
+    }
+
     protected enum State {
         INITIAL,
         AUTH_TEXT_SENT,
@@ -102,16 +119,19 @@ public abstract class SaslMechanism {
             this.account = account;
         }
 
-        public SaslMechanism of(
-                final Collection<String> mechanisms, final Collection<ChannelBinding> bindings) {
-            final ChannelBinding channelBinding = ChannelBinding.best(bindings);
+        private SaslMechanism of(
+                final Collection<String> mechanisms, final ChannelBinding channelBinding) {
+            Preconditions.checkNotNull(channelBinding, "Use ChannelBinding.NONE instead of null");
             if (mechanisms.contains(External.MECHANISM) && account.getPrivateKeyAlias() != null) {
                 return new External(account);
-            } else if (mechanisms.contains(ScramSha512Plus.MECHANISM) && channelBinding != null) {
+            } else if (mechanisms.contains(ScramSha512Plus.MECHANISM)
+                    && channelBinding != ChannelBinding.NONE) {
                 return new ScramSha512Plus(account, channelBinding);
-            } else if (mechanisms.contains(ScramSha256Plus.MECHANISM) && channelBinding != null) {
+            } else if (mechanisms.contains(ScramSha256Plus.MECHANISM)
+                    && channelBinding != ChannelBinding.NONE) {
                 return new ScramSha256Plus(account, channelBinding);
-            } else if (mechanisms.contains(ScramSha1Plus.MECHANISM) && channelBinding != null) {
+            } else if (mechanisms.contains(ScramSha1Plus.MECHANISM)
+                    && channelBinding != ChannelBinding.NONE) {
                 return new ScramSha1Plus(account, channelBinding);
             } else if (mechanisms.contains(ScramSha512.MECHANISM)) {
                 return new ScramSha512(account);
@@ -131,9 +151,33 @@ public abstract class SaslMechanism {
             }
         }
 
+        public SaslMechanism of(
+                final Collection<String> mechanisms,
+                final Collection<ChannelBinding> bindings,
+                final SSLSockets.Version sslVersion) {
+            final ChannelBinding channelBinding = ChannelBinding.best(bindings, sslVersion);
+            return of(mechanisms, channelBinding);
+        }
+
         public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) {
-            return of(Collections.singleton(mechanism), Collections.singleton(channelBinding));
+            return of(Collections.singleton(mechanism), channelBinding);
         }
+    }
 
+    public static SaslMechanism ensureAvailable(
+            final SaslMechanism mechanism, final SSLSockets.Version sslVersion) {
+        if (mechanism instanceof ScramPlusMechanism) {
+            final ChannelBinding cb = ((ScramPlusMechanism) mechanism).getChannelBinding();
+            if (ChannelBinding.ensureBest(cb, sslVersion)) {
+                return mechanism;
+            } else {
+                Log.d(
+                        Config.LOGTAG,
+                        "pinned channel binding method " + cb + " no longer available");
+                return null;
+            }
+        } else {
+            return mechanism;
+        }
     }
 }

src/main/java/eu/siacs/conversations/utils/SSLSocketHelper.java → src/main/java/eu/siacs/conversations/utils/SSLSockets.java 🔗

@@ -5,9 +5,12 @@ import android.util.Log;
 
 import androidx.annotation.RequiresApi;
 
+import com.google.common.base.Strings;
+
 import org.conscrypt.Conscrypt;
 
 import java.lang.reflect.Method;
+import java.net.Socket;
 import java.nio.charset.StandardCharsets;
 import java.security.NoSuchAlgorithmException;
 import java.util.Arrays;
@@ -24,7 +27,7 @@ import javax.net.ssl.SSLSocket;
 import eu.siacs.conversations.Config;
 import eu.siacs.conversations.entities.Account;
 
-public class SSLSocketHelper {
+public class SSLSockets {
 
     public static void setSecurity(final SSLSocket sslSocket) {
         final String[] supportProtocols;
@@ -100,6 +103,45 @@ public class SSLSocketHelper {
 
     public static void log(Account account, SSLSocket socket) {
         SSLSession session = socket.getSession();
-        Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": protocol=" + session.getProtocol() + " cipher=" + session.getCipherSuite());
+        Log.d(
+                Config.LOGTAG,
+                account.getJid().asBareJid()
+                        + ": protocol="
+                        + session.getProtocol()
+                        + " cipher="
+                        + session.getCipherSuite());
+    }
+
+    public static Version version(final Socket socket) {
+        if (socket instanceof SSLSocket) {
+            final SSLSocket sslSocket = (SSLSocket) socket;
+            return Version.of(sslSocket.getSession().getProtocol());
+        } else {
+            return Version.NONE;
+        }
+    }
+
+    public enum Version {
+        TLS_1_0,
+        TLS_1_1,
+        TLS_1_2,
+        TLS_1_3,
+        UNKNOWN,
+        NONE;
+
+        private static Version of(final String protocol) {
+            switch (Strings.nullToEmpty(protocol)) {
+                case "TLSv1":
+                    return TLS_1_0;
+                case "TLSv1.1":
+                    return TLS_1_1;
+                case "TLSv1.2":
+                    return TLS_1_2;
+                case "TLSv1.3":
+                    return TLS_1_3;
+                default:
+                    return UNKNOWN;
+            }
+        }
     }
 }

src/main/java/eu/siacs/conversations/utils/TLSSocketFactory.java 🔗

@@ -17,7 +17,7 @@ public class TLSSocketFactory extends SSLSocketFactory {
     private final SSLSocketFactory internalSSLSocketFactory;
 
     public TLSSocketFactory(X509TrustManager[] trustManager, SecureRandom random) throws KeyManagementException, NoSuchAlgorithmException {
-        SSLContext context = SSLSocketHelper.getSSLContext();
+        SSLContext context = SSLSockets.getSSLContext();
         context.init(null, trustManager, random);
         this.internalSSLSocketFactory = context.getSocketFactory();
     }
@@ -59,7 +59,7 @@ public class TLSSocketFactory extends SSLSocketFactory {
 
     private static Socket enableTLSOnSocket(Socket socket) {
         if(socket instanceof SSLSocket) {
-            SSLSocketHelper.setSecurity((SSLSocket) socket);
+            SSLSockets.setSecurity((SSLSocket) socket);
         }
         return socket;
     }

src/main/java/eu/siacs/conversations/xml/Namespace.java 🔗

@@ -19,6 +19,7 @@ public final class Namespace {
     public static final String SASL = "urn:ietf:params:xml:ns:xmpp-sasl";
     public static final String SASL_2 = "urn:xmpp:sasl:2";
     public static final String CHANNEL_BINDING = "urn:xmpp:sasl-cb:0";
+    public static final String FAST = "urn:xmpp:fast:0";
     public static final String TLS = "urn:ietf:params:xml:ns:xmpp-tls";
     public static final String PUBSUB = "http://jabber.org/protocol/pubsub";
     public static final String PUBSUB_PUBLISH_OPTIONS = PUBSUB + "#publish-options";

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -15,9 +15,7 @@ import android.util.SparseArray;
 
 import androidx.annotation.NonNull;
 
-import com.google.common.base.Predicates;
 import com.google.common.base.Strings;
-import com.google.common.collect.Collections2;
 
 import org.xmlpull.v1.XmlPullParserException;
 
@@ -80,7 +78,7 @@ import eu.siacs.conversations.utils.CryptoHelper;
 import eu.siacs.conversations.utils.Patterns;
 import eu.siacs.conversations.utils.PhoneHelper;
 import eu.siacs.conversations.utils.Resolver;
-import eu.siacs.conversations.utils.SSLSocketHelper;
+import eu.siacs.conversations.utils.SSLSockets;
 import eu.siacs.conversations.utils.SocksSocketFactory;
 import eu.siacs.conversations.utils.XmlHelper;
 import eu.siacs.conversations.xml.Element;
@@ -494,10 +492,11 @@ public class XmppConnection implements Runnable {
         tagWriter.beginDocument();
         final boolean quickStart;
         if (socket instanceof SSLSocket) {
-            SSLSocketHelper.log(account, (SSLSocket) socket);
-            quickStart = establishStream(true);
+            final SSLSocket sslSocket = (SSLSocket) socket;
+            SSLSockets.log(account, sslSocket);
+            quickStart = establishStream(SSLSockets.version(sslSocket));
         } else {
-            quickStart = establishStream(false);
+            quickStart = establishStream(SSLSockets.Version.NONE);
         }
         final Tag tag = tagReader.readTag();
         if (Thread.currentThread().isInterrupted()) {
@@ -512,7 +511,7 @@ public class XmppConnection implements Runnable {
 
     private SSLSocketFactory getSSLSocketFactory()
             throws NoSuchAlgorithmException, KeyManagementException {
-        final SSLContext sc = SSLSocketHelper.getSSLContext();
+        final SSLContext sc = SSLSockets.getSSLContext();
         final MemorizingTrustManager trustManager =
                 this.mXmppConnectionService.getMemorizingTrustManager();
         final KeyManager[] keyManager;
@@ -720,7 +719,6 @@ public class XmppConnection implements Runnable {
                                 + ": server did not send stream features after SASL2 success");
                 throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
             }
-            Log.d(Config.LOGTAG, "success: " + success);
             final String authorizationIdentifier =
                     success.findChildContent("authorization-identifier");
             final Jid authorizationJid;
@@ -785,7 +783,7 @@ public class XmppConnection implements Runnable {
                     processEnabled(streamManagementEnabled);
                     waitForDisco = true;
                 } else {
-                    //if we didn’t enable stream managment in bind do it now
+                    //if we did not enable stream management in bind do it now
                     waitForDisco = enableStreamManagement();
                 }
                 if (carbonsEnabled != null) {
@@ -800,7 +798,7 @@ public class XmppConnection implements Runnable {
         this.quickStartInProgress = false;
         if (version == SaslMechanism.Version.SASL) {
             tagReader.reset();
-            sendStartStream(true);
+            sendStartStream(false, true);
             final Tag tag = tagReader.readTag();
             if (tag != null && tag.isStart("stream", Namespace.STREAMS)) {
                 processStream();
@@ -1163,7 +1161,7 @@ public class XmppConnection implements Runnable {
         Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": TLS connection established");
         final boolean quickStart;
         try {
-            quickStart = establishStream(true);
+            quickStart = establishStream(SSLSockets.version(sslSocket));
         } catch (final InterruptedException e) {
             return;
         }
@@ -1173,7 +1171,7 @@ public class XmppConnection implements Runnable {
         features.encryptionEnabled = true;
         final Tag tag = tagReader.readTag();
         if (tag != null && tag.isStart("stream", Namespace.STREAMS)) {
-            SSLSocketHelper.log(account, sslSocket);
+            SSLSockets.log(account, sslSocket);
             processStream();
         } else {
             throw new StateChangingException(Account.State.STREAM_OPENING_ERROR);
@@ -1193,9 +1191,9 @@ public class XmppConnection implements Runnable {
                 (SSLSocket)
                         sslSocketFactory.createSocket(
                                 socket, address.getHostAddress(), socket.getPort(), true);
-        SSLSocketHelper.setSecurity(sslSocket);
-        SSLSocketHelper.setHostname(sslSocket, IDN.toASCII(account.getServer()));
-        SSLSocketHelper.setApplicationProtocol(sslSocket, "xmpp-client");
+        SSLSockets.setSecurity(sslSocket);
+        SSLSockets.setHostname(sslSocket, IDN.toASCII(account.getServer()));
+        SSLSockets.setApplicationProtocol(sslSocket, "xmpp-client");
         final XmppDomainVerifier xmppDomainVerifier = new XmppDomainVerifier();
         try {
             if (!xmppDomainVerifier.verify(
@@ -1251,8 +1249,7 @@ public class XmppConnection implements Runnable {
         } else if (!this.streamFeatures.hasChild("register", Namespace.REGISTER_STREAM_FEATURE)
                 && account.isOptionSet(Account.OPTION_REGISTER)) {
             throw new StateChangingException(Account.State.REGISTRATION_NOT_SUPPORTED);
-        } else if (Config.SASL_2_ENABLED
-                && this.streamFeatures.hasChild("authentication", Namespace.SASL_2)
+        } else if (this.streamFeatures.hasChild("authentication", Namespace.SASL_2)
                 && shouldAuthenticate
                 && isSecure) {
             authenticate(SaslMechanism.Version.SASL_2);
@@ -1301,29 +1298,14 @@ public class XmppConnection implements Runnable {
         } else {
             authElement = this.streamFeatures.findChild("authentication", Namespace.SASL_2);
         }
-        //TODO externalize
-        final Collection<String> mechanisms =
-                Collections2.transform(
-                        Collections2.filter(
-                                authElement.getChildren(),
-                                c -> c != null && "mechanism".equals(c.getName())),
-                        c -> c == null ? null : c.getContent());
+        final Collection<String> mechanisms = SaslMechanism.mechanisms(authElement);
         final Element cbElement =
                 this.streamFeatures.findChild("sasl-channel-binding", Namespace.CHANNEL_BINDING);
-        final Collection<ChannelBinding> channelBindings =
-                Collections2.filter(
-                        Collections2.transform(
-                                Collections2.filter(
-                                        cbElement == null
-                                                ? Collections.emptyList()
-                                                : cbElement.getChildren(),
-                                        c -> c != null && "channel-binding".equals(c.getName())),
-                                c -> c == null ? null : ChannelBinding.of(c.getAttribute("type"))),
-                        Predicates.notNull());
+        final Collection<ChannelBinding> channelBindings = ChannelBinding.of(cbElement);
         Log.d(Config.LOGTAG,"mechanisms: "+mechanisms);
         Log.d(Config.LOGTAG, "channel bindings: " + channelBindings);
         final SaslMechanism.Factory factory = new SaslMechanism.Factory(account);
-        this.saslMechanism = factory.of(mechanisms, channelBindings);
+        this.saslMechanism = factory.of(mechanisms, channelBindings, SSLSockets.version(this.socket));
 
         //TODO externalize checks
 
@@ -1360,6 +1342,9 @@ public class XmppConnection implements Runnable {
         } else if (version == SaslMechanism.Version.SASL_2) {
             final Element inline = authElement.findChild("inline", Namespace.SASL_2);
             final boolean sm = inline != null && inline.hasChild("sm", "urn:xmpp:sm:3");
+            final Element fast = inline == null ? null : inline.findChild("fast", Namespace.FAST);
+            final Collection<String> fastMechanisms = SaslMechanism.mechanisms(fast);
+            Log.d(Config.LOGTAG,"fast mechanisms: "+fastMechanisms);
             final Collection<String> bindFeatures = Bind2.features(inline);
             quickStartAvailable =
                     sm
@@ -1434,12 +1419,11 @@ public class XmppConnection implements Runnable {
         Log.d(Config.LOGTAG, "inline bind features: " + bindFeatures);
         final Element bind = new Element("bind", Namespace.BIND2);
         bind.addChild("tag").setContent(mXmppConnectionService.getString(R.string.app_name));
-        final Element features = bind.addChild("features");
         if (bindFeatures.contains(Namespace.CARBONS)) {
-            features.addChild("enable", Namespace.CARBONS);
+            bind.addChild("enable", Namespace.CARBONS);
         }
         if (bindFeatures.contains(Namespace.STREAM_MANAGEMENT)) {
-            features.addChild(new EnablePacket());
+            bind.addChild(new EnablePacket());
         }
         return bind;
     }
@@ -2060,34 +2044,40 @@ public class XmppConnection implements Runnable {
         }
     }
 
-    private boolean establishStream(final boolean secureConnection) throws IOException, InterruptedException {
-        final SaslMechanism saslMechanism = account.getPinnedMechanism();
+    private boolean establishStream(final SSLSockets.Version sslVersion)
+            throws IOException, InterruptedException {
+        final SaslMechanism pinnedMechanism =
+                SaslMechanism.ensureAvailable(account.getPinnedMechanism(), sslVersion);
+        final boolean secureConnection = sslVersion != SSLSockets.Version.NONE;
         if (secureConnection
-                && Config.SASL_2_ENABLED
-                && saslMechanism != null
+                && Config.QUICKSTART_ENABLED
+                && pinnedMechanism != null
                 && account.isOptionSet(Account.OPTION_QUICKSTART_AVAILABLE)) {
             mXmppConnectionService.restoredFromDatabaseLatch.await();
-            this.saslMechanism = saslMechanism;
+            this.saslMechanism = pinnedMechanism;
             final Element authenticate =
-                    generateAuthenticationRequest(saslMechanism.getClientFirstMessage());
-            authenticate.setAttribute("mechanism", saslMechanism.getMechanism());
-            sendStartStream(false);
+                    generateAuthenticationRequest(pinnedMechanism.getClientFirstMessage());
+            authenticate.setAttribute("mechanism", pinnedMechanism.getMechanism());
+            sendStartStream(true, false);
             tagWriter.writeElement(authenticate);
             Log.d(
                     Config.LOGTAG,
                     account.getJid().toString()
                             + ": quick start with "
-                            + saslMechanism.getMechanism());
+                            + pinnedMechanism.getMechanism());
             return true;
         } else {
-            sendStartStream(true);
+            sendStartStream(secureConnection, true);
             return false;
         }
     }
 
-    private void sendStartStream(final boolean flush) throws IOException {
+    private void sendStartStream(final boolean from, final boolean flush) throws IOException {
         final Tag stream = Tag.start("stream:stream");
         stream.setAttribute("to", account.getServer());
+        if (from) {
+            stream.setAttribute("from", account.getJid().asBareJid().toEscapedString());
+        }
         stream.setAttribute("version", "1.0");
         stream.setAttribute("xml:lang", LocalizedContent.STREAM_LANGUAGE);
         stream.setAttribute("xmlns", "jabber:client");