Cache SCRAM-SHA-1 keys for current session

Sam Whited created

Change summary

src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java | 57 ++
src/main/java/eu/siacs/conversations/utils/CryptoHelper.java    |  4 
2 files changed, 51 insertions(+), 10 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java 🔗

@@ -1,6 +1,7 @@
 package eu.siacs.conversations.crypto.sasl;
 
 import android.util.Base64;
+import android.util.LruCache;
 
 import org.bouncycastle.crypto.Digest;
 import org.bouncycastle.crypto.digests.SHA1Digest;
@@ -28,9 +29,40 @@ public class ScramSha1 extends SaslMechanism {
 	private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 	private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 
+	public static class KeyPair {
+		final public byte[] clientKey;
+		final public byte[] serverKey;
+
+		public KeyPair(final byte[] clientKey, final byte[] serverKey) {
+			this.clientKey = clientKey;
+			this.serverKey = serverKey;
+		}
+	}
+
+	private static final LruCache<String, KeyPair> CACHE;
+
 	static {
 		DIGEST = new SHA1Digest();
 		HMAC = new HMac(new SHA1Digest());
+		CACHE = new LruCache<String, KeyPair>(10) {
+			protected KeyPair create(final String k) {
+				// Map keys are "bytesToHex(JID),bytesToHex(password),bytesToHex(salt),iterations".
+				// Changing any of these values forces a cache miss. `CryptoHelper.bytesToHex()'
+				// is applied to prevent commas in the strings breaking things.
+				final String[] kparts = k.split(",", 4);
+				try {
+					final byte[] saltedPassword, serverKey, clientKey;
+					saltedPassword = hi(CryptoHelper.saslPrep(CryptoHelper.hexToString(kparts[1])).getBytes(),
+							Base64.decode(CryptoHelper.hexToString(kparts[2]), Base64.DEFAULT), Integer.valueOf(kparts[3]));
+					serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
+					clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
+
+					return new KeyPair(clientKey, serverKey);
+				} catch (final InvalidKeyException | NumberFormatException e) {
+					return null;
+				}
+			}
+		};
 	}
 
 	private State state = State.INITIAL;
@@ -118,15 +150,20 @@ public class ScramSha1 extends SaslMechanism {
 				final byte[] authMessage = (clientFirstMessageBare + ',' + new String(serverFirstMessage) + ','
 						+ clientFinalMessageWithoutProof).getBytes();
 
-				// TODO: In future, cache the clientKey and serverKey and re-use them on re-auth.
-				final byte[] saltedPassword, clientSignature, serverKey, clientKey;
+				// Map keys are "bytesToHex(JID),bytesToHex(password),bytesToHex(salt),iterations".
+				final KeyPair keys = CACHE.get(
+						CryptoHelper.bytesToHex(account.getJid().toBareJid().toString().getBytes()) + ","
+						+ CryptoHelper.bytesToHex(account.getPassword().getBytes()) + ","
+						+ CryptoHelper.bytesToHex(salt.getBytes()) + ","
+						+ String.valueOf(iterationCount)
+						);
+				if (keys == null) {
+					throw new AuthenticationException("Invalid keys generated");
+				}
+				final byte[] clientSignature;
 				try {
-					saltedPassword = hi(CryptoHelper.saslPrep(account.getPassword()).getBytes(),
-							Base64.decode(salt, Base64.DEFAULT), iterationCount);
-					serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
-					serverSignature = hmac(serverKey, authMessage);
-					clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
-					final byte[] storedKey = digest(clientKey);
+					serverSignature = hmac(keys.serverKey, authMessage);
+					final byte[] storedKey = digest(keys.clientKey);
 
 					clientSignature = hmac(storedKey, authMessage);
 
@@ -134,10 +171,10 @@ public class ScramSha1 extends SaslMechanism {
 					throw new AuthenticationException(e);
 				}
 
-				final byte[] clientProof = new byte[clientKey.length];
+				final byte[] clientProof = new byte[keys.clientKey.length];
 
 				for (int i = 0; i < clientProof.length; i++) {
-					clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
+					clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
 				}
 
 

src/main/java/eu/siacs/conversations/utils/CryptoHelper.java 🔗

@@ -30,6 +30,10 @@ public class CryptoHelper {
 		return array;
 	}
 
+	public static String hexToString(final String hexString) {
+		return new String(hexToBytes(hexString));
+	}
+
 	public static byte[] concatenateByteArrays(byte[] a, byte[] b) {
 		byte[] result = new byte[a.length + b.length];
 		System.arraycopy(a, 0, result, 0, a.length);