ScramSha1.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.util.Base64;
  4import android.util.LruCache;
  5
  6import org.bouncycastle.crypto.Digest;
  7import org.bouncycastle.crypto.digests.SHA1Digest;
  8import org.bouncycastle.crypto.macs.HMac;
  9import org.bouncycastle.crypto.params.KeyParameter;
 10
 11import java.math.BigInteger;
 12import java.nio.charset.Charset;
 13import java.security.InvalidKeyException;
 14import java.security.SecureRandom;
 15
 16import eu.siacs.conversations.entities.Account;
 17import eu.siacs.conversations.utils.CryptoHelper;
 18import eu.siacs.conversations.xml.TagWriter;
 19
 20public class ScramSha1 extends SaslMechanism {
 21	// TODO: When channel binding (SCRAM-SHA1-PLUS) is supported in future, generalize this to indicate support and/or usage.
 22	final private static String GS2_HEADER = "n,,";
 23	private String clientFirstMessageBare;
 24	private byte[] serverFirstMessage;
 25	final private String clientNonce;
 26	private byte[] serverSignature = null;
 27	private static HMac HMAC;
 28	private static Digest DIGEST;
 29	private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 30	private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 31
 32	public static class KeyPair {
 33		final public byte[] clientKey;
 34		final public byte[] serverKey;
 35
 36		public KeyPair(final byte[] clientKey, final byte[] serverKey) {
 37			this.clientKey = clientKey;
 38			this.serverKey = serverKey;
 39		}
 40	}
 41
 42	private static final LruCache<String, KeyPair> CACHE;
 43
 44	static {
 45		DIGEST = new SHA1Digest();
 46		HMAC = new HMac(new SHA1Digest());
 47		CACHE = new LruCache<String, KeyPair>(10) {
 48			protected KeyPair create(final String k) {
 49				// Map keys are "bytesToHex(JID),bytesToHex(password),bytesToHex(salt),iterations".
 50				// Changing any of these values forces a cache miss. `CryptoHelper.bytesToHex()'
 51				// is applied to prevent commas in the strings breaking things.
 52				final String[] kparts = k.split(",", 4);
 53				try {
 54					final byte[] saltedPassword, serverKey, clientKey;
 55					saltedPassword = hi(CryptoHelper.hexToString(kparts[1]).getBytes(),
 56							Base64.decode(CryptoHelper.hexToString(kparts[2]), Base64.DEFAULT), Integer.valueOf(kparts[3]));
 57					serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 58					clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 59
 60					return new KeyPair(clientKey, serverKey);
 61				} catch (final InvalidKeyException | NumberFormatException e) {
 62					return null;
 63				}
 64			}
 65		};
 66	}
 67
 68	private State state = State.INITIAL;
 69
 70	public ScramSha1(final TagWriter tagWriter, final Account account, final SecureRandom rng) {
 71		super(tagWriter, account, rng);
 72
 73		// This nonce should be different for each authentication attempt.
 74		clientNonce = new BigInteger(100, this.rng).toString(32);
 75		clientFirstMessageBare = "";
 76	}
 77
 78	@Override
 79	public int getPriority() {
 80		return 20;
 81	}
 82
 83	@Override
 84	public String getMechanism() {
 85		return "SCRAM-SHA-1";
 86	}
 87
 88	@Override
 89	public String getClientFirstMessage() {
 90		if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
 91			clientFirstMessageBare = "n=" + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())) +
 92				",r=" + this.clientNonce;
 93			state = State.AUTH_TEXT_SENT;
 94		}
 95		return Base64.encodeToString(
 96				(GS2_HEADER + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
 97				Base64.NO_WRAP);
 98	}
 99
