1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Base64;
4
5import com.google.common.base.CaseFormat;
6import com.google.common.base.Objects;
7import com.google.common.cache.Cache;
8import com.google.common.cache.CacheBuilder;
9import com.google.common.hash.HashFunction;
10
11import java.nio.charset.Charset;
12import java.security.InvalidKeyException;
13import java.util.concurrent.ExecutionException;
14
15import javax.crypto.SecretKey;
16import javax.net.ssl.SSLSocket;
17
18import eu.siacs.conversations.entities.Account;
19import eu.siacs.conversations.utils.CryptoHelper;
20
21abstract class ScramMechanism extends SaslMechanism {
22
23 public static final SecretKey EMPTY_KEY =
24 new SecretKey() {
25 @Override
26 public String getAlgorithm() {
27 return "HMAC";
28 }
29
30 @Override
31 public String getFormat() {
32 return "RAW";
33 }
34
35 @Override
36 public byte[] getEncoded() {
37 return new byte[0];
38 }
39 };
40
41 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
42 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
43 private static final Cache<CacheKey, KeyPair> CACHE =
44 CacheBuilder.newBuilder().maximumSize(10).build();
45 protected final ChannelBinding channelBinding;
46 private final String gs2Header;
47 private final String clientNonce;
48 protected State state = State.INITIAL;
49 private String clientFirstMessageBare;
50 private byte[] serverSignature = null;
51
52 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
53 super(account);
54 this.channelBinding = channelBinding;
55 if (channelBinding == ChannelBinding.NONE) {
56 // TODO this needs to be changed to "y,," for the scram internal down grade protection
57 // but we might risk compatibility issues if the server supports a binding that we don’t
58 // support
59 this.gs2Header = "n,,";
60 } else {
61 this.gs2Header =
62 String.format(
63 "p=%s,,",
64 CaseFormat.UPPER_UNDERSCORE
65 .converterTo(CaseFormat.LOWER_HYPHEN)
66 .convert(channelBinding.toString()));
67 }
68 // This nonce should be different for each authentication attempt.
69 this.clientNonce = CryptoHelper.random(100);
70 clientFirstMessageBare = "";
71 }
72
73 protected abstract HashFunction getHMac(final byte[] key);
74
75 protected abstract HashFunction getDigest();
76
77 private KeyPair getKeyPair(final String password, final String salt, final int iterations)
78 throws ExecutionException {
79 return CACHE.get(
80 new CacheKey(getMechanism(), password, salt, iterations),
81 () -> {
82 final byte[] saltedPassword, serverKey, clientKey;
83 saltedPassword =
84 hi(
85 password.getBytes(),
86 Base64.decode(salt, Base64.DEFAULT),
87 iterations);
88 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
89 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
90 return new KeyPair(clientKey, serverKey);
91 });
92 }
93
94 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
95 return getHMac(key).hashBytes(input).asBytes();
96 }
97
98 private byte[] digest(final byte[] bytes) {
99 return getDigest().hashBytes(bytes).asBytes();
100 }
101
102 /*
103 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
104 * pseudorandom function (PRF) and with dkLen == output length of
105 * HMAC() == output length of H().
106 */
107 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
108 throws InvalidKeyException {
109 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
110 byte[] out = u.clone();
111 for (int i = 1; i < iterations; i++) {
112 u = hmac(key, u);
113 for (int j = 0; j < u.length; j++) {
114 out[j] ^= u[j];
115 }
116 }
117 return out;
118 }
119
120 @Override
121 public String getClientFirstMessage(final SSLSocket sslSocket) {
122 if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
123 clientFirstMessageBare =
124 "n="
125 + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
126 + ",r="
127 + this.clientNonce;
128 state = State.AUTH_TEXT_SENT;
129 }
130 return Base64.encodeToString(
131 (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
132 Base64.NO_WRAP);
133 }
134
135 @Override
136 public String getResponse(final String challenge, final SSLSocket socket)
137 throws AuthenticationException {
138 switch (state) {
139 case AUTH_TEXT_SENT:
140 if (challenge == null) {
141 throw new AuthenticationException("challenge can not be null");
142 }
143 byte[] serverFirstMessage;
144 try {
145 serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
146 } catch (IllegalArgumentException e) {
147 throw new AuthenticationException("Unable to decode server challenge", e);
148 }
149 final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
150 String nonce = "";
151 int iterationCount = -1;
152 String salt = "";
153 for (final String token : tokenizer) {
154 if (token.length() > 1 && token.charAt(1) == '=') {
155 switch (token.charAt(0)) {
156 case 'i':
157 try {
158 iterationCount = Integer.parseInt(token.substring(2));
159 } catch (final NumberFormatException e) {
160 throw new AuthenticationException(e);
161 }
162 break;
163 case 's':
164 salt = token.substring(2);
165 break;
166 case 'r':
167 nonce = token.substring(2);
168 break;
169 case 'm':
170 /*
171 * RFC 5802:
172 * m: This attribute is reserved for future extensibility. In this
173 * version of SCRAM, its presence in a client or a server message
174 * MUST cause authentication failure when the attribute is parsed by
175 * the other end.
176 */
177 throw new AuthenticationException(
178 "Server sent reserved token: `m'");
179 }
180 }
181 }
182
183 if (iterationCount < 0) {
184 throw new AuthenticationException("Server did not send iteration count");
185 }
186 if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
187 throw new AuthenticationException(
188 "Server nonce does not contain client nonce: " + nonce);
189 }
190 if (salt.isEmpty()) {
191 throw new AuthenticationException("Server sent empty salt");
192 }
193
194 final byte[] channelBindingData = getChannelBindingData(socket);
195
196 final int gs2Len = this.gs2Header.getBytes().length;
197 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
198 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
199 System.arraycopy(
200 channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
201
202 final String clientFinalMessageWithoutProof =
203 "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
204
205 final byte[] authMessage =
206 (clientFirstMessageBare
207 + ','
208 + new String(serverFirstMessage)
209 + ','
210 + clientFinalMessageWithoutProof)
211 .getBytes();
212
213 final KeyPair keys;
214 try {
215 keys =
216 getKeyPair(
217 CryptoHelper.saslPrep(account.getPassword()),
218 salt,
219 iterationCount);
220 } catch (ExecutionException e) {
221 throw new AuthenticationException("Invalid keys generated");
222 }
223 final byte[] clientSignature;
224 try {
225 serverSignature = hmac(keys.serverKey, authMessage);
226 final byte[] storedKey = digest(keys.clientKey);
227
228 clientSignature = hmac(storedKey, authMessage);
229
230 } catch (final InvalidKeyException e) {
231 throw new AuthenticationException(e);
232 }
233
234 final byte[] clientProof = new byte[keys.clientKey.length];
235
236 if (clientSignature.length < keys.clientKey.length) {
237 throw new AuthenticationException(
238 "client signature was shorter than clientKey");
239 }
240
241 for (int i = 0; i < clientProof.length; i++) {
242 clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
243 }
244
245 final String clientFinalMessage =
246 clientFinalMessageWithoutProof
247 + ",p="
248 + Base64.encodeToString(clientProof, Base64.NO_WRAP);
249 state = State.RESPONSE_SENT;
250 return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
251 case RESPONSE_SENT:
252 try {
253 final String clientCalculatedServerFinalMessage =
254 "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
255 if (!clientCalculatedServerFinalMessage.equals(
256 new String(Base64.decode(challenge, Base64.DEFAULT)))) {
257 throw new Exception();
258 }
259 state = State.VALID_SERVER_RESPONSE;
260 return "";
261 } catch (Exception e) {
262 throw new AuthenticationException(
263 "Server final message does not match calculated final message");
264 }
265 default:
266 throw new InvalidStateException(state);
267 }
268 }
269
270 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
271 throws AuthenticationException {
272 if (this.channelBinding == ChannelBinding.NONE) {
273 return new byte[0];
274 }
275 throw new AssertionError("getChannelBindingData needs to be overwritten");
276 }
277
278 private static class CacheKey {
279 final String algorithm;
280 final String password;
281 final String salt;
282 final int iterations;
283
284 private CacheKey(String algorithm, String password, String salt, int iterations) {
285 this.algorithm = algorithm;
286 this.password = password;
287 this.salt = salt;
288 this.iterations = iterations;
289 }
290
291 @Override
292 public boolean equals(Object o) {
293 if (this == o) return true;
294 if (o == null || getClass() != o.getClass()) return false;
295 CacheKey cacheKey = (CacheKey) o;
296 return iterations == cacheKey.iterations
297 && Objects.equal(algorithm, cacheKey.algorithm)
298 && Objects.equal(password, cacheKey.password)
299 && Objects.equal(salt, cacheKey.salt);
300 }
301
302 @Override
303 public int hashCode() {
304 return Objects.hashCode(algorithm, password, salt, iterations);
305 }
306 }
307
308 private static class KeyPair {
309 final byte[] clientKey;
310 final byte[] serverKey;
311
312 KeyPair(final byte[] clientKey, final byte[] serverKey) {
313 this.clientKey = clientKey;
314 this.serverKey = serverKey;
315 }
316 }
317}