store full sasl mechanism (not just priority)

Daniel Gultsch created

Change summary

src/conversations/java/eu/siacs/conversations/ui/MagicCreateActivity.java |  2 
src/main/java/eu/siacs/conversations/crypto/sasl/ChannelBinding.java      | 12 
src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java       |  6 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramPlusMechanism.java  |  9 
src/main/java/eu/siacs/conversations/entities/Account.java                | 72 
src/main/java/eu/siacs/conversations/persistance/DatabaseBackend.java     | 46 
src/main/java/eu/siacs/conversations/ui/EditAccountActivity.java          |  6 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java             |  7 
8 files changed, 117 insertions(+), 43 deletions(-)

Detailed changes

src/conversations/java/eu/siacs/conversations/ui/MagicCreateActivity.java 🔗

@@ -100,7 +100,7 @@ public class MagicCreateActivity extends XmppActivity implements TextWatcher {
                         account.setOption(Account.OPTION_MAGIC_CREATE, true);
                         account.setOption(Account.OPTION_FIXED_USERNAME, fixedUsername);
                         if (this.preAuth != null) {
-                            account.setKey(Account.PRE_AUTH_REGISTRATION_TOKEN, this.preAuth);
+                            account.setKey(Account.KEY_PRE_AUTH_REGISTRATION_TOKEN, this.preAuth);
                         }
                         xmppConnectionService.createAccount(account);
                     }

src/main/java/eu/siacs/conversations/crypto/sasl/ChannelBinding.java 🔗

@@ -3,6 +3,7 @@ package eu.siacs.conversations.crypto.sasl;
 import android.util.Log;
 
 import com.google.common.base.CaseFormat;
+import com.google.common.base.Strings;
 
 import java.util.Collection;
 
@@ -27,6 +28,17 @@ public enum ChannelBinding {
         }
     }
 
+    public static ChannelBinding get(final String name) {
+        if (Strings.isNullOrEmpty(name)) {
+            return NONE;
+        }
+        try {
+            return valueOf(name);
+        } catch (final IllegalArgumentException e) {
+            return NONE;
+        }
+    }
+
     public static ChannelBinding best(final Collection<ChannelBinding> bindings) {
         if (bindings.contains(TLS_EXPORTER)) {
             return TLS_EXPORTER;

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java 🔗

@@ -3,6 +3,7 @@ package eu.siacs.conversations.crypto.sasl;
 import com.google.common.base.Strings;
 
 import java.util.Collection;
+import java.util.Collections;
 
 import javax.net.ssl.SSLSocket;
 
@@ -129,5 +130,10 @@ public abstract class SaslMechanism {
                 return null;
             }
         }
+
+        public SaslMechanism of(final String mechanism, final ChannelBinding channelBinding) {
+            return of(Collections.singleton(mechanism), Collections.singleton(channelBinding));
+        }
+
     }
 }

src/main/java/eu/siacs/conversations/crypto/sasl/ScramPlusMechanism.java 🔗

@@ -16,7 +16,7 @@ import javax.net.ssl.SSLSocket;
 
 import eu.siacs.conversations.entities.Account;
 
