Add auth method pinning

Sam Whited created

Change summary

src/main/java/eu/siacs/conversations/crypto/sasl/DigestMd5.java     |  8 
src/main/java/eu/siacs/conversations/crypto/sasl/Plain.java         |  8 
src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java |  9 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java     |  8 
src/main/java/eu/siacs/conversations/entities/Account.java          |  2 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java       | 39 
6 files changed, 58 insertions(+), 16 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/crypto/sasl/DigestMd5.java 🔗

@@ -17,7 +17,13 @@ public class DigestMd5 extends SaslMechanism {
 		super(tagWriter, account, rng);
 	}
 
-	public static String getMechanism() {
+	@Override
+	public int getPriority() {
+		return 10;
+	}
+
+	@Override
+	public String getMechanism() {
 		return "DIGEST-MD5";
 	}
 

src/main/java/eu/siacs/conversations/crypto/sasl/Plain.java 🔗

@@ -12,7 +12,13 @@ public class Plain extends SaslMechanism {
 		super(tagWriter, account, null);
 	}
 
-	public static String getMechanism() {
+	@Override
+	public int getPriority() {
+		return 0;
+	}
+
+	@Override
+	public String getMechanism() {
 		return "PLAIN";
 	}
 

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java 🔗

@@ -44,6 +44,15 @@ public abstract class SaslMechanism {
 		this.rng = rng;
 	}
 
+	/**
+	 * The priority is used to pin the authentication mechanism. If authentication fails, it MAY be retried with another
+	 * mechanism of the same priority, but MUST NOT be tried with a mechanism of lower priority (to prevent downgrade
+	 * attacks).
+	 * @return An arbitrary int representing the priority
+	 */
+	public abstract int getPriority();
+
+	public abstract String getMechanism();
 	public String getClientFirstMessage() {
 		return "";
 	}

src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java 🔗

@@ -43,7 +43,13 @@ public class ScramSha1 extends SaslMechanism {
 		clientFirstMessageBare = "";
 	}
 
-	public static String getMechanism() {
+	@Override
+	public int getPriority() {
+		return 20;
+	}
+
+	@Override
+	public String getMechanism() {
 		return "SCRAM-SHA-1";
 	}
 

src/main/java/eu/siacs/conversations/entities/Account.java 🔗

@@ -34,6 +34,8 @@ public class Account extends AbstractEntity {
 	public static final String KEYS = "keys";
 	public static final String AVATAR = "avatar";
 
+	public static final String PINNED_MECHANISM_KEY = "pinned_mechanism";
+
 	public static final int OPTION_USETLS = 0;
 	public static final int OPTION_DISABLED = 1;
 	public static final int OPTION_REGISTER = 2;

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -12,6 +12,8 @@ import android.util.Log;
 import android.util.SparseArray;
 
 import org.apache.http.conn.ssl.StrictHostnameVerifier;
+import org.json.JSONException;
+import org.json.JSONObject;
 import org.xmlpull.v1.XmlPullParserException;
 
 import java.io.IOException;
@@ -291,6 +293,8 @@ public class XmppConnection implements Runnable {
 									Log.e(Config.LOGTAG, String.valueOf(e));
 								}
 								Log.d(Config.LOGTAG, account.getJid().toBareJid().toString() + ": logged in");
+								account.setKey(Account.PINNED_MECHANISM_KEY,
+										String.valueOf(saslMechanism.getPriority()));
 								tagReader.reset();
 								sendStartStream();
 								processStream(tagReader.readTag());
@@ -629,23 +633,32 @@ public class XmppConnection implements Runnable {
 					.findChild("mechanisms"));
 			final Element auth = new Element("auth");
 			auth.setAttribute("xmlns", "urn:ietf:params:xml:ns:xmpp-sasl");
-			if (mechanisms.contains(ScramSha1.getMechanism())) {
+			if (mechanisms.contains("SCRAM-SHA-1")) {
 				saslMechanism = new ScramSha1(tagWriter, account, mXmppConnectionService.getRNG());
-				Log.d(Config.LOGTAG, "Authenticating with " + ScramSha1.getMechanism());
-				auth.setAttribute("mechanism", ScramSha1.getMechanism());
-			} else if (mechanisms.contains(DigestMd5.getMechanism())) {
-				Log.d(Config.LOGTAG, "Authenticating with " + DigestMd5.getMechanism());
+			} else if (mechanisms.contains("DIGEST-MD5")) {
 				saslMechanism = new DigestMd5(tagWriter, account, mXmppConnectionService.getRNG());
-				auth.setAttribute("mechanism", DigestMd5.getMechanism());
-			} else if (mechanisms.contains(Plain.getMechanism())) {
-				Log.d(Config.LOGTAG, "Authenticating with " + Plain.getMechanism());
+			} else if (mechanisms.contains("PLAIN")) {
 				saslMechanism = new Plain(tagWriter, account);
-				auth.setAttribute("mechanism", Plain.getMechanism());
 			}
-            if (!saslMechanism.getClientFirstMessage().isEmpty()) {
-                auth.setContent(saslMechanism.getClientFirstMessage());
-            }
-            tagWriter.writeElement(auth);
+			final JSONObject keys = account.getKeys();
+			try {
+				if (keys.has(Account.PINNED_MECHANISM_KEY) &&
+						keys.getInt(Account.PINNED_MECHANISM_KEY) > saslMechanism.getPriority() ) {
+					Log.e(Config.LOGTAG, "Auth failed. Authentication mechanism " + saslMechanism.getMechanism() +
+							" has lower priority (" + String.valueOf(saslMechanism.getPriority()) +
+							") than pinned priority (" + keys.getInt(Account.PINNED_MECHANISM_KEY) +
+							"). Possible downgrade attack?");
+					disconnect(true);
+						}
+			} catch (final JSONException e) {
+				Log.d(Config.LOGTAG, "Parse error while checking pinned auth mechanism");
+			}
+			Log.d(Config.LOGTAG, "Authenticating with " + saslMechanism.getMechanism());
+			auth.setAttribute("mechanism", saslMechanism.getMechanism());
+			if (!saslMechanism.getClientFirstMessage().isEmpty()) {
+				auth.setContent(saslMechanism.getClientFirstMessage());
+			}
+			tagWriter.writeElement(auth);
 		} else if (this.streamFeatures.hasChild("sm", "urn:xmpp:sm:"
 					+ smVersion)
 				&& streamId != null) {