implement XEP-0474: SASL SCRAM Downgrade Protection

Daniel Gultsch created

Change summary

conversations.doap                                                           |  7 
src/main/java/eu/siacs/conversations/crypto/sasl/DowngradeProtection.java    | 98 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java         | 28 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java                | 12 
src/main/java/im/conversations/android/xmpp/model/cb/SaslChannelBinding.java |  8 
5 files changed, 149 insertions(+), 4 deletions(-)

Detailed changes

conversations.doap 🔗

@@ -490,6 +490,13 @@
             <xmpp:version>0.1.0</xmpp:version>
         </xmpp:SupportedXep>
     </implements>
+    <implements>
+        <xmpp:SupportedXep>
+            <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0474.html"/>
+            <xmpp:status>complete</xmpp:status>
+            <xmpp:version>0.3.1</xmpp:version>
+        </xmpp:SupportedXep>
+    </implements>
     <implements>
         <xmpp:SupportedXep>
             <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0484.html"/>

src/main/java/eu/siacs/conversations/crypto/sasl/DowngradeProtection.java 🔗

@@ -0,0 +1,98 @@
+package eu.siacs.conversations.crypto.sasl;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Joiner;
+import com.google.common.base.Strings;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Ordering;
+import java.util.Collection;
+
+public class DowngradeProtection {
+
+    private static final char SEPARATOR = ',';
+    private static final char SEPARATOR_MECHANISM_AND_BINDING = '|';
+
+    public final ImmutableList<String> mechanisms;
+    public final ImmutableList<String> channelBindings;
+
+    public DowngradeProtection(
+            final Collection<String> mechanisms, final Collection<String> channelBindings) {
+        this.mechanisms = Ordering.natural().immutableSortedCopy(mechanisms);
+        this.channelBindings = Ordering.natural().immutableSortedCopy(channelBindings);
+    }
+
+    public DowngradeProtection(final Collection<String> mechanisms) {
+        this.mechanisms = Ordering.natural().immutableSortedCopy(mechanisms);
+        this.channelBindings = null;
+    }
+
+    public String asDString() {
+        ensureSaslMechanismFormat(this.mechanisms);
+        ensureNoSeparators(this.mechanisms);
+        if (this.channelBindings != null) {
+            ensureNoSeparators(this.channelBindings);
+            ensureBindingFormat(this.channelBindings);
+            final var builder = new StringBuilder();
+            Joiner.on(SEPARATOR).appendTo(builder, mechanisms);
+            builder.append(SEPARATOR_MECHANISM_AND_BINDING);
+            Joiner.on(SEPARATOR).appendTo(builder, channelBindings);
+            return builder.toString();
+        } else {
+            return Joiner.on(SEPARATOR).join(mechanisms);
+        }
+    }
+
+    private static void ensureNoSeparators(final Iterable<String> list) {
+        for (final String item : list) {
+            if (item.indexOf(SEPARATOR) >= 0
+                    || item.indexOf(SEPARATOR_MECHANISM_AND_BINDING) >= 0) {
+                throw new SecurityException("illegal chars found in list");
+            }
+        }
+    }
+
+    private static void ensureSaslMechanismFormat(final Iterable<String> names) {
+        for (final String name : names) {
+            ensureSaslMechanismFormat(name);
+        }
+    }
+
+    private static void ensureSaslMechanismFormat(final String name) {
+        if (Strings.isNullOrEmpty(name)) {
+            throw new SecurityException("Empty sasl mechanism names are not permitted");
+        }
+        // https://www.rfc-editor.org/rfc/rfc4422.html#section-3.1
+        if (name.length() <= 20
+                && CharMatcher.inRange('A', 'Z')
+                        .or(CharMatcher.inRange('0', '9'))
+                        .or(CharMatcher.is('-'))
+                        .or(CharMatcher.is('_'))
+                        .matchesAllOf(name)
+                && !Character.isDigit(name.charAt(0))) {
+            return;
+        }
+        throw new SecurityException("Encountered illegal sasl name");
+    }
+
+    private static void ensureBindingFormat(final Iterable<String> names) {
+        for (final String name : names) {
+            ensureBindingFormat(name);
+        }
+    }
+
+    private static void ensureBindingFormat(final String name) {
+        if (Strings.isNullOrEmpty(name)) {
+            throw new SecurityException("Empty binding names are not permitted");
+        }
+        // https://www.rfc-editor.org/rfc/rfc5056.html#section-7d
+        if (CharMatcher.inRange('A', 'Z')
+                .or(CharMatcher.inRange('a', 'z'))
+                .or(CharMatcher.inRange('0', '9'))
+                .or(CharMatcher.is('.'))
+                .or(CharMatcher.is('-'))
+                .matchesAllOf(name)) {
+            return;
+        }
+        throw new SecurityException("Encountered illegal binding name");
+    }
+}

src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java 🔗

@@ -3,6 +3,7 @@ package eu.siacs.conversations.crypto.sasl;
 import com.google.common.base.CaseFormat;
 import com.google.common.base.Joiner;
 import com.google.common.base.Objects;
