add better state tracking to SASL

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/crypto/sasl/Anonymous.java      |  19 
src/main/java/eu/siacs/conversations/crypto/sasl/DigestMd5.java      | 241 
src/main/java/eu/siacs/conversations/crypto/sasl/External.java       |  17 
src/main/java/eu/siacs/conversations/crypto/sasl/Plain.java          |  32 
src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java  |  23 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java |   1 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java        |  17 
7 files changed, 237 insertions(+), 113 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/crypto/sasl/Anonymous.java 🔗

@@ -1,8 +1,9 @@
 package eu.siacs.conversations.crypto.sasl;
 
-import javax.net.ssl.SSLSocket;
-
+import com.google.common.base.Preconditions;
+import com.google.common.base.Strings;
 import eu.siacs.conversations.entities.Account;
+import javax.net.ssl.SSLSocket;
 
 public class Anonymous extends SaslMechanism {
 
@@ -24,6 +25,20 @@ public class Anonymous extends SaslMechanism {
 
     @Override
     public String getClientFirstMessage(final SSLSocket sslSocket) {
+        Preconditions.checkState(
+                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
+        this.state = State.AUTH_TEXT_SENT;
         return "";
     }
+
+    @Override
+    public String getResponse(final String challenge, final SSLSocket sslSocket)
+            throws AuthenticationException {
+        checkState(State.AUTH_TEXT_SENT);
+        if (Strings.isNullOrEmpty(challenge)) {
+            this.state = State.VALID_SERVER_RESPONSE;
+            return null;
+        }
+        throw new AuthenticationException("Unexpected server response");
+    }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/DigestMd5.java 🔗

@@ -1,20 +1,25 @@
 package eu.siacs.conversations.crypto.sasl;
 
-import android.util.Base64;
-
-import java.nio.charset.Charset;
-import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
-
-import javax.net.ssl.SSLSocket;
-
+import android.util.Log;
+import androidx.annotation.NonNull;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Splitter;
+import com.google.common.base.Strings;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.hash.Hashing;
+import com.google.common.io.BaseEncoding;
+import eu.siacs.conversations.Config;
 import eu.siacs.conversations.entities.Account;
 import eu.siacs.conversations.utils.CryptoHelper;
+import java.nio.charset.Charset;
+import java.util.Map;
+import javax.net.ssl.SSLSocket;
 
 public class DigestMd5 extends SaslMechanism {
 
     public static final String MECHANISM = "DIGEST-MD5";
     private State state = State.INITIAL;
+    private String precalculatedRSPAuth;
 
     public DigestMd5(final Account account) {
         super(account);
@@ -31,84 +36,150 @@ public class DigestMd5 extends SaslMechanism {
     }
 
     @Override
-    public String getResponse(final String challenge, final SSLSocket sslSocket)
+    public String getClientFirstMessage(final SSLSocket sslSocket) {
+        Preconditions.checkState(
+                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
+        this.state = State.AUTH_TEXT_SENT;
+        return "";
+    }
+
+    @Override
+    public String getResponse(final String challenge, final SSLSocket socket)
+            throws AuthenticationException {
+        return switch (state) {
+            case AUTH_TEXT_SENT -> processChallenge(challenge, socket);
+            case RESPONSE_SENT -> validateServerResponse(challenge);
+            case VALID_SERVER_RESPONSE -> validateUnnecessarySuccessMessage(challenge);
+            default -> throw new InvalidStateException(state);
+        };
+    }
+
+    // ejabberd sends the RSPAuth response as a challenge and then an empty success
+    // technically this is allowed as per https://datatracker.ietf.org/doc/html/rfc2222#section-5.2
+    // although it says to do that only if the profile of the protocol does not allow data to be put
+    // into success. which xmpp does allow. obviously
+    private String validateUnnecessarySuccessMessage(final String challenge)
+            throws AuthenticationException {
+        if (Strings.isNullOrEmpty(challenge)) {
+            return "";
+        }
+        throw new AuthenticationException("Success message must be empty");
+    }
+
+    private String validateServerResponse(final String challenge) throws AuthenticationException {
+        Log.d(Config.LOGTAG, "DigestMd5.validateServerResponse(" + challenge + ")");
+        final var attributes = messageToAttributes(challenge);
+        Log.d(Config.LOGTAG, "attributes: " + attributes);
+        final var rspauth = attributes.get("rspauth");
+        if (Strings.isNullOrEmpty(rspauth)) {
+            throw new AuthenticationException("no rspauth in server finish message");
+        }
+        final var expected = this.precalculatedRSPAuth;
+        if (Strings.isNullOrEmpty(expected) || !this.precalculatedRSPAuth.equals(rspauth)) {
+            throw new AuthenticationException("RSPAuth mismatch");
+        }
+        this.state = State.VALID_SERVER_RESPONSE;
+        return "";
+    }
+
+    private String processChallenge(final String challenge, final SSLSocket socket)
             throws AuthenticationException {
-        switch (state) {
-            case INITIAL:
-                state = State.RESPONSE_SENT;
-                final String encodedResponse;
-                try {
-                    final Tokenizer tokenizer =
-                            new Tokenizer(Base64.decode(challenge, Base64.DEFAULT));
-                    String nonce = "";
-                    for (final String token : tokenizer) {
-                        final String[] parts = token.split("=", 2);
-                        if (parts[0].equals("nonce")) {
-                            nonce = parts[1].replace("\"", "");
-                        } else if (parts[0].equals("rspauth")) {
-                            return "";
-                        }
-                    }
-                    final String digestUri = "xmpp/" + account.getServer();
-                    final String nonceCount = "00000001";
-                    final String x =
-                            account.getUsername()
-                                    + ":"
-                                    + account.getServer()
-                                    + ":"
-                                    + account.getPassword();
-                    final MessageDigest md = MessageDigest.getInstance("MD5");
-                    final byte[] y = md.digest(x.getBytes(Charset.defaultCharset()));
-                    final String cNonce = CryptoHelper.random(100);
-                    final byte[] a1 =
-                            CryptoHelper.concatenateByteArrays(
-                                    y,
-                                    (":" + nonce + ":" + cNonce)
-                                            .getBytes(Charset.defaultCharset()));
-                    final String a2 = "AUTHENTICATE:" + digestUri;
-                    final String ha1 = CryptoHelper.bytesToHex(md.digest(a1));
-                    final String ha2 =
-                            CryptoHelper.bytesToHex(
-                                    md.digest(a2.getBytes(Charset.defaultCharset())));
-                    final String kd =
-                            ha1 + ":" + nonce + ":" + nonceCount + ":" + cNonce + ":auth:" + ha2;
-                    final String response =
-                            CryptoHelper.bytesToHex(
-                                    md.digest(kd.getBytes(Charset.defaultCharset())));
-                    final String saslString =
-                            "username=\""
-                                    + account.getUsername()
-                                    + "\",realm=\""
-                                    + account.getServer()
-                                    + "\",nonce=\""
-                                    + nonce
-                                    + "\",cnonce=\""
-                                    + cNonce
-                                    + "\",nc="
-                                    + nonceCount
-                                    + ",qop=auth,digest-uri=\""
-                                    + digestUri
-                                    + "\",response="
-                                    + response
-                                    + ",charset=utf-8";
-                    encodedResponse =
-                            Base64.encodeToString(
-                                    saslString.getBytes(Charset.defaultCharset()), Base64.NO_WRAP);
-                } catch (final NoSuchAlgorithmException e) {
-                    throw new AuthenticationException(e);
-                }
-
-                return encodedResponse;
-            case RESPONSE_SENT:
-                state = State.VALID_SERVER_RESPONSE;
-                break;
-            case VALID_SERVER_RESPONSE:
-                if (challenge == null) {
-                    return null; // everything is fine
-                }
-            default:
-                throw new InvalidStateException(state);
+        Log.d(Config.LOGTAG, "DigestMd5.processChallenge()");
+        this.state = State.RESPONSE_SENT;
+        final var attributes = messageToAttributes(challenge);
+
+        final var nonce = attributes.get("nonce");
+
+        if (Strings.isNullOrEmpty(nonce)) {
+            throw new AuthenticationException("Server nonce missing");
+        }
+        final String digestUri = "xmpp/" + account.getServer();
+        final String nonceCount = "00000001";
+        final String x =
+                account.getUsername() + ":" + account.getServer() + ":" + account.getPassword();
+        final byte[] y = Hashing.md5().hashBytes(x.getBytes(Charset.defaultCharset())).asBytes();
+        final String cNonce = CryptoHelper.random(100);
+        final byte[] a1 =
+                CryptoHelper.concatenateByteArrays(
+                        y, (":" + nonce + ":" + cNonce).getBytes(Charset.defaultCharset()));
+        final String a2 = "AUTHENTICATE:" + digestUri;
+        final String ha1 = CryptoHelper.bytesToHex(Hashing.md5().hashBytes(a1).asBytes());
+        final String ha2 =
+                CryptoHelper.bytesToHex(
+                        Hashing.md5().hashBytes(a2.getBytes(Charset.defaultCharset())).asBytes());
+        final String kd = ha1 + ":" + nonce + ":" + nonceCount + ":" + cNonce + ":auth:" + ha2;
+
+        final String a2ForResponse = ":" + digestUri;
+        final String ha2ForResponse =
+                CryptoHelper.bytesToHex(
+                        Hashing.md5()
+                                .hashBytes(a2ForResponse.getBytes(Charset.defaultCharset()))
+                                .asBytes());
+        final String kdForResponseInput =
+                ha1 + ":" + nonce + ":" + nonceCount + ":" + cNonce + ":auth:" + ha2ForResponse;
+
+        this.precalculatedRSPAuth =
+                CryptoHelper.bytesToHex(
+                        Hashing.md5()
+                                .hashBytes(kdForResponseInput.getBytes(Charset.defaultCharset()))
+                                .asBytes());
+
+        final String response =
+                CryptoHelper.bytesToHex(
+                        Hashing.md5().hashBytes(kd.getBytes(Charset.defaultCharset())).asBytes());
+
+        final String saslString =
+                "username=\""
+                        + account.getUsername()
+                        + "\",realm=\""
+                        + account.getServer()
+                        + "\",nonce=\""
+                        + nonce
+                        + "\",cnonce=\""
+                        + cNonce
+                        + "\",nc="
+                        + nonceCount
+                        + ",qop=auth,digest-uri=\""
+                        + digestUri
+                        + "\",response="
+                        + response
+                        + ",charset=utf-8";
+        return BaseEncoding.base64().encode(saslString.getBytes());
+    }
+
+    private static Map<String, String> messageToAttributes(final String message)
+            throws AuthenticationException {
+        byte[] asBytes;
+        try {
+            asBytes = BaseEncoding.base64().decode(message);
+        } catch (final IllegalArgumentException e) {
+            throw new AuthenticationException("Unable to decode server challenge", e);
+        }
+        try {
+            return splitToAttributes(new String(asBytes));
+        } catch (final IllegalArgumentException e) {
+            throw new AuthenticationException("Duplicate attributes");
+        }
+    }
+
+    private static Map<String, String> splitToAttributes(final String message) {
+        final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
+        for (final String token : Splitter.on(',').split(message)) {
+            final var tuple = Splitter.on('=').limit(2).splitToList(token);
+            if (tuple.size() == 2) {
+                final var value = tuple.get(1);
+                builder.put(tuple.get(0), trimQuotes(value));
+            }
+        }
+        return builder.buildOrThrow();
+    }
+
+    public static String trimQuotes(@NonNull final String input) {
+        if (input.length() >= 2
+                && input.charAt(0) == '"'
+                && input.charAt(input.length() - 1) == '"') {
+            return input.substring(1, input.length() - 1);
         }
-        return null;
+        return input;
     }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/External.java 🔗

@@ -1,6 +1,7 @@
 package eu.siacs.conversations.crypto.sasl;
 
-import android.util.Base64;
+import com.google.common.base.Preconditions;
+import com.google.common.io.BaseEncoding;
 import eu.siacs.conversations.entities.Account;
 import javax.net.ssl.SSLSocket;
 
@@ -24,7 +25,17 @@ public class External extends SaslMechanism {
 
     @Override
     public String getClientFirstMessage(final SSLSocket sslSocket) {
-        return Base64.encodeToString(
-                account.getJid().asBareJid().toString().getBytes(), Base64.NO_WRAP);
+        Preconditions.checkState(
+                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
+        this.state = State.AUTH_TEXT_SENT;
+        final String message = account.getJid().asBareJid().toString();
+        return BaseEncoding.base64().encode(message.getBytes());
+    }
+
+    @Override
+    public String getResponse(String challenge, SSLSocket sslSocket)
+            throws AuthenticationException {
+        // TODO check that state is in auth text sent and move to finished
+        return "";
     }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/Plain.java 🔗

@@ -1,12 +1,10 @@
 package eu.siacs.conversations.crypto.sasl;
 
-import android.util.Base64;
-
-import java.nio.charset.Charset;
-
-import javax.net.ssl.SSLSocket;
-
+import com.google.common.base.Preconditions;
+import com.google.common.base.Strings;
+import com.google.common.io.BaseEncoding;
 import eu.siacs.conversations.entities.Account;
+import javax.net.ssl.SSLSocket;
 
 public class Plain extends SaslMechanism {
 
@@ -16,11 +14,6 @@ public class Plain extends SaslMechanism {
         super(account);
     }
 
-    public static String getMessage(String username, String password) {
-        final String message = '\u0000' + username + '\u0000' + password;
-        return Base64.encodeToString(message.getBytes(Charset.defaultCharset()), Base64.NO_WRAP);
-    }
-
     @Override
     public int getPriority() {
         return 10;
@@ -33,6 +26,21 @@ public class Plain extends SaslMechanism {
 
     @Override
     public String getClientFirstMessage(final SSLSocket sslSocket) {
-        return getMessage(account.getUsername(), account.getPassword());
+        Preconditions.checkState(
+                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
+        this.state = State.AUTH_TEXT_SENT;
+        final String message = '\u0000' + account.getUsername() + '\u0000' + account.getPassword();
+        return BaseEncoding.base64().encode(message.getBytes());
+    }
+
+    @Override
+    public String getResponse(final String challenge, final SSLSocket sslSocket)
+            throws AuthenticationException {
+        checkState(State.AUTH_TEXT_SENT);
+        if (Strings.isNullOrEmpty(challenge)) {
+            this.state = State.VALID_SERVER_RESPONSE;
+            return null;
+        }
+        throw new AuthenticationException("Unexpected server response");
     }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java 🔗

@@ -16,6 +16,8 @@ public abstract class SaslMechanism {
 
     protected final Account account;
 
+    protected State state = State.INITIAL;
+
     protected SaslMechanism(final Account account) {
         this.account = account;
     }
@@ -39,14 +41,10 @@ public abstract class SaslMechanism {
 
     public abstract String getMechanism();
 
-    public String getClientFirstMessage(final SSLSocket sslSocket) {
-        return "";
-    }
+    public abstract String getClientFirstMessage(final SSLSocket sslSocket);
 
-    public String getResponse(final String challenge, final SSLSocket sslSocket)
-            throws AuthenticationException {
-        return "";
-    }
+    public abstract String getResponse(final String challenge, final SSLSocket sslSocket)
+            throws AuthenticationException;
 
     public enum State {
         INITIAL,
@@ -55,6 +53,17 @@ public abstract class SaslMechanism {
         VALID_SERVER_RESPONSE,
     }
 
+    protected void checkState(final State expected) throws InvalidStateException {
+        final var current = this.state;
+        if (current == null) {
+            throw new InvalidStateException("Current state is null. Implementation problem");
+        }
+        if (current != expected) {
+            throw new InvalidStateException(
+                    String.format("State was %s. Expected %s", current, expected));
+        }
+    }
+
     public enum Version {
         SASL,
         SASL_2;

src/main/java/eu/siacs/conversations/crypto/sasl/ScramMechanism.java 🔗

@@ -48,7 +48,6 @@ public abstract class ScramMechanism extends SaslMechanism {
     protected final ChannelBinding channelBinding;
     private final String gs2Header;
     private final String clientNonce;
-    protected State state = State.INITIAL;
     private final String clientFirstMessageBare;
     private byte[] serverSignature = null;
     private DowngradeProtection downgradeProtection = null;

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -775,7 +775,9 @@ public class XmppConnection implements Runnable {
             throws IOException, XmlPullParserException {
         final LoginInfo currentLoginInfo = this.loginInfo;
         final SaslMechanism currentSaslMechanism = LoginInfo.mechanism(currentLoginInfo);
-        if (currentLoginInfo == null || currentSaslMechanism == null) {
+        if (currentLoginInfo == null
+                || LoginInfo.isSuccess(currentLoginInfo)
+                || currentSaslMechanism == null) {
             throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
         }
         final SaslMechanism.Version version;
@@ -987,9 +989,15 @@ public class XmppConnection implements Runnable {
         } catch (final IllegalArgumentException e) {
             throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
         }
+
+        final LoginInfo currentLoginInfo = this.loginInfo;
+        if (currentLoginInfo == null || LoginInfo.isSuccess(currentLoginInfo)) {
+            throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
+        }
+
         Log.d(Config.LOGTAG, failure.toString());
         Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": login failure " + version);
-        if (SaslMechanism.hashedToken(LoginInfo.mechanism(this.loginInfo))) {
+        if (SaslMechanism.hashedToken(LoginInfo.mechanism(currentLoginInfo))) {
             Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": resetting token");
             account.resetFastToken();
             mXmppConnectionService.databaseBackend.updateAccount(account);
@@ -1026,7 +1034,7 @@ public class XmppConnection implements Runnable {
                 }
             }
         }
-        if (SaslMechanism.hashedToken(LoginInfo.mechanism(this.loginInfo))) {
+        if (SaslMechanism.hashedToken(LoginInfo.mechanism(currentLoginInfo))) {
             Log.d(
                     Config.LOGTAG,
                     account.getJid().asBareJid()
@@ -2913,6 +2921,9 @@ public class XmppConnection implements Runnable {
 
         public void success(final String challenge, final SSLSocket sslSocket)
                 throws SaslMechanism.AuthenticationException {
+            if (Thread.currentThread().isInterrupted()) {
+                throw new SaslMechanism.AuthenticationException("Race condition during auth");
+            }
             final var response = this.saslMechanism.getResponse(challenge, sslSocket);
             if (!Strings.isNullOrEmpty(response)) {
                 throw new SaslMechanism.AuthenticationException(