1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Base64;
4
5import com.google.common.base.CaseFormat;
6import com.google.common.base.Objects;
7import com.google.common.cache.Cache;
8import com.google.common.cache.CacheBuilder;
9
10import org.bouncycastle.crypto.Digest;
11import org.bouncycastle.crypto.macs.HMac;
12import org.bouncycastle.crypto.params.KeyParameter;
13
14import java.nio.charset.Charset;
15import java.security.InvalidKeyException;
16import java.util.concurrent.ExecutionException;
17
18import javax.net.ssl.SSLSocket;
19
20import eu.siacs.conversations.entities.Account;
21import eu.siacs.conversations.utils.CryptoHelper;
22
23abstract class ScramMechanism extends SaslMechanism {
24
25 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
26 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
27 private static final Cache<CacheKey, KeyPair> CACHE =
28 CacheBuilder.newBuilder().maximumSize(10).build();
29 protected final ChannelBinding channelBinding;
30 private final String gs2Header;
31 private final String clientNonce;
32 protected State state = State.INITIAL;
33 private String clientFirstMessageBare;
34 private byte[] serverSignature = null;
35
36 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
37 super(account);
38 this.channelBinding = channelBinding;
39 if (channelBinding == ChannelBinding.NONE) {
40 this.gs2Header = "n,,";
41 } else {
42 this.gs2Header =
43 String.format(
44 "p=%s,,",
45 CaseFormat.UPPER_UNDERSCORE
46 .converterTo(CaseFormat.LOWER_HYPHEN)
47 .convert(channelBinding.toString()));
48 }
49 // This nonce should be different for each authentication attempt.
50 this.clientNonce = CryptoHelper.random(100);
51 clientFirstMessageBare = "";
52 }
53
54 protected abstract HMac getHMAC();
55
56 protected abstract Digest getDigest();
57
58 private KeyPair getKeyPair(final String password, final String salt, final int iterations)
59 throws ExecutionException {
60 return CACHE.get(
61 new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations),
62 () -> {
63 final byte[] saltedPassword, serverKey, clientKey;
64 saltedPassword =
65 hi(
66 password.getBytes(),
67 Base64.decode(salt, Base64.DEFAULT),
68 iterations);
69 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
70 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
71 return new KeyPair(clientKey, serverKey);
72 });
73 }
74
75 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
76 final HMac hMac = getHMAC();
77 hMac.init(new KeyParameter(key));
78 hMac.update(input, 0, input.length);
79 final byte[] out = new byte[hMac.getMacSize()];
80 hMac.doFinal(out, 0);
81 return out;
82 }
83
84 public byte[] digest(final byte[] bytes) {
85 final Digest digest = getDigest();
86 digest.reset();
87 digest.update(bytes, 0, bytes.length);
88 final byte[] out = new byte[digest.getDigestSize()];
89 digest.doFinal(out, 0);
90 return out;
91 }
92
93 /*
94 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
95 * pseudorandom function (PRF) and with dkLen == output length of
96 * HMAC() == output length of H().
97 */
98 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
99 throws InvalidKeyException {
100 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
101 byte[] out = u.clone();
102 for (int i = 1; i < iterations; i++) {
103 u = hmac(key, u);
104 for (int j = 0; j < u.length; j++) {
105 out[j] ^= u[j];
106 }
107 }
108 return out;
109 }
110
111 @Override
112 public String getClientFirstMessage() {
113 if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
114 clientFirstMessageBare =
115 "n="
116 + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
117 + ",r="
118 + this.clientNonce;
119 state = State.AUTH_TEXT_SENT;
120 }
121 return Base64.encodeToString(
122 (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
123 Base64.NO_WRAP);
124 }
125
126 @Override
127 public String getResponse(final String challenge, final SSLSocket socket)
128 throws AuthenticationException {
129 switch (state) {
130 case AUTH_TEXT_SENT:
131 if (challenge == null) {
132 throw new AuthenticationException("challenge can not be null");
133 }
134 byte[] serverFirstMessage;
135 try {
136 serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
137 } catch (IllegalArgumentException e) {
138 throw new AuthenticationException("Unable to decode server challenge", e);
139 }
140 final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
141 String nonce = "";
142 int iterationCount = -1;
143 String salt = "";
144 for (final String token : tokenizer) {
145 if (token.charAt(1) == '=') {
146 switch (token.charAt(0)) {
147 case 'i':
148 try {
149 iterationCount = Integer.parseInt(token.substring(2));
150 } catch (final NumberFormatException e) {
151 throw new AuthenticationException(e);
152 }
153 break;
154 case 's':
155 salt = token.substring(2);
156 break;
157 case 'r':
158 nonce = token.substring(2);
159 break;
160 case 'm':
161 /*
162 * RFC 5802:
163 * m: This attribute is reserved for future extensibility. In this
164 * version of SCRAM, its presence in a client or a server message
165 * MUST cause authentication failure when the attribute is parsed by
166 * the other end.
167 */
168 throw new AuthenticationException(
169 "Server sent reserved token: `m'");
170 }
171 }
172 }
173
174 if (iterationCount < 0) {
175 throw new AuthenticationException("Server did not send iteration count");
176 }
177 if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
178 throw new AuthenticationException(
179 "Server nonce does not contain client nonce: " + nonce);
180 }
181 if (salt.isEmpty()) {
182 throw new AuthenticationException("Server sent empty salt");
183 }
184
185 final byte[] channelBindingData = getChannelBindingData(socket);
186
187 final int gs2Len = this.gs2Header.getBytes().length;
188 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
189 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
190 System.arraycopy(
191 channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
192
193 final String clientFinalMessageWithoutProof =
194 "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
195
196 final byte[] authMessage =
197 (clientFirstMessageBare
198 + ','
199 + new String(serverFirstMessage)
200 + ','
201 + clientFinalMessageWithoutProof)
202 .getBytes();
203
204 final KeyPair keys;
205 try {
206 keys =
207 getKeyPair(
208 CryptoHelper.saslPrep(account.getPassword()),
209 salt,
210 iterationCount);
211 } catch (ExecutionException e) {
212 throw new AuthenticationException("Invalid keys generated");
213 }
214 final byte[] clientSignature;
215 try {
216 serverSignature = hmac(keys.serverKey, authMessage);
217 final byte[] storedKey = digest(keys.clientKey);
218
219 clientSignature = hmac(storedKey, authMessage);
220
221 } catch (final InvalidKeyException e) {
222 throw new AuthenticationException(e);
223 }
224
225 final byte[] clientProof = new byte[keys.clientKey.length];
226
227 if (clientSignature.length < keys.clientKey.length) {
228 throw new AuthenticationException(
229 "client signature was shorter than clientKey");
230 }
231
232 for (int i = 0; i < clientProof.length; i++) {
233 clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
234 }
235
236 final String clientFinalMessage =
237 clientFinalMessageWithoutProof
238 + ",p="
239 + Base64.encodeToString(clientProof, Base64.NO_WRAP);
240 state = State.RESPONSE_SENT;
241 return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
242 case RESPONSE_SENT:
243 try {
244 final String clientCalculatedServerFinalMessage =
245 "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
246 if (!clientCalculatedServerFinalMessage.equals(
247 new String(Base64.decode(challenge, Base64.DEFAULT)))) {
248 throw new Exception();
249 }
250 state = State.VALID_SERVER_RESPONSE;
251 return "";
252 } catch (Exception e) {
253 throw new AuthenticationException(
254 "Server final message does not match calculated final message");
255 }
256 default:
257 throw new InvalidStateException(state);
258 }
259 }
260
261 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
262 throws AuthenticationException {
263 if (this.channelBinding == ChannelBinding.NONE) {
264 return new byte[0];
265 }
266 throw new AssertionError("getChannelBindingData needs to be overwritten");
267 }
268
269 private static class CacheKey {
270 final String algorithm;
271 final String password;
272 final String salt;
273 final int iterations;
274
275 private CacheKey(String algorithm, String password, String salt, int iterations) {
276 this.algorithm = algorithm;
277 this.password = password;
278 this.salt = salt;
279 this.iterations = iterations;
280 }
281
282 @Override
283 public boolean equals(Object o) {
284 if (this == o) return true;
285 if (o == null || getClass() != o.getClass()) return false;
286 CacheKey cacheKey = (CacheKey) o;
287 return iterations == cacheKey.iterations
288 && Objects.equal(algorithm, cacheKey.algorithm)
289 && Objects.equal(password, cacheKey.password)
290 && Objects.equal(salt, cacheKey.salt);
291 }
292
293 @Override
294 public int hashCode() {
295 return Objects.hashCode(algorithm, password, salt, iterations);
296 }
297 }
298
299 private static class KeyPair {
300 final byte[] clientKey;
301 final byte[] serverKey;
302
303 KeyPair(final byte[] clientKey, final byte[] serverKey) {
304 this.clientKey = clientKey;
305 this.serverKey = serverKey;
306 }
307 }
308}