1package eu.siacs.conversations.crypto.sasl;
2
3import android.util.Base64;
4import android.util.Log;
5
6import com.google.common.base.CaseFormat;
7import com.google.common.base.Objects;
8import com.google.common.cache.Cache;
9import com.google.common.cache.CacheBuilder;
10import com.google.common.hash.HashFunction;
11
12import java.nio.charset.Charset;
13import java.security.InvalidKeyException;
14import java.util.concurrent.ExecutionException;
15
16import javax.net.ssl.SSLSocket;
17
18import eu.siacs.conversations.Config;
19import eu.siacs.conversations.entities.Account;
20import eu.siacs.conversations.utils.CryptoHelper;
21
22abstract class ScramMechanism extends SaslMechanism {
23
24 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
25 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
26 private static final Cache<CacheKey, KeyPair> CACHE =
27 CacheBuilder.newBuilder().maximumSize(10).build();
28 protected final ChannelBinding channelBinding;
29 private final String gs2Header;
30 private final String clientNonce;
31 protected State state = State.INITIAL;
32 private String clientFirstMessageBare;
33 private byte[] serverSignature = null;
34
35 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
36 super(account);
37 this.channelBinding = channelBinding;
38 if (channelBinding == ChannelBinding.NONE) {
39 // TODO this needs to be changed to "y,," for the scram internal down grade protection
40 // but we might risk compatibility issues if the server supports a binding that we don’t
41 // support
42 this.gs2Header = "n,,";
43 } else {
44 this.gs2Header =
45 String.format(
46 "p=%s,,",
47 CaseFormat.UPPER_UNDERSCORE
48 .converterTo(CaseFormat.LOWER_HYPHEN)
49 .convert(channelBinding.toString()));
50 }
51 // This nonce should be different for each authentication attempt.
52 this.clientNonce = CryptoHelper.random(100);
53 clientFirstMessageBare = "";
54 }
55
56 protected abstract HashFunction getHMac(final byte[] key);
57
58 protected abstract HashFunction getDigest();
59
60 private KeyPair getKeyPair(final String password, final String salt, final int iterations)
61 throws ExecutionException {
62 return CACHE.get(
63 new CacheKey(getMechanism(), password, salt, iterations),
64 () -> {
65 final byte[] saltedPassword, serverKey, clientKey;
66 saltedPassword =
67 hi(
68 password.getBytes(),
69 Base64.decode(salt, Base64.DEFAULT),
70 iterations);
71 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
72 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
73 return new KeyPair(clientKey, serverKey);
74 });
75 }
76
77 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
78 return getHMac(key).hashBytes(input).asBytes();
79 }
80
81 private byte[] digest(final byte[] bytes) {
82 return getDigest().hashBytes(bytes).asBytes();
83 }
84
85 /*
86 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
87 * pseudorandom function (PRF) and with dkLen == output length of
88 * HMAC() == output length of H().
89 */
90 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
91 throws InvalidKeyException {
92 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
93 byte[] out = u.clone();
94 for (int i = 1; i < iterations; i++) {
95 u = hmac(key, u);
96 for (int j = 0; j < u.length; j++) {
97 out[j] ^= u[j];
98 }
99 }
100 return out;
101 }
102
103 @Override
104 public String getClientFirstMessage(final SSLSocket sslSocket) {
105 if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
106 clientFirstMessageBare =
107 "n="
108 + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
109 + ",r="
110 + this.clientNonce;
111 state = State.AUTH_TEXT_SENT;
112 }
113 return Base64.encodeToString(
114 (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
115 Base64.NO_WRAP);
116 }
117
118 @Override
119 public String getResponse(final String challenge, final SSLSocket socket)
120 throws AuthenticationException {
121 switch (state) {
122 case AUTH_TEXT_SENT:
123 if (challenge == null) {
124 throw new AuthenticationException("challenge can not be null");
125 }
126 byte[] serverFirstMessage;
127 try {
128 serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
129 } catch (IllegalArgumentException e) {
130 throw new AuthenticationException("Unable to decode server challenge", e);
131 }
132 final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
133 String nonce = "";
134 int iterationCount = -1;
135 String salt = "";
136 for (final String token : tokenizer) {
137 if (token.charAt(1) == '=') {
138 switch (token.charAt(0)) {
139 case 'i':
140 try {
141 iterationCount = Integer.parseInt(token.substring(2));
142 } catch (final NumberFormatException e) {
143 throw new AuthenticationException(e);
144 }
145 break;
146 case 's':
147 salt = token.substring(2);
148 break;
149 case 'r':
150 nonce = token.substring(2);
151 break;
152 case 'm':
153 /*
154 * RFC 5802:
155 * m: This attribute is reserved for future extensibility. In this
156 * version of SCRAM, its presence in a client or a server message
157 * MUST cause authentication failure when the attribute is parsed by
158 * the other end.
159 */
160 throw new AuthenticationException(
161 "Server sent reserved token: `m'");
162 }
163 }
164 }
165
166 if (iterationCount < 0) {
167 throw new AuthenticationException("Server did not send iteration count");
168 }
169 if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
170 throw new AuthenticationException(
171 "Server nonce does not contain client nonce: " + nonce);
172 }
173 if (salt.isEmpty()) {
174 throw new AuthenticationException("Server sent empty salt");
175 }
176
177 final byte[] channelBindingData = getChannelBindingData(socket);
178
179 final int gs2Len = this.gs2Header.getBytes().length;
180 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
181 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
182 System.arraycopy(
183 channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
184
185 final String clientFinalMessageWithoutProof =
186 "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
187
188 final byte[] authMessage =
189 (clientFirstMessageBare
190 + ','
191 + new String(serverFirstMessage)
192 + ','
193 + clientFinalMessageWithoutProof)
194 .getBytes();
195
196 final KeyPair keys;
197 try {
198 keys =
199 getKeyPair(
200 CryptoHelper.saslPrep(account.getPassword()),
201 salt,
202 iterationCount);
203 } catch (ExecutionException e) {
204 throw new AuthenticationException("Invalid keys generated");
205 }
206 final byte[] clientSignature;
207 try {
208 serverSignature = hmac(keys.serverKey, authMessage);
209 final byte[] storedKey = digest(keys.clientKey);
210
211 clientSignature = hmac(storedKey, authMessage);
212
213 } catch (final InvalidKeyException e) {
214 throw new AuthenticationException(e);
215 }
216
217 final byte[] clientProof = new byte[keys.clientKey.length];
218
219 if (clientSignature.length < keys.clientKey.length) {
220 throw new AuthenticationException(
221 "client signature was shorter than clientKey");
222 }
223
224 for (int i = 0; i < clientProof.length; i++) {
225 clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
226 }
227
228 final String clientFinalMessage =
229 clientFinalMessageWithoutProof
230 + ",p="
231 + Base64.encodeToString(clientProof, Base64.NO_WRAP);
232 state = State.RESPONSE_SENT;
233 return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
234 case RESPONSE_SENT:
235 try {
236 final String clientCalculatedServerFinalMessage =
237 "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
238 if (!clientCalculatedServerFinalMessage.equals(
239 new String(Base64.decode(challenge, Base64.DEFAULT)))) {
240 throw new Exception();
241 }
242 state = State.VALID_SERVER_RESPONSE;
243 return "";
244 } catch (Exception e) {
245 throw new AuthenticationException(
246 "Server final message does not match calculated final message");
247 }
248 default:
249 throw new InvalidStateException(state);
250 }
251 }
252
253 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
254 throws AuthenticationException {
255 if (this.channelBinding == ChannelBinding.NONE) {
256 return new byte[0];
257 }
258 throw new AssertionError("getChannelBindingData needs to be overwritten");
259 }
260
261 private static class CacheKey {
262 final String algorithm;
263 final String password;
264 final String salt;
265 final int iterations;
266
267 private CacheKey(String algorithm, String password, String salt, int iterations) {
268 this.algorithm = algorithm;
269 this.password = password;
270 this.salt = salt;
271 this.iterations = iterations;
272 }
273
274 @Override
275 public boolean equals(Object o) {
276 if (this == o) return true;
277 if (o == null || getClass() != o.getClass()) return false;
278 CacheKey cacheKey = (CacheKey) o;
279 return iterations == cacheKey.iterations
280 && Objects.equal(algorithm, cacheKey.algorithm)
281 && Objects.equal(password, cacheKey.password)
282 && Objects.equal(salt, cacheKey.salt);
283 }
284
285 @Override
286 public int hashCode() {
287 return Objects.hashCode(algorithm, password, salt, iterations);
288 }
289 }
290
291 private static class KeyPair {
292 final byte[] clientKey;
293 final byte[] serverKey;
294
295 KeyPair(final byte[] clientKey, final byte[] serverKey) {
296 this.clientKey = clientKey;
297 this.serverKey = serverKey;
298 }
299 }
300}