1package eu.siacs.conversations.crypto.sasl;
2
3import com.google.common.base.CaseFormat;
4import com.google.common.base.Joiner;
5import com.google.common.base.Objects;
6import com.google.common.base.Preconditions;
7import com.google.common.base.Splitter;
8import com.google.common.base.Strings;
9import com.google.common.cache.Cache;
10import com.google.common.cache.CacheBuilder;
11import com.google.common.collect.ImmutableMap;
12import com.google.common.hash.HashFunction;
13import com.google.common.io.BaseEncoding;
14import com.google.common.primitives.Ints;
15import eu.siacs.conversations.entities.Account;
16import eu.siacs.conversations.utils.CryptoHelper;
17import java.security.InvalidKeyException;
18import java.util.Arrays;
19import java.util.Map;
20import java.util.concurrent.ExecutionException;
21import javax.crypto.SecretKey;
22import javax.net.ssl.SSLSocket;
23
24public abstract class ScramMechanism extends SaslMechanism {
25
26 public static final SecretKey EMPTY_KEY =
27 new SecretKey() {
28 @Override
29 public String getAlgorithm() {
30 return "HMAC";
31 }
32
33 @Override
34 public String getFormat() {
35 return "RAW";
36 }
37
38 @Override
39 public byte[] getEncoded() {
40 return new byte[0];
41 }
42 };
43
44 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
45 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
46 private static final Cache<CacheKey, KeyPair> CACHE =
47 CacheBuilder.newBuilder().maximumSize(10).build();
48 protected final ChannelBinding channelBinding;
49 private final String gs2Header;
50 private final String clientNonce;
51 private final String clientFirstMessageBare;
52 private byte[] serverSignature = null;
53 private DowngradeProtection downgradeProtection = null;
54
55 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
56 super(account);
57 this.channelBinding = channelBinding;
58 if (channelBinding == ChannelBinding.NONE) {
59 // TODO this needs to be changed to "y,," for the scram internal down grade protection
60 // but we might risk compatibility issues if the server supports a binding that we don’t
61 // support
62 this.gs2Header = "n,,";
63 } else {
64 this.gs2Header =
65 String.format(
66 "p=%s,,",
67 CaseFormat.UPPER_UNDERSCORE
68 .converterTo(CaseFormat.LOWER_HYPHEN)
69 .convert(channelBinding.toString()));
70 }
71 // This nonce should be different for each authentication attempt.
72 this.clientNonce = CryptoHelper.random(100);
73 this.clientFirstMessageBare =
74 String.format(
75 "n=%s,r=%s",
76 CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
77 this.clientNonce);
78 }
79
80 public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
81 Preconditions.checkState(
82 this.state == State.INITIAL, "setting downgrade protection in invalid state");
83 this.downgradeProtection = downgradeProtection;
84 }
85
86 protected abstract HashFunction getHMac(final byte[] key);
87
88 protected abstract HashFunction getDigest();
89
90 private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
91 throws ExecutionException {
92 final var key = new CacheKey(getMechanism(), password, salt, iterations);
93 return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
94 }
95
96 private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
97 throws InvalidKeyException {
98 final byte[] saltedPassword, serverKey, clientKey;
99 saltedPassword = hi(password.getBytes(), salt, iterations);
100 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
101 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
102 return new KeyPair(clientKey, serverKey);
103 }
104
105 @Override
106 public String getMechanism() {
107 return "";
108 }
109
110 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
111 return getHMac(key).hashBytes(input).asBytes();
112 }
113
114 private byte[] digest(final byte[] bytes) {
115 return getDigest().hashBytes(bytes).asBytes();
116 }
117
118 /*
119 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
120 * pseudorandom function (PRF) and with dkLen == output length of
121 * HMAC() == output length of H().
122 */
123 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
124 throws InvalidKeyException {
125 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
126 byte[] out = u.clone();
127 for (int i = 1; i < iterations; i++) {
128 u = hmac(key, u);
129 for (int j = 0; j < u.length; j++) {
130 out[j] ^= u[j];
131 }
132 }
133 return out;
134 }
135
136 @Override
137 public String getClientFirstMessage(final SSLSocket sslSocket) {
138 Preconditions.checkState(
139 this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
140 this.state = State.AUTH_TEXT_SENT;
141 final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
142 return BaseEncoding.base64().encode(message);
143 }
144
145 @Override
146 public String getResponse(final String challenge, final SSLSocket socket)
147 throws AuthenticationException {
148 return switch (state) {
149 case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
150 case RESPONSE_SENT -> processServerFinalMessage(challenge);
151 default -> throw new InvalidStateException(state);
152 };
153 }
154
155 private String processServerFirstMessage(final String challenge, final SSLSocket socket)
156 throws AuthenticationException {
157 if (Strings.isNullOrEmpty(challenge)) {
158 throw new AuthenticationException("challenge can not be null");
159 }
160 byte[] serverFirstMessage;
161 try {
162 serverFirstMessage = BaseEncoding.base64().decode(challenge);
163 } catch (final IllegalArgumentException e) {
164 throw new AuthenticationException("Unable to decode server challenge", e);
165 }
166 final Map<String, String> attributes;
167 try {
168 attributes = splitToAttributes(new String(serverFirstMessage));
169 } catch (final IllegalArgumentException e) {
170 throw new AuthenticationException("Duplicate attributes");
171 }
172 if (attributes.containsKey("m")) {
173 /*
174 * RFC 5802:
175 * m: This attribute is reserved for future extensibility. In this
176 * version of SCRAM, its presence in a client or a server message
177 * MUST cause authentication failure when the attribute is parsed by
178 * the other end.
179 */
180 throw new AuthenticationException("Server sent reserved token: 'm'");
181 }
182 final String i = attributes.get("i");
183 final String s = attributes.get("s");
184 final String nonce = attributes.get("r");
185 final String h = attributes.get("h");
186 if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
187 throw new AuthenticationException("Missing attributes from server first message");
188 }
189 final Integer iterationCount = Ints.tryParse(i);
190
191 if (iterationCount == null || iterationCount < 0) {
192 throw new AuthenticationException("Server did not send iteration count");
193 }
194 if (!nonce.startsWith(clientNonce)) {
195 throw new AuthenticationException(
196 "Server nonce does not contain client nonce: " + nonce);
197 }
198
199 final byte[] salt;
200
201 try {
202 salt = BaseEncoding.base64().decode(s);
203 } catch (final IllegalArgumentException e) {
204 throw new AuthenticationException("Invalid salt in server first message");
205 }
206
207 if (h != null && this.downgradeProtection != null) {
208 final String asSeenInFeatures;
209 try {
210 asSeenInFeatures = downgradeProtection.asHString();
211 } catch (final SecurityException e) {
212 throw new AuthenticationException(e);
213 }
214 final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
215 if (!hashed.equals(h)) {
216 throw new AuthenticationException("Mismatch in SSDP");
217 }
218 }
219
220 final byte[] channelBindingData = getChannelBindingData(socket);
221
222 final int gs2Len = this.gs2Header.getBytes().length;
223 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
224 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
225 System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
226
227 final String clientFinalMessageWithoutProof =
228 String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
229
230 final var authMessage =
231 Joiner.on(',')
232 .join(
233 clientFirstMessageBare,
234 new String(serverFirstMessage),
235 clientFinalMessageWithoutProof);
236
237 final KeyPair keys;
238 try {
239 keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
240 } catch (final ExecutionException e) {
241 throw new AuthenticationException("Invalid keys generated");
242 }
243 final byte[] clientSignature;
244 try {
245 serverSignature = hmac(keys.serverKey, authMessage.getBytes());
246 final byte[] storedKey = digest(keys.clientKey);
247
248 clientSignature = hmac(storedKey, authMessage.getBytes());
249
250 } catch (final InvalidKeyException e) {
251 throw new AuthenticationException(e);
252 }
253
254 final byte[] clientProof = new byte[keys.clientKey.length];
255
256 if (clientSignature.length < keys.clientKey.length) {
257 throw new AuthenticationException("client signature was shorter than clientKey");
258 }
259
260 for (int j = 0; j < clientProof.length; j++) {
261 clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
262 }
263
264 final var clientFinalMessage =
265 String.format(
266 "%s,p=%s",
267 clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
268 this.state = State.RESPONSE_SENT;
269 return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
270 }
271
272 private Map<String, String> splitToAttributes(final String message) {
273 final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
274 for (final String token : Splitter.on(',').split(message)) {
275 final var tuple = Splitter.on('=').limit(2).splitToList(token);
276 if (tuple.size() == 2) {
277 builder.put(tuple.get(0), tuple.get(1));
278 }
279 }
280 return builder.buildOrThrow();
281 }
282
283 private String processServerFinalMessage(final String challenge)
284 throws AuthenticationException {
285 final String serverFinalMessage;
286 try {
287 serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
288 } catch (final IllegalArgumentException e) {
289 throw new AuthenticationException("Invalid base64 in server final message", e);
290 }
291 final var clientCalculatedServerFinalMessage =
292 String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
293 if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
294 this.state = State.VALID_SERVER_RESPONSE;
295 return "";
296 }
297 throw new AuthenticationException(
298 "Server final message does not match calculated final message");
299 }
300
301 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
302 throws AuthenticationException {
303 if (this.channelBinding == ChannelBinding.NONE) {
304 return new byte[0];
305 }
306 throw new AssertionError("getChannelBindingData needs to be overwritten");
307 }
308
309 private static class CacheKey {
310 private final String algorithm;
311 private final String password;
312 private final byte[] salt;
313 private final int iterations;
314
315 private CacheKey(
316 final String algorithm,
317 final String password,
318 final byte[] salt,
319 final int iterations) {
320 this.algorithm = algorithm;
321 this.password = password;
322 this.salt = salt;
323 this.iterations = iterations;
324 }
325
326 @Override
327 public boolean equals(final Object o) {
328 if (this == o) return true;
329 if (o == null || getClass() != o.getClass()) return false;
330 CacheKey cacheKey = (CacheKey) o;
331 return iterations == cacheKey.iterations
332 && Objects.equal(algorithm, cacheKey.algorithm)
333 && Objects.equal(password, cacheKey.password)
334 && Arrays.equals(salt, cacheKey.salt);
335 }
336
337 @Override
338 public int hashCode() {
339 final int result = Objects.hashCode(algorithm, password, iterations);
340 return 31 * result + Arrays.hashCode(salt);
341 }
342 }
343
344 private static class KeyPair {
345 final byte[] clientKey;
346 final byte[] serverKey;
347
348 KeyPair(final byte[] clientKey, final byte[] serverKey) {
349 this.clientKey = clientKey;
350 this.serverKey = serverKey;
351 }
352 }
353}