-abstract class ScramPlusMechanism extends ScramMechanism {
+public abstract class ScramPlusMechanism extends ScramMechanism {
 
     private static final String EXPORTER_LABEL = "EXPORTER-Channel-Binding";
 
@@ -51,8 +51,7 @@ abstract class ScramPlusMechanism extends ScramMechanism {
             }
             return unique;
         } else if (this.channelBinding == ChannelBinding.TLS_SERVER_END_POINT) {
-            final byte[] endPoint = getServerEndPointChannelBinding(sslSocket.getSession());
-            return endPoint;
+            return getServerEndPointChannelBinding(sslSocket.getSession());
         } else {
             throw new AuthenticationException(
                     String.format("%s is not a valid channel binding", channelBinding));
@@ -103,4 +102,8 @@ abstract class ScramPlusMechanism extends ScramMechanism {
         messageDigest.update(encodedCertificate);
         return messageDigest.digest();
     }
+
+    public ChannelBinding getChannelBinding() {
+        return this.channelBinding;
+    }
 }

src/main/java/eu/siacs/conversations/entities/Account.java 🔗

@@ -25,6 +25,9 @@ import eu.siacs.conversations.R;
 import eu.siacs.conversations.crypto.PgpDecryptionService;
 import eu.siacs.conversations.crypto.axolotl.AxolotlService;
 import eu.siacs.conversations.crypto.axolotl.XmppAxolotlSession;
+import eu.siacs.conversations.crypto.sasl.ChannelBinding;
+import eu.siacs.conversations.crypto.sasl.SaslMechanism;
+import eu.siacs.conversations.crypto.sasl.ScramPlusMechanism;
 import eu.siacs.conversations.services.AvatarService;
 import eu.siacs.conversations.services.XmppConnectionService;
 import eu.siacs.conversations.utils.UIHelper;
@@ -50,9 +53,9 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
     public static final String STATUS = "status";
     public static final String STATUS_MESSAGE = "status_message";
     public static final String RESOURCE = "resource";
+    public static final String PINNED_MECHANISM = "pinned_mechanism";
+    public static final String PINNED_CHANNEL_BINDING = "pinned_channel_binding";
 
-    public static final String PINNED_MECHANISM_KEY = "pinned_mechanism";
-    public static final String PRE_AUTH_REGISTRATION_TOKEN = "pre_auth_registration";
 
     public static final int OPTION_USETLS = 0;
     public static final int OPTION_DISABLED = 1;
@@ -64,8 +67,13 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
     public static final int OPTION_HTTP_UPLOAD_AVAILABLE = 7;
     public static final int OPTION_UNVERIFIED = 8;
     public static final int OPTION_FIXED_USERNAME = 9;
+
     private static final String KEY_PGP_SIGNATURE = "pgp_signature";
     private static final String KEY_PGP_ID = "pgp_id";
+    private static final String KEY_PINNED_MECHANISM = "pinned_mechanism";
+    public static final String KEY_PRE_AUTH_REGISTRATION_TOKEN = "pre_auth_registration";
+
+
     protected final JSONObject keys;
     private final Roster roster = new Roster(this);
     private final Collection<Jid> blocklist = new CopyOnWriteArraySet<>();
@@ -90,18 +98,20 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
     private XmppConnection xmppConnection = null;
     private long mEndGracePeriod = 0L;
     private final Map<Jid, Bookmark> bookmarks = new HashMap<>();
-    private Presence.Status presenceStatus = Presence.Status.ONLINE;
-    private String presenceStatusMessage = null;
+    private Presence.Status presenceStatus;
+    private String presenceStatusMessage;
+    private String pinnedMechanism;
+    private String pinnedChannelBinding;
 
     public Account(final Jid jid, final String password) {
         this(java.util.UUID.randomUUID().toString(), jid,
-                password, 0, null, "", null, null, null, 5222, Presence.Status.ONLINE, null);
+                password, 0, null, "", null, null, null, 5222, Presence.Status.ONLINE, null, null, null);
     }
 
     private Account(final String uuid, final Jid jid,
                     final String password, final int options, final String rosterVersion, final String keys,
                     final String avatar, String displayName, String hostname, int port,
-                    final Presence.Status status, String statusMessage) {
+                    final Presence.Status status, String statusMessage, final String pinnedMechanism, final String pinnedChannelBinding) {
         this.uuid = uuid;
         this.jid = jid;
         this.password = password;
@@ -120,19 +130,21 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
         this.port = port;
         this.presenceStatus = status;
         this.presenceStatusMessage = statusMessage;
+        this.pinnedMechanism = pinnedMechanism;
+        this.pinnedChannelBinding = pinnedChannelBinding;
     }
 
     public static Account fromCursor(final Cursor cursor) {
         final Jid jid;
         try {
-            String resource = cursor.getString(cursor.getColumnIndexOrThrow(RESOURCE));
+            final String resource = cursor.getString(cursor.getColumnIndexOrThrow(RESOURCE));
             jid = Jid.of(
                     cursor.getString(cursor.getColumnIndexOrThrow(USERNAME)),
                     cursor.getString(cursor.getColumnIndexOrThrow(SERVER)),
                     resource == null || resource.trim().isEmpty() ? null : resource);
-        } catch (final IllegalArgumentException ignored) {
+        } catch (final IllegalArgumentException e) {
             Log.d(Config.LOGTAG, cursor.getString(cursor.getColumnIndexOrThrow(USERNAME)) + "@" + cursor.getString(cursor.getColumnIndexOrThrow(SERVER)));
-            throw new AssertionError(ignored);
+            throw new AssertionError(e);
         }
         return new Account(cursor.getString(cursor.getColumnIndexOrThrow(UUID)),
                 jid,
@@ -145,7 +157,9 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
                 cursor.getString(cursor.getColumnIndexOrThrow(HOSTNAME)),
                 cursor.getInt(cursor.getColumnIndexOrThrow(PORT)),
                 Presence.Status.fromShowString(cursor.getString(cursor.getColumnIndexOrThrow(STATUS))),
-                cursor.getString(cursor.getColumnIndexOrThrow(STATUS_MESSAGE)));
+                cursor.getString(cursor.getColumnIndexOrThrow(STATUS_MESSAGE)),
+                cursor.getString(cursor.getColumnIndexOrThrow(PINNED_MECHANISM)),
+                cursor.getString(cursor.getColumnIndexOrThrow(PINNED_CHANNEL_BINDING)));
     }
 
     public boolean httpUploadAvailable(long size) {
@@ -289,6 +303,38 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
         }
     }
 