+import com.google.common.base.Preconditions;
 import com.google.common.base.Splitter;
 import com.google.common.base.Strings;
 import com.google.common.cache.Cache;
@@ -20,7 +21,7 @@ import java.util.concurrent.ExecutionException;
 import javax.crypto.SecretKey;
 import javax.net.ssl.SSLSocket;
 
-abstract class ScramMechanism extends SaslMechanism {
+public abstract class ScramMechanism extends SaslMechanism {
 
     public static final SecretKey EMPTY_KEY =
             new SecretKey() {
@@ -50,6 +51,7 @@ abstract class ScramMechanism extends SaslMechanism {
     protected State state = State.INITIAL;
     private final String clientFirstMessageBare;
     private byte[] serverSignature = null;
+    private DowngradeProtection downgradeProtection = null;
 
     ScramMechanism(final Account account, final ChannelBinding channelBinding) {
         super(account);
@@ -76,6 +78,12 @@ abstract class ScramMechanism extends SaslMechanism {
                         this.clientNonce);
     }
 
+    public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
+        Preconditions.checkState(
+                this.state == State.INITIAL, "setting downgrade protection in invalid state");
+        this.downgradeProtection = downgradeProtection;
+    }
+
     protected abstract HashFunction getHMac(final byte[] key);
 
     protected abstract HashFunction getDigest();
@@ -128,9 +136,8 @@ abstract class ScramMechanism extends SaslMechanism {
 
     @Override
     public String getClientFirstMessage(final SSLSocket sslSocket) {
-        if (this.state != State.INITIAL) {
-            throw new IllegalArgumentException("Calling getClientFirstMessage from invalid state");
-        }
+        Preconditions.checkState(
+                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
         this.state = State.AUTH_TEXT_SENT;
         final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
         return BaseEncoding.base64().encode(message);
@@ -198,6 +205,19 @@ abstract class ScramMechanism extends SaslMechanism {
             throw new AuthenticationException("Invalid salt in server first message");
         }
 
+        if (d != null && this.downgradeProtection != null) {
+            final String asSeenInFeatures;
+            try {
+                asSeenInFeatures = downgradeProtection.asDString();
+            } catch (final SecurityException e) {
+                throw new AuthenticationException(e);
+            }
+            final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
+            if (!hashed.equals(d)) {
+                throw new AuthenticationException("Mismatch in SSDP");
+            }
+        }
+
         final byte[] channelBindingData = getChannelBindingData(socket);
 
         final int gs2Len = this.gs2Header.getBytes().length;

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -29,8 +29,10 @@ import eu.siacs.conversations.crypto.XmppDomainVerifier;
 import eu.siacs.conversations.crypto.axolotl.AxolotlService;
 import eu.siacs.conversations.crypto.sasl.ChannelBinding;
 import eu.siacs.conversations.crypto.sasl.ChannelBindingMechanism;
+import eu.siacs.conversations.crypto.sasl.DowngradeProtection;
 import eu.siacs.conversations.crypto.sasl.HashedToken;
 import eu.siacs.conversations.crypto.sasl.SaslMechanism;
+import eu.siacs.conversations.crypto.sasl.ScramMechanism;
 import eu.siacs.conversations.entities.Account;
 import eu.siacs.conversations.entities.Message;
 import eu.siacs.conversations.entities.ServiceDiscoveryResult;
@@ -1558,6 +1560,16 @@ public class XmppConnection implements Runnable {
         final SaslMechanism saslMechanism =
                 factory.of(mechanisms, channelBindings, version, SSLSockets.version(this.socket));
         this.validate(saslMechanism, mechanisms);
+        final DowngradeProtection downgradeProtection;
+        if (cbExtension != null) {
+            downgradeProtection =
+                    new DowngradeProtection(mechanisms, cbExtension.getChannelBindingTypes());
+        } else {
+            downgradeProtection = new DowngradeProtection(mechanisms);
+        }
+        if (saslMechanism instanceof ScramMechanism scramMechanism) {
+            scramMechanism.setDowngradeProtection(downgradeProtection);
+        }
         final boolean quickStartAvailable;
         final String firstMessage =
                 saslMechanism.getClientFirstMessage(sslSocketOrNull(this.socket));

src/main/java/im/conversations/android/xmpp/model/cb/SaslChannelBinding.java 🔗

@@ -1,5 +1,7 @@
 package im.conversations.android.xmpp.model.cb;
 
+import com.google.common.base.Predicates;
+import com.google.common.collect.Collections2;
 import im.conversations.android.annotation.XmlElement;
 import im.conversations.android.xmpp.model.StreamFeature;
 import java.util.Collection;
@@ -14,4 +16,10 @@ public class SaslChannelBinding extends StreamFeature {
     public Collection<ChannelBinding> getChannelBindings() {
         return this.getExtensions(ChannelBinding.class);
     }
+
+    public Collection<String> getChannelBindingTypes() {
+        return Collections2.filter(
+                Collections2.transform(getChannelBindings(), ChannelBinding::getType),
+                Predicates.notNull());
+    }
 }