1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Base64;
4
5import com.google.common.base.Objects;
6import com.google.common.cache.Cache;
7import com.google.common.cache.CacheBuilder;
8
9import org.bouncycastle.crypto.Digest;
10import org.bouncycastle.crypto.macs.HMac;
11import org.bouncycastle.crypto.params.KeyParameter;
12
13import java.nio.charset.Charset;
14import java.security.InvalidKeyException;
15import java.util.concurrent.ExecutionException;
16
17import eu.siacs.conversations.entities.Account;
18import eu.siacs.conversations.utils.CryptoHelper;
19
20abstract class ScramMechanism extends SaslMechanism {
21 // TODO: When channel binding (SCRAM-SHA1-PLUS) is supported in future, generalize this to
22 // indicate support and/or usage.
23 private static final String GS2_HEADER = "n,,";
24 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
25 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
26 private static final Cache<CacheKey, KeyPair> CACHE =
27 CacheBuilder.newBuilder().maximumSize(10).build();
28 protected final ChannelBinding channelBinding;
29 private final String clientNonce;
30 protected State state = State.INITIAL;
31 private String clientFirstMessageBare;
32 private byte[] serverSignature = null;
33
34 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
35 super(account);
36 this.channelBinding = channelBinding;
37 // This nonce should be different for each authentication attempt.
38 this.clientNonce = CryptoHelper.random(100);
39 clientFirstMessageBare = "";
40 }
41
42 protected abstract HMac getHMAC();
43
44 protected abstract Digest getDigest();
45
46 private KeyPair getKeyPair(final String password, final String salt, final int iterations)
47 throws ExecutionException {
48 return CACHE.get(
49 new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations),
50 () -> {
51 final byte[] saltedPassword, serverKey, clientKey;
52 saltedPassword =
53 hi(
54 password.getBytes(),
55 Base64.decode(salt, Base64.DEFAULT),
56 iterations);
57 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
58 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
59 return new KeyPair(clientKey, serverKey);
60 });
61 }
62
63 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
64 final HMac hMac = getHMAC();
65 hMac.init(new KeyParameter(key));
66 hMac.update(input, 0, input.length);
67 final byte[] out = new byte[hMac.getMacSize()];
68 hMac.doFinal(out, 0);
69 return out;
70 }
71
72 public byte[] digest(byte[] bytes) {
73 final Digest digest = getDigest();
74 digest.reset();
75 digest.update(bytes, 0, bytes.length);
76 final byte[] out = new byte[digest.getDigestSize()];
77 digest.doFinal(out, 0);
78 return out;
79 }
80
81 /*
82 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
83 * pseudorandom function (PRF) and with dkLen == output length of
84 * HMAC() == output length of H().
85 */
86 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
87 throws InvalidKeyException {
88 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
89 byte[] out = u.clone();
90 for (int i = 1; i < iterations; i++) {
91 u = hmac(key, u);
92 for (int j = 0; j < u.length; j++) {
93 out[j] ^= u[j];
94 }
95 }
96 return out;
97 }
98
99 @Override
100 public String getClientFirstMessage() {
101 if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
102 clientFirstMessageBare =
103 "n="
104 + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
105 + ",r="
106 + this.clientNonce;
107 state = State.AUTH_TEXT_SENT;
108 }
109 return Base64.encodeToString(
110 (GS2_HEADER + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
111 Base64.NO_WRAP);
112 }
113
114 @Override
115 public String getResponse(final String challenge) throws AuthenticationException {
116 switch (state) {
117 case AUTH_TEXT_SENT:
118 if (challenge == null) {
119 throw new AuthenticationException("challenge can not be null");
120 }
121 byte[] serverFirstMessage;
122 try {
123 serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
124 } catch (IllegalArgumentException e) {
125 throw new AuthenticationException("Unable to decode server challenge", e);
126 }
127 final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
128 String nonce = "";
129 int iterationCount = -1;
130 String salt = "";
131 for (final String token : tokenizer) {
132 if (token.charAt(1) == '=') {
133 switch (token.charAt(0)) {
134 case 'i':
135 try {
136 iterationCount = Integer.parseInt(token.substring(2));
137 } catch (final NumberFormatException e) {
138 throw new AuthenticationException(e);
139 }
140 break;
141 case 's':
142 salt = token.substring(2);
143 break;
144 case 'r':
145 nonce = token.substring(2);
146 break;
147 case 'm':
148 /*
149 * RFC 5802:
150 * m: This attribute is reserved for future extensibility. In this
151 * version of SCRAM, its presence in a client or a server message
152 * MUST cause authentication failure when the attribute is parsed by
153 * the other end.
154 */
155 throw new AuthenticationException(
156 "Server sent reserved token: `m'");
157 }
158 }
159 }
160
161 if (iterationCount < 0) {
162 throw new AuthenticationException("Server did not send iteration count");
163 }
164 if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
165 throw new AuthenticationException(
166 "Server nonce does not contain client nonce: " + nonce);
167 }
168 if (salt.isEmpty()) {
169 throw new AuthenticationException("Server sent empty salt");
170 }
171
172 final String clientFinalMessageWithoutProof =
173 "c="
174 + Base64.encodeToString(GS2_HEADER.getBytes(), Base64.NO_WRAP)
175 + ",r="
176 + nonce;
177 final byte[] authMessage =
178 (clientFirstMessageBare
179 + ','
180 + new String(serverFirstMessage)
181 + ','
182 + clientFinalMessageWithoutProof)
183 .getBytes();
184
185 final KeyPair keys;
186 try {
187 keys =
188 getKeyPair(
189 CryptoHelper.saslPrep(account.getPassword()),
190 salt,
191 iterationCount);
192 } catch (ExecutionException e) {
193 throw new AuthenticationException("Invalid keys generated");
194 }
195 final byte[] clientSignature;
196 try {
197 serverSignature = hmac(keys.serverKey, authMessage);
198 final byte[] storedKey = digest(keys.clientKey);
199
200 clientSignature = hmac(storedKey, authMessage);
201
202 } catch (final InvalidKeyException e) {
203 throw new AuthenticationException(e);
204 }
205
206 final byte[] clientProof = new byte[keys.clientKey.length];
207
208 if (clientSignature.length < keys.clientKey.length) {
209 throw new AuthenticationException(
210 "client signature was shorter than clientKey");
211 }
212
213 for (int i = 0; i < clientProof.length; i++) {
214 clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
215 }
216
217 final String clientFinalMessage =
218 clientFinalMessageWithoutProof
219 + ",p="
220 + Base64.encodeToString(clientProof, Base64.NO_WRAP);
221 state = State.RESPONSE_SENT;
222 return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
223 case RESPONSE_SENT:
224 try {
225 final String clientCalculatedServerFinalMessage =
226 "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
227 if (!clientCalculatedServerFinalMessage.equals(
228 new String(Base64.decode(challenge, Base64.DEFAULT)))) {
229 throw new Exception();
230 }
231 state = State.VALID_SERVER_RESPONSE;
232 return "";
233 } catch (Exception e) {
234 throw new AuthenticationException(
235 "Server final message does not match calculated final message");
236 }
237 default:
238 throw new InvalidStateException(state);
239 }
240 }
241
242 private static class CacheKey {
243 final String algorithm;
244 final String password;
245 final String salt;
246 final int iterations;
247
248 private CacheKey(String algorithm, String password, String salt, int iterations) {
249 this.algorithm = algorithm;
250 this.password = password;
251 this.salt = salt;
252 this.iterations = iterations;
253 }
254
255 @Override
256 public boolean equals(Object o) {
257 if (this == o) return true;
258 if (o == null || getClass() != o.getClass()) return false;
259 CacheKey cacheKey = (CacheKey) o;
260 return iterations == cacheKey.iterations
261 && Objects.equal(algorithm, cacheKey.algorithm)
262 && Objects.equal(password, cacheKey.password)
263 && Objects.equal(salt, cacheKey.salt);
264 }
265
266 @Override
267 public int hashCode() {
268 return Objects.hashCode(algorithm, password, salt, iterations);
269 }
270 }
271
272 private static class KeyPair {
273 final byte[] clientKey;
274 final byte[] serverKey;
275
276 KeyPair(final byte[] clientKey, final byte[] serverKey) {
277 this.clientKey = clientKey;
278 this.serverKey = serverKey;
279 }
280 }
281}