1package eu.siacs.conversations.crypto.sasl;
2
3import android.annotation.TargetApi;
4import android.os.Build;
5import android.util.Base64;
6
7import com.google.common.base.Objects;
8import com.google.common.cache.Cache;
9import com.google.common.cache.CacheBuilder;
10
11import org.bouncycastle.crypto.Digest;
12import org.bouncycastle.crypto.macs.HMac;
13import org.bouncycastle.crypto.params.KeyParameter;
14
15import java.nio.charset.Charset;
16import java.security.InvalidKeyException;
17import java.security.SecureRandom;
18import java.util.concurrent.ExecutionException;
19
20import eu.siacs.conversations.entities.Account;
21import eu.siacs.conversations.utils.CryptoHelper;
22import eu.siacs.conversations.xml.TagWriter;
23
24@TargetApi(Build.VERSION_CODES.HONEYCOMB_MR1)
25abstract class ScramMechanism extends SaslMechanism {
26 // TODO: When channel binding (SCRAM-SHA1-PLUS) is supported in future, generalize this to indicate support and/or usage.
27 private final static String GS2_HEADER = "n,,";
28 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
29 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
30
31 protected abstract HMac getHMAC();
32
33 protected abstract Digest getDigest();
34
35 private static final Cache<CacheKey, KeyPair> CACHE = CacheBuilder.newBuilder().maximumSize(10).build();
36
37 private static class CacheKey {
38 final String algorithm;
39 final String password;
40 final String salt;
41 final int iterations;
42
43 private CacheKey(String algorithm, String password, String salt, int iterations) {
44 this.algorithm = algorithm;
45 this.password = password;
46 this.salt = salt;
47 this.iterations = iterations;
48 }
49
50 @Override
51 public boolean equals(Object o) {
52 if (this == o) return true;
53 if (o == null || getClass() != o.getClass()) return false;
54 CacheKey cacheKey = (CacheKey) o;
55 return iterations == cacheKey.iterations &&
56 Objects.equal(algorithm, cacheKey.algorithm) &&
57 Objects.equal(password, cacheKey.password) &&
58 Objects.equal(salt, cacheKey.salt);
59 }
60
61 @Override
62 public int hashCode() {
63 return Objects.hashCode(algorithm, password, salt, iterations);
64 }
65 }
66
67 private KeyPair getKeyPair(final String password, final String salt, final int iterations) throws ExecutionException {
68 return CACHE.get(new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations), () -> {
69 final byte[] saltedPassword, serverKey, clientKey;
70 saltedPassword = hi(password.getBytes(), Base64.decode(salt, Base64.DEFAULT), iterations);
71 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
72 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
73 return new KeyPair(clientKey, serverKey);
74 });
75 }
76
77 private final String clientNonce;
78 protected State state = State.INITIAL;
79 private String clientFirstMessageBare;
80 private byte[] serverSignature = null;
81
82 ScramMechanism(final TagWriter tagWriter, final Account account, final SecureRandom rng) {
83 super(tagWriter, account, rng);
84
85 // This nonce should be different for each authentication attempt.
86 clientNonce = CryptoHelper.random(100, rng);
87 clientFirstMessageBare = "";
88 }
89
90 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
91 final HMac hMac = getHMAC();
92 hMac.init(new KeyParameter(key));
93 hMac.update(input, 0, input.length);
94 final byte[] out = new byte[hMac.getMacSize()];
95 hMac.doFinal(out, 0);
96 return out;
97 }
98
99 public byte[] digest(byte[] bytes) {
100 final Digest digest = getDigest();
101 digest.reset();
102 digest.update(bytes, 0, bytes.length);
103 final byte[] out = new byte[digest.getDigestSize()];
104 digest.doFinal(out, 0);
105 return out;
106 }
107
108 /*
109 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
110 * pseudorandom function (PRF) and with dkLen == output length of
111 * HMAC() == output length of H().
112 */
113 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
114 throws InvalidKeyException {
115 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
116 byte[] out = u.clone();
117 for (int i = 1; i < iterations; i++) {
118 u = hmac(key, u);
119 for (int j = 0; j < u.length; j++) {
120 out[j] ^= u[j];
121 }
122 }
123 return out;
124 }
125
126 @Override
127 public String getClientFirstMessage() {
128 if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
129 clientFirstMessageBare = "n=" + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())) +
130 ",r=" + this.clientNonce;
131 state = State.AUTH_TEXT_SENT;
132 }
133 return Base64.encodeToString(
134 (GS2_HEADER + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
135 Base64.NO_WRAP);
136 }
137
138 @Override
139 public String getResponse(final String challenge) throws AuthenticationException {
140 switch (state) {
141 case AUTH_TEXT_SENT:
142 if (challenge == null) {
143 throw new AuthenticationException("challenge can not be null");
144 }
145 byte[] serverFirstMessage;
146 try {
147 serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
148 } catch (IllegalArgumentException e) {
149 throw new AuthenticationException("Unable to decode server challenge", e);
150 }
151 final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
152 String nonce = "";
153 int iterationCount = -1;
154 String salt = "";
155 for (final String token : tokenizer) {
156 if (token.charAt(1) == '=') {
157 switch (token.charAt(0)) {
158 case 'i':
159 try {
160 iterationCount = Integer.parseInt(token.substring(2));
161 } catch (final NumberFormatException e) {
162 throw new AuthenticationException(e);
163 }
164 break;
165 case 's':
166 salt = token.substring(2);
167 break;
168 case 'r':
169 nonce = token.substring(2);
170 break;
171 case 'm':
172 /*
173 * RFC 5802:
174 * m: This attribute is reserved for future extensibility. In this
175 * version of SCRAM, its presence in a client or a server message
176 * MUST cause authentication failure when the attribute is parsed by
177 * the other end.
178 */
179 throw new AuthenticationException("Server sent reserved token: `m'");
180 }
181 }
182 }
183
184 if (iterationCount < 0) {
185 throw new AuthenticationException("Server did not send iteration count");
186 }
187 if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
188 throw new AuthenticationException("Server nonce does not contain client nonce: " + nonce);
189 }
190 if (salt.isEmpty()) {
191 throw new AuthenticationException("Server sent empty salt");
192 }
193
194 final String clientFinalMessageWithoutProof = "c=" + Base64.encodeToString(
195 GS2_HEADER.getBytes(), Base64.NO_WRAP) + ",r=" + nonce;
196 final byte[] authMessage = (clientFirstMessageBare + ',' + new String(serverFirstMessage) + ','
197 + clientFinalMessageWithoutProof).getBytes();
198
199 final KeyPair keys;
200 try {
201 keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
202 } catch (ExecutionException e) {
203 throw new AuthenticationException("Invalid keys generated");
204 }
205 final byte[] clientSignature;
206 try {
207 serverSignature = hmac(keys.serverKey, authMessage);
208 final byte[] storedKey = digest(keys.clientKey);
209
210 clientSignature = hmac(storedKey, authMessage);
211
212 } catch (final InvalidKeyException e) {
213 throw new AuthenticationException(e);
214 }
215
216 final byte[] clientProof = new byte[keys.clientKey.length];
217
218 if (clientSignature.length < keys.clientKey.length) {
219 throw new AuthenticationException("client signature was shorter than clientKey");
220 }
221
222 for (int i = 0; i < clientProof.length; i++) {
223 clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
224 }
225
226
227 final String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" +
228 Base64.encodeToString(clientProof, Base64.NO_WRAP);
229 state = State.RESPONSE_SENT;
230 return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
231 case RESPONSE_SENT:
232 try {
233 final String clientCalculatedServerFinalMessage = "v=" +
234 Base64.encodeToString(serverSignature, Base64.NO_WRAP);
235 if (!clientCalculatedServerFinalMessage.equals(new String(Base64.decode(challenge, Base64.DEFAULT)))) {
236 throw new Exception();
237 }
238 state = State.VALID_SERVER_RESPONSE;
239 return "";
240 } catch (Exception e) {
241 throw new AuthenticationException("Server final message does not match calculated final message");
242 }
243 default:
244 throw new InvalidStateException(state);
245 }
246 }
247
248 private static class KeyPair {
249 final byte[] clientKey;
250 final byte[] serverKey;
251
252 KeyPair(final byte[] clientKey, final byte[] serverKey) {
253 this.clientKey = clientKey;
254 this.serverKey = serverKey;
255 }
256 }
257}