1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Base64;
4
5import com.google.common.base.CaseFormat;
6import com.google.common.base.Objects;
7import com.google.common.cache.Cache;
8import com.google.common.cache.CacheBuilder;
9
10import org.bouncycastle.crypto.Digest;
11import org.bouncycastle.crypto.macs.HMac;
12import org.bouncycastle.crypto.params.KeyParameter;
13
14import java.nio.charset.Charset;
15import java.security.InvalidKeyException;
16import java.util.concurrent.ExecutionException;
17
18import javax.net.ssl.SSLSocket;
19
20import eu.siacs.conversations.entities.Account;
21import eu.siacs.conversations.utils.CryptoHelper;
22
23abstract class ScramMechanism extends SaslMechanism {
24
25 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
26 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
27 private static final Cache<CacheKey, KeyPair> CACHE =
28 CacheBuilder.newBuilder().maximumSize(10).build();
29 protected final ChannelBinding channelBinding;
30 private final String gs2Header;
31 private final String clientNonce;
32 protected State state = State.INITIAL;
33 private String clientFirstMessageBare;
34 private byte[] serverSignature = null;
35
36 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
37 super(account);
38 this.channelBinding = channelBinding;
39 if (channelBinding == ChannelBinding.NONE) {
40 // TODO this needs to be changed to "y,," for the scram internal down grade protection
41 // but we might risk compatibility issues if the server supports a binding that we don’t
42 // support
43 this.gs2Header = "n,,";
44 } else {
45 this.gs2Header =
46 String.format(
47 "p=%s,,",
48 CaseFormat.UPPER_UNDERSCORE
49 .converterTo(CaseFormat.LOWER_HYPHEN)
50 .convert(channelBinding.toString()));
51 }
52 // This nonce should be different for each authentication attempt.
53 this.clientNonce = CryptoHelper.random(100);
54 clientFirstMessageBare = "";
55 }
56
57 protected abstract HMac getHMAC();
58
59 protected abstract Digest getDigest();
60
61 private KeyPair getKeyPair(final String password, final String salt, final int iterations)
62 throws ExecutionException {
63 return CACHE.get(
64 new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations),
65 () -> {
66 final byte[] saltedPassword, serverKey, clientKey;
67 saltedPassword =
68 hi(
69 password.getBytes(),
70 Base64.decode(salt, Base64.DEFAULT),
71 iterations);
72 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
73 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
74 return new KeyPair(clientKey, serverKey);
75 });
76 }
77
78 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
79 final HMac hMac = getHMAC();
80 hMac.init(new KeyParameter(key));
81 hMac.update(input, 0, input.length);
82 final byte[] out = new byte[hMac.getMacSize()];
83 hMac.doFinal(out, 0);
84 return out;
85 }
86
87 public byte[] digest(final byte[] bytes) {
88 final Digest digest = getDigest();
89 digest.reset();
90 digest.update(bytes, 0, bytes.length);
91 final byte[] out = new byte[digest.getDigestSize()];
92 digest.doFinal(out, 0);
93 return out;
94 }
95
96 /*
97 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
98 * pseudorandom function (PRF) and with dkLen == output length of
99 * HMAC() == output length of H().
100 */
101 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
102 throws InvalidKeyException {
103 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
104 byte[] out = u.clone();
105 for (int i = 1; i < iterations; i++) {
106 u = hmac(key, u);
107 for (int j = 0; j < u.length; j++) {
108 out[j] ^= u[j];
109 }
110 }
111 return out;
112 }
113
114 @Override
115 public String getClientFirstMessage(final SSLSocket sslSocket) {
116 if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
117 clientFirstMessageBare =
118 "n="
119 + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
120 + ",r="
121 + this.clientNonce;
122 state = State.AUTH_TEXT_SENT;
123 }
124 return Base64.encodeToString(
125 (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
126 Base64.NO_WRAP);
127 }
128
129 @Override
130 public String getResponse(final String challenge, final SSLSocket socket)
131 throws AuthenticationException {
132 switch (state) {
133 case AUTH_TEXT_SENT:
134 if (challenge == null) {
135 throw new AuthenticationException("challenge can not be null");
136 }
137 byte[] serverFirstMessage;
138 try {
139 serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
140 } catch (IllegalArgumentException e) {
141 throw new AuthenticationException("Unable to decode server challenge", e);
142 }
143 final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
144 String nonce = "";
145 int iterationCount = -1;
146 String salt = "";
147 for (final String token : tokenizer) {
148 if (token.charAt(1) == '=') {
149 switch (token.charAt(0)) {
150 case 'i':
151 try {
152 iterationCount = Integer.parseInt(token.substring(2));
153 } catch (final NumberFormatException e) {
154 throw new AuthenticationException(e);
155 }
156 break;
157 case 's':
158 salt = token.substring(2);
159 break;
160 case 'r':
161 nonce = token.substring(2);
162 break;
163 case 'm':
164 /*
165 * RFC 5802:
166 * m: This attribute is reserved for future extensibility. In this
167 * version of SCRAM, its presence in a client or a server message
168 * MUST cause authentication failure when the attribute is parsed by
169 * the other end.
170 */
171 throw new AuthenticationException(
172 "Server sent reserved token: `m'");
173 }
174 }
175 }
176
177 if (iterationCount < 0) {
178 throw new AuthenticationException("Server did not send iteration count");
179 }
180 if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
181 throw new AuthenticationException(
182 "Server nonce does not contain client nonce: " + nonce);
183 }
184 if (salt.isEmpty()) {
185 throw new AuthenticationException("Server sent empty salt");
186 }
187
188 final byte[] channelBindingData = getChannelBindingData(socket);
189
190 final int gs2Len = this.gs2Header.getBytes().length;
191 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
192 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
193 System.arraycopy(
194 channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
195
196 final String clientFinalMessageWithoutProof =
197 "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
198
199 final byte[] authMessage =
200 (clientFirstMessageBare
201 + ','
202 + new String(serverFirstMessage)
203 + ','
204 + clientFinalMessageWithoutProof)
205 .getBytes();
206
207 final KeyPair keys;
208 try {
209 keys =
210 getKeyPair(
211 CryptoHelper.saslPrep(account.getPassword()),
212 salt,
213 iterationCount);
214 } catch (ExecutionException e) {
215 throw new AuthenticationException("Invalid keys generated");
216 }
217 final byte[] clientSignature;
218 try {
219 serverSignature = hmac(keys.serverKey, authMessage);
220 final byte[] storedKey = digest(keys.clientKey);
221
222 clientSignature = hmac(storedKey, authMessage);
223
224 } catch (final InvalidKeyException e) {
225 throw new AuthenticationException(e);
226 }
227
228 final byte[] clientProof = new byte[keys.clientKey.length];
229
230 if (clientSignature.length < keys.clientKey.length) {
231 throw new AuthenticationException(
232 "client signature was shorter than clientKey");
233 }
234
235 for (int i = 0; i < clientProof.length; i++) {
236 clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
237 }
238
239 final String clientFinalMessage =
240 clientFinalMessageWithoutProof
241 + ",p="
242 + Base64.encodeToString(clientProof, Base64.NO_WRAP);
243 state = State.RESPONSE_SENT;
244 return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
245 case RESPONSE_SENT:
246 try {
247 final String clientCalculatedServerFinalMessage =
248 "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
249 if (!clientCalculatedServerFinalMessage.equals(
250 new String(Base64.decode(challenge, Base64.DEFAULT)))) {
251 throw new Exception();
252 }
253 state = State.VALID_SERVER_RESPONSE;
254 return "";
255 } catch (Exception e) {
256 throw new AuthenticationException(
257 "Server final message does not match calculated final message");
258 }
259 default:
260 throw new InvalidStateException(state);
261 }
262 }
263
264 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
265 throws AuthenticationException {
266 if (this.channelBinding == ChannelBinding.NONE) {
267 return new byte[0];
268 }
269 throw new AssertionError("getChannelBindingData needs to be overwritten");
270 }
271
272 private static class CacheKey {
273 final String algorithm;
274 final String password;
275 final String salt;
276 final int iterations;
277
278 private CacheKey(String algorithm, String password, String salt, int iterations) {
279 this.algorithm = algorithm;
280 this.password = password;
281 this.salt = salt;
282 this.iterations = iterations;
283 }
284
285 @Override
286 public boolean equals(Object o) {
287 if (this == o) return true;
288 if (o == null || getClass() != o.getClass()) return false;
289 CacheKey cacheKey = (CacheKey) o;
290 return iterations == cacheKey.iterations
291 && Objects.equal(algorithm, cacheKey.algorithm)
292 && Objects.equal(password, cacheKey.password)
293 && Objects.equal(salt, cacheKey.salt);
294 }
295
296 @Override
297 public int hashCode() {
298 return Objects.hashCode(algorithm, password, salt, iterations);
299 }
300 }
301
302 private static class KeyPair {
303 final byte[] clientKey;
304 final byte[] serverKey;
305
306 KeyPair(final byte[] clientKey, final byte[] serverKey) {
307 this.clientKey = clientKey;
308 this.serverKey = serverKey;
309 }
310 }
311}