1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Base64;
4
5import com.google.common.base.Objects;
6import com.google.common.cache.Cache;
7import com.google.common.cache.CacheBuilder;
8
9import org.bouncycastle.crypto.Digest;
10import org.bouncycastle.crypto.macs.HMac;
11import org.bouncycastle.crypto.params.KeyParameter;
12
13import java.nio.charset.Charset;
14import java.security.InvalidKeyException;
15import java.security.SecureRandom;
16import java.util.concurrent.ExecutionException;
17
18import eu.siacs.conversations.entities.Account;
19import eu.siacs.conversations.utils.CryptoHelper;
20import eu.siacs.conversations.xml.TagWriter;
21
22abstract class ScramMechanism extends SaslMechanism {
23 // TODO: When channel binding (SCRAM-SHA1-PLUS) is supported in future, generalize this to indicate support and/or usage.
24 private final static String GS2_HEADER = "n,,";
25 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
26 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
27
28 protected abstract HMac getHMAC();
29
30 protected abstract Digest getDigest();
31
32 private static final Cache<CacheKey, KeyPair> CACHE = CacheBuilder.newBuilder().maximumSize(10).build();
33
34 private static class CacheKey {
35 final String algorithm;
36 final String password;
37 final String salt;
38 final int iterations;
39
40 private CacheKey(String algorithm, String password, String salt, int iterations) {
41 this.algorithm = algorithm;
42 this.password = password;
43 this.salt = salt;
44 this.iterations = iterations;
45 }
46
47 @Override
48 public boolean equals(Object o) {
49 if (this == o) return true;
50 if (o == null || getClass() != o.getClass()) return false;
51 CacheKey cacheKey = (CacheKey) o;
52 return iterations == cacheKey.iterations &&
53 Objects.equal(algorithm, cacheKey.algorithm) &&
54 Objects.equal(password, cacheKey.password) &&
55 Objects.equal(salt, cacheKey.salt);
56 }
57
58 @Override
59 public int hashCode() {
60 return Objects.hashCode(algorithm, password, salt, iterations);
61 }
62 }
63
64 private KeyPair getKeyPair(final String password, final String salt, final int iterations) throws ExecutionException {
65 return CACHE.get(new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations), () -> {
66 final byte[] saltedPassword, serverKey, clientKey;
67 saltedPassword = hi(password.getBytes(), Base64.decode(salt, Base64.DEFAULT), iterations);
68 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
69 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
70 return new KeyPair(clientKey, serverKey);
71 });
72 }
73
74 private final String clientNonce;
75 protected State state = State.INITIAL;
76 private String clientFirstMessageBare;
77 private byte[] serverSignature = null;
78
79 ScramMechanism(final TagWriter tagWriter, final Account account, final SecureRandom rng) {
80 super(tagWriter, account, rng);
81
82 // This nonce should be different for each authentication attempt.
83 clientNonce = CryptoHelper.random(100, rng);
84 clientFirstMessageBare = "";
85 }
86
87 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
88 final HMac hMac = getHMAC();
89 hMac.init(new KeyParameter(key));
90 hMac.update(input, 0, input.length);
91 final byte[] out = new byte[hMac.getMacSize()];
92 hMac.doFinal(out, 0);
93 return out;
94 }
95
96 public byte[] digest(byte[] bytes) {
97 final Digest digest = getDigest();
98 digest.reset();
99 digest.update(bytes, 0, bytes.length);
100 final byte[] out = new byte[digest.getDigestSize()];
101 digest.doFinal(out, 0);
102 return out;
103 }
104
105 /*
106 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
107 * pseudorandom function (PRF) and with dkLen == output length of
108 * HMAC() == output length of H().
109 */
110 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
111 throws InvalidKeyException {
112 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
113 byte[] out = u.clone();
114 for (int i = 1; i < iterations; i++) {
115 u = hmac(key, u);
116 for (int j = 0; j < u.length; j++) {
117 out[j] ^= u[j];
118 }
119 }
120 return out;
121 }
122
123 @Override
124 public String getClientFirstMessage() {
125 if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
126 clientFirstMessageBare = "n=" + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())) +
127 ",r=" + this.clientNonce;
128 state = State.AUTH_TEXT_SENT;
129 }
130 return Base64.encodeToString(
131 (GS2_HEADER + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
132 Base64.NO_WRAP);
133 }
134
135 @Override
136 public String getResponse(final String challenge) throws AuthenticationException {
137 switch (state) {
138 case AUTH_TEXT_SENT:
139 if (challenge == null) {
140 throw new AuthenticationException("challenge can not be null");
141 }
142 byte[] serverFirstMessage;
143 try {
144 serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
145 } catch (IllegalArgumentException e) {
146 throw new AuthenticationException("Unable to decode server challenge", e);
147 }
148 final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
149 String nonce = "";
150 int iterationCount = -1;
151 String salt = "";
152 for (final String token : tokenizer) {
153 if (token.charAt(1) == '=') {
154 switch (token.charAt(0)) {
155 case 'i':
156 try {
157 iterationCount = Integer.parseInt(token.substring(2));
158 } catch (final NumberFormatException e) {
159 throw new AuthenticationException(e);
160 }
161 break;
162 case 's':
163 salt = token.substring(2);
164 break;
165 case 'r':
166 nonce = token.substring(2);
167 break;
168 case 'm':
169 /*
170 * RFC 5802:
171 * m: This attribute is reserved for future extensibility. In this
172 * version of SCRAM, its presence in a client or a server message
173 * MUST cause authentication failure when the attribute is parsed by
174 * the other end.
175 */
176 throw new AuthenticationException("Server sent reserved token: `m'");
177 }
178 }
179 }
180
181 if (iterationCount < 0) {
182 throw new AuthenticationException("Server did not send iteration count");
183 }
184 if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
185 throw new AuthenticationException("Server nonce does not contain client nonce: " + nonce);
186 }
187 if (salt.isEmpty()) {
188 throw new AuthenticationException("Server sent empty salt");
189 }
190
191 final String clientFinalMessageWithoutProof = "c=" + Base64.encodeToString(
192 GS2_HEADER.getBytes(), Base64.NO_WRAP) + ",r=" + nonce;
193 final byte[] authMessage = (clientFirstMessageBare + ',' + new String(serverFirstMessage) + ','
194 + clientFinalMessageWithoutProof).getBytes();
195
196 final KeyPair keys;
197 try {
198 keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
199 } catch (ExecutionException e) {
200 throw new AuthenticationException("Invalid keys generated");
201 }
202 final byte[] clientSignature;
203 try {
204 serverSignature = hmac(keys.serverKey, authMessage);
205 final byte[] storedKey = digest(keys.clientKey);
206
207 clientSignature = hmac(storedKey, authMessage);
208
209 } catch (final InvalidKeyException e) {
210 throw new AuthenticationException(e);
211 }
212
213 final byte[] clientProof = new byte[keys.clientKey.length];
214
215 if (clientSignature.length < keys.clientKey.length) {
216 throw new AuthenticationException("client signature was shorter than clientKey");
217 }
218
219 for (int i = 0; i < clientProof.length; i++) {
220 clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
221 }
222
223
224 final String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" +
225 Base64.encodeToString(clientProof, Base64.NO_WRAP);
226 state = State.RESPONSE_SENT;
227 return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
228 case RESPONSE_SENT:
229 try {
230 final String clientCalculatedServerFinalMessage = "v=" +
231 Base64.encodeToString(serverSignature, Base64.NO_WRAP);
232 if (!clientCalculatedServerFinalMessage.equals(new String(Base64.decode(challenge, Base64.DEFAULT)))) {
233 throw new Exception();
234 }
235 state = State.VALID_SERVER_RESPONSE;
236 return "";
237 } catch (Exception e) {
238 throw new AuthenticationException("Server final message does not match calculated final message");
239 }
240 default:
241 throw new InvalidStateException(state);
242 }
243 }
244
245 private static class KeyPair {
246 final byte[] clientKey;
247 final byte[] serverKey;
248
249 KeyPair(final byte[] clientKey, final byte[] serverKey) {
250 this.clientKey = clientKey;
251 this.serverKey = serverKey;
252 }
253 }
254}