+    public void setPinnedMechanism(final SaslMechanism mechanism) {
+        this.pinnedMechanism = mechanism.getMechanism();
+        if (mechanism instanceof ScramPlusMechanism) {
+            this.pinnedChannelBinding = ((ScramPlusMechanism) mechanism).getChannelBinding().toString();
+        }
+    }
+
+    public void resetPinnedMechanism() {
+        this.pinnedMechanism = null;
+        this.pinnedChannelBinding = null;
+        setKey(Account.KEY_PINNED_MECHANISM, String.valueOf(-1));
+    }
+
+    public int getPinnedMechanismPriority() {
+        final int fallback = getKeyAsInt(KEY_PINNED_MECHANISM, -1);
+        if (Strings.isNullOrEmpty(this.pinnedMechanism)) {
+            return fallback;
+        }
+        final SaslMechanism saslMechanism = getPinnedMechanism();
+        if (saslMechanism == null) {
+            return fallback;
+        } else {
+            return saslMechanism.getPriority();
+        }
+    }
+
+    public SaslMechanism getPinnedMechanism() {
+        final String mechanism = Strings.nullToEmpty(this.pinnedMechanism);
+        final ChannelBinding channelBinding = ChannelBinding.get(this.pinnedChannelBinding);
+        return new SaslMechanism.Factory(this).of(mechanism, channelBinding);
+    }
+
     public State getTrueStatus() {
         return this.status;
     }
@@ -361,8 +407,8 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
         }
     }
 