100	@Override
101	public String getResponse(final String challenge) throws AuthenticationException {
102		switch (state) {
103			case AUTH_TEXT_SENT:
104				if (challenge == null) {
105					throw new AuthenticationException("challenge can not be null");
106				}
107				serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
108				final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
109				String nonce = "";
110				int iterationCount = -1;
111				String salt = "";
112				for (final String token : tokenizer) {
113					if (token.charAt(1) == '=') {
114						switch (token.charAt(0)) {
115							case 'i':
116								try {
117									iterationCount = Integer.parseInt(token.substring(2));
118								} catch (final NumberFormatException e) {
119									throw new AuthenticationException(e);
120								}
121								break;
122							case 's':
123								salt = token.substring(2);
124								break;
125							case 'r':
126								nonce = token.substring(2);
127								break;
128							case 'm':
129								/*
130								 * RFC 5802:
131								 * m: This attribute is reserved for future extensibility.  In this
132								 * version of SCRAM, its presence in a client or a server message
133								 * MUST cause authentication failure when the attribute is parsed by
134								 * the other end.
135								 */
136								throw new AuthenticationException("Server sent reserved token: `m'");
137						}
138					}
139				}
140
141				if (iterationCount < 0) {
142					throw new AuthenticationException("Server did not send iteration count");
143				}
144				if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
145					throw new AuthenticationException("Server nonce does not contain client nonce: " + nonce);
146				}
147				if (salt.isEmpty()) {
148					throw new AuthenticationException("Server sent empty salt");
149				}
150
151				final String clientFinalMessageWithoutProof = "c=" + Base64.encodeToString(
152						GS2_HEADER.getBytes(), Base64.NO_WRAP) + ",r=" + nonce;
153				final byte[] authMessage = (clientFirstMessageBare + ',' + new String(serverFirstMessage) + ','
154						+ clientFinalMessageWithoutProof).getBytes();
155
156				// Map keys are "bytesToHex(JID),bytesToHex(password),bytesToHex(salt),iterations".
157				final KeyPair keys = CACHE.get(
158						CryptoHelper.bytesToHex(account.getJid().toBareJid().toString().getBytes()) + ","
159						+ CryptoHelper.bytesToHex(account.getPassword().getBytes()) + ","
160						+ CryptoHelper.bytesToHex(salt.getBytes()) + ","
161						+ String.valueOf(iterationCount)
162						);
163				if (keys == null) {
164					throw new AuthenticationException("Invalid keys generated");
165				}
166				final byte[] clientSignature;
167				try {
168					serverSignature = hmac(keys.serverKey, authMessage);
169					final byte[] storedKey = digest(keys.clientKey);
170
171					clientSignature = hmac(storedKey, authMessage);
172
173				} catch (final InvalidKeyException e) {
174					throw new AuthenticationException(e);
175				}
176
177				final byte[] clientProof = new byte[keys.clientKey.length];
178
179				for (int i = 0; i < clientProof.length; i++) {
180					clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
181				}
182
183
184				final String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" +
185					Base64.encodeToString(clientProof, Base64.NO_WRAP);
186				state = State.RESPONSE_SENT;
187				return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
188			case RESPONSE_SENT:
189				final String clientCalculatedServerFinalMessage = "v=" +
190					Base64.encodeToString(serverSignature, Base64.NO_WRAP);
191				if (challenge == null || !clientCalculatedServerFinalMessage.equals(new String(Base64.decode(challenge, Base64.DEFAULT)))) {
192					throw new AuthenticationException("Server final message does not match calculated final message");
193				}
194				state = State.VALID_SERVER_RESPONSE;
195				return "";
196			default:
197				throw new InvalidStateException(state);
198		}
199	}
200
201	public static synchronized byte[] hmac(final byte[] key, final byte[] input)
202		throws InvalidKeyException {
203		HMAC.init(new KeyParameter(key));
204		HMAC.update(input, 0, input.length);
205		final byte[] out = new byte[HMAC.getMacSize()];
206		HMAC.doFinal(out, 0);
207		return out;
208	}
209
210	public static synchronized byte[] digest(byte[] bytes) {
211		DIGEST.reset();
212		DIGEST.update(bytes, 0, bytes.length);
213		final byte[] out = new byte[DIGEST.getDigestSize()];
214		DIGEST.doFinal(out, 0);
215		return out;
216	}
217
218	/*
219	 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
220	 * pseudorandom function (PRF) and with dkLen == output length of
221	 * HMAC() == output length of H().
222	 */
223	private static synchronized byte[] hi(final byte[] key, final byte[] salt, final int iterations)
224		throws InvalidKeyException {
225		byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
226		byte[] out = u.clone();
227		for (int i = 1; i < iterations; i++) {
228			u = hmac(key, u);
229			for (int j = 0; j < u.length; j++) {
230				out[j] ^= u[j];
231			}
232		}
233		return out;
234	}
235}