-    public boolean setPrivateKeyAlias(String alias) {
-        return setKey("private_key_alias", alias);
+    public void setPrivateKeyAlias(final String alias) {
+        setKey("private_key_alias", alias);
     }
 
     public String getPrivateKeyAlias() {
@@ -388,6 +434,8 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
         values.put(STATUS, presenceStatus.toShowString());
         values.put(STATUS_MESSAGE, presenceStatusMessage);
         values.put(RESOURCE, jid.getResource());
+        values.put(PINNED_MECHANISM, pinnedMechanism);
+        values.put(PINNED_CHANNEL_BINDING, pinnedChannelBinding);
         return values;
     }
 

src/main/java/eu/siacs/conversations/persistance/DatabaseBackend.java 🔗

@@ -64,7 +64,7 @@ import eu.siacs.conversations.xmpp.mam.MamReference;
 public class DatabaseBackend extends SQLiteOpenHelper {
 
     private static final String DATABASE_NAME = "history";
-    private static final int DATABASE_VERSION = 49;
+    private static final int DATABASE_VERSION = 50;
 
     private static boolean requiresMessageIndexRebuild = false;
     private static DatabaseBackend instance = null;
@@ -230,6 +230,8 @@ public class DatabaseBackend extends SQLiteOpenHelper {
                 + Account.KEYS + " TEXT, "
                 + Account.HOSTNAME + " TEXT, "
                 + Account.RESOURCE + " TEXT,"
+                + Account.PINNED_MECHANISM + " TEXT,"
+                + Account.PINNED_CHANNEL_BINDING + " TEXT,"
                 + Account.PORT + " NUMBER DEFAULT 5222)");
         db.execSQL("create table " + Conversation.TABLENAME + " ("
                 + Conversation.UUID + " TEXT PRIMARY KEY, " + Conversation.NAME
@@ -589,6 +591,11 @@ public class DatabaseBackend extends SQLiteOpenHelper {
             db.endTransaction();
             requiresMessageIndexRebuild = true;
         }
+        if (oldVersion < 50 && newVersion >= 50) {
+            db.execSQL("ALTER TABLE " + Account.TABLENAME + " ADD COLUMN " + Account.PINNED_MECHANISM + " TEXT");
+            db.execSQL("ALTER TABLE " + Account.TABLENAME + " ADD COLUMN " + Account.PINNED_CHANNEL_BINDING + " TEXT");
+
+        }
     }
 
     private void canonicalizeJids(SQLiteDatabase db) {
@@ -938,20 +945,19 @@ public class DatabaseBackend extends SQLiteOpenHelper {
                 contactJid.asBareJid().toString() + "/%",
                 contactJid.asBareJid().toString()
         };
-        Cursor cursor = db.query(Conversation.TABLENAME, null,
+        try(final Cursor cursor = db.query(Conversation.TABLENAME, null,
                 Conversation.ACCOUNT + "=? AND (" + Conversation.CONTACTJID
-                        + " like ? OR " + Conversation.CONTACTJID + "=?)", selectionArgs, null, null, null);
-        if (cursor.getCount() == 0) {
-            cursor.close();
-            return null;
-        }
-        cursor.moveToFirst();
-        Conversation conversation = Conversation.fromCursor(cursor);
-        cursor.close();
-        if (conversation.getJid() instanceof InvalidJid) {
-            return null;
+                        + " like ? OR " + Conversation.CONTACTJID + "=?)", selectionArgs, null, null, null)) {
+            if (cursor.getCount() == 0) {
+                return null;
+            }
+            cursor.moveToFirst();
+            final Conversation conversation = Conversation.fromCursor(cursor);
+            if (conversation.getJid() instanceof InvalidJid) {
+                return null;
+            }
+            return conversation;
         }
-        return conversation;
     }
 
     public void updateConversation(final Conversation conversation) {
@@ -1024,14 +1030,14 @@ public class DatabaseBackend extends SQLiteOpenHelper {
     }
 
     public void readRoster(Roster roster) {
-        SQLiteDatabase db = this.getReadableDatabase();
-        Cursor cursor;
-        String[] args = {roster.getAccount().getUuid()};
-        cursor = db.query(Contact.TABLENAME, null, Contact.ACCOUNT + "=?", args, null, null, null);
-        while (cursor.moveToNext()) {
-            roster.initContact(Contact.fromCursor(cursor));
+        final SQLiteDatabase db = this.getReadableDatabase();
+        final String[] args = {roster.getAccount().getUuid()};
+        try (final Cursor cursor =
+                db.query(Contact.TABLENAME, null, Contact.ACCOUNT + "=?", args, null, null, null)) {
+            while (cursor.moveToNext()) {
+                roster.initContact(Contact.fromCursor(cursor));
+            }
         }
-        cursor.close();
     }
 
     public void writeRoster(final Roster roster) {

src/main/java/eu/siacs/conversations/ui/EditAccountActivity.java 🔗

@@ -181,7 +181,7 @@ public class EditAccountActivity extends OmemoActivity implements OnAccountUpdat
             }
 
             if (inNeedOfSaslAccept()) {
-                mAccount.setKey(Account.PINNED_MECHANISM_KEY, String.valueOf(-1));
+                mAccount.resetPinnedMechanism();
                 if (!xmppConnectionService.updateAccount(mAccount)) {
                     Toast.makeText(EditAccountActivity.this, R.string.unable_to_update_account, Toast.LENGTH_SHORT).show();
                 }
@@ -421,7 +421,7 @@ public class EditAccountActivity extends OmemoActivity implements OnAccountUpdat
             } else {
                 preset = jid.getDomain();
             }
-            final Intent intent = SignupUtils.getTokenRegistrationIntent(this, preset, mAccount.getKey(Account.PRE_AUTH_REGISTRATION_TOKEN));
+            final Intent intent = SignupUtils.getTokenRegistrationIntent(this, preset, mAccount.getKey(Account.KEY_PRE_AUTH_REGISTRATION_TOKEN));
             StartConversationActivity.addInviteUri(intent, getIntent());
             startActivity(intent);
             return;
@@ -892,7 +892,7 @@ public class EditAccountActivity extends OmemoActivity implements OnAccountUpdat
     }
 
     private boolean inNeedOfSaslAccept() {
-        return mAccount != null && mAccount.getLastErrorStatus() == Account.State.DOWNGRADE_ATTACK && mAccount.getKeyAsInt(Account.PINNED_MECHANISM_KEY, -1) >= 0 && !accountInfoEdited();
+        return mAccount != null && mAccount.getLastErrorStatus() == Account.State.DOWNGRADE_ATTACK && mAccount.getPinnedMechanismPriority() >= 0 && !accountInfoEdited();
     }
 
     private void shareBarcode() {

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -692,8 +692,7 @@ public class XmppConnection implements Runnable {
         Log.d(
                 Config.LOGTAG,
                 account.getJid().asBareJid().toString() + ": logged in (using " + version + ")");
-        // TODO store mechanism name
-        account.setKey(Account.PINNED_MECHANISM_KEY, String.valueOf(saslMechanism.getPriority()));
+        account.setPinnedMechanism(saslMechanism);
         if (version == SaslMechanism.Version.SASL_2) {
             final String authorizationIdentifier =
                     success.findChildContent("authorization-identifier");
@@ -1264,7 +1263,7 @@ public class XmppConnection implements Runnable {
                             + mechanisms);
             throw new StateChangingException(Account.State.INCOMPATIBLE_SERVER);
         }
-        final int pinnedMechanism = account.getKeyAsInt(Account.PINNED_MECHANISM_KEY, -1);
+        final int pinnedMechanism = account.getPinnedMechanismPriority();
         if (pinnedMechanism > saslMechanism.getPriority()) {
             Log.e(
                     Config.LOGTAG,
@@ -1345,7 +1344,7 @@ public class XmppConnection implements Runnable {
     }
 
     private void register() {
-        final String preAuth = account.getKey(Account.PRE_AUTH_REGISTRATION_TOKEN);
+        final String preAuth = account.getKey(Account.KEY_PRE_AUTH_REGISTRATION_TOKEN);
         if (preAuth != null && features.invite()) {
             final IqPacket preAuthRequest = new IqPacket(IqPacket.TYPE.SET);
             preAuthRequest.addChild("preauth", Namespace.PARS).setAttribute("token", preAuth);