1package eu.siacs.conversations.crypto.sasl;
2
3import com.google.common.base.CaseFormat;
4import com.google.common.base.Joiner;
5import com.google.common.base.Objects;
6import com.google.common.base.Preconditions;
7import com.google.common.base.Splitter;
8import com.google.common.base.Strings;
9import com.google.common.cache.Cache;
10import com.google.common.cache.CacheBuilder;
11import com.google.common.collect.ImmutableMap;
12import com.google.common.hash.HashFunction;
13import com.google.common.io.BaseEncoding;
14import com.google.common.primitives.Ints;
15import eu.siacs.conversations.entities.Account;
16import eu.siacs.conversations.utils.CryptoHelper;
17import java.security.InvalidKeyException;
18import java.util.Arrays;
19import java.util.Map;
20import java.util.concurrent.ExecutionException;
21import javax.crypto.SecretKey;
22import javax.net.ssl.SSLSocket;
23
24public abstract class ScramMechanism extends SaslMechanism {
25
26 public static final SecretKey EMPTY_KEY =
27 new SecretKey() {
28 @Override
29 public String getAlgorithm() {
30 return "HMAC";
31 }
32
33 @Override
34 public String getFormat() {
35 return "RAW";
36 }
37
38 @Override
39 public byte[] getEncoded() {
40 return new byte[0];
41 }
42 };
43
44 // For the SCRAM-SHA-1/SCRAM-SHA-1-PLUS SASL mechanism, servers SHOULD announce a hash
45 // iteration-count of at least 4096.
46 // https://datatracker.ietf.org/doc/html/rfc5802#section-5.1
47 private static final int ITERATION_COUNT_MINIMUM = 4096;
48 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
49 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
50 private static final Cache<CacheKey, KeyPair> CACHE =
51 CacheBuilder.newBuilder().maximumSize(10).build();
52 protected final ChannelBinding channelBinding;
53 private final String gs2Header;
54 private final String clientNonce;
55 private final String clientFirstMessageBare;
56 private byte[] serverSignature = null;
57 private DowngradeProtection downgradeProtection = null;
58
59 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
60 super(account);
61 this.channelBinding = channelBinding;
62 if (channelBinding == ChannelBinding.NONE) {
63 this.gs2Header = "y,,";
64 } else {
65 this.gs2Header =
66 String.format(
67 "p=%s,,",
68 CaseFormat.UPPER_UNDERSCORE
69 .converterTo(CaseFormat.LOWER_HYPHEN)
70 .convert(channelBinding.toString()));
71 }
72 // This nonce should be different for each authentication attempt.
73 this.clientNonce = CryptoHelper.random(100);
74 this.clientFirstMessageBare =
75 String.format(
76 "n=%s,r=%s",
77 CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
78 this.clientNonce);
79 }
80
81 public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
82 Preconditions.checkState(
83 this.state == State.INITIAL, "setting downgrade protection in invalid state");
84 this.downgradeProtection = downgradeProtection;
85 }
86
87 protected abstract HashFunction getHMac(final byte[] key);
88
89 protected abstract HashFunction getDigest();
90
91 private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
92 throws ExecutionException {
93 final var key = new CacheKey(getMechanism(), password, salt, iterations);
94 return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
95 }
96
97 private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
98 throws InvalidKeyException {
99 final byte[] saltedPassword, serverKey, clientKey;
100 saltedPassword = hi(password.getBytes(), salt, iterations);
101 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
102 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
103 return new KeyPair(clientKey, serverKey);
104 }
105
106 @Override
107 public String getMechanism() {
108 return "";
109 }
110
111 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
112 return getHMac(key).hashBytes(input).asBytes();
113 }
114
115 private byte[] digest(final byte[] bytes) {
116 return getDigest().hashBytes(bytes).asBytes();
117 }
118
119 /*
120 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
121 * pseudorandom function (PRF) and with dkLen == output length of
122 * HMAC() == output length of H().
123 */
124 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
125 throws InvalidKeyException {
126 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
127 byte[] out = u.clone();
128 for (int i = 1; i < iterations; i++) {
129 u = hmac(key, u);
130 for (int j = 0; j < u.length; j++) {
131 out[j] ^= u[j];
132 }
133 }
134 return out;
135 }
136
137 @Override
138 public String getClientFirstMessage(final SSLSocket sslSocket) {
139 Preconditions.checkState(
140 this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
141 this.state = State.AUTH_TEXT_SENT;
142 final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
143 return BaseEncoding.base64().encode(message);
144 }
145
146 @Override
147 public String getResponse(final String challenge, final SSLSocket socket)
148 throws AuthenticationException {
149 return switch (state) {
150 case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
151 case RESPONSE_SENT -> processServerFinalMessage(challenge);
152 default -> throw new InvalidStateException(state);
153 };
154 }
155
156 private String processServerFirstMessage(final String challenge, final SSLSocket socket)
157 throws AuthenticationException {
158 if (Strings.isNullOrEmpty(challenge)) {
159 throw new AuthenticationException("challenge can not be null");
160 }
161 byte[] serverFirstMessage;
162 try {
163 serverFirstMessage = BaseEncoding.base64().decode(challenge);
164 } catch (final IllegalArgumentException e) {
165 throw new AuthenticationException("Unable to decode server challenge", e);
166 }
167 final Map<String, String> attributes;
168 try {
169 attributes = splitToAttributes(new String(serverFirstMessage));
170 } catch (final IllegalArgumentException e) {
171 throw new AuthenticationException("Duplicate attributes");
172 }
173 if (attributes.containsKey("m")) {
174 /*
175 * RFC 5802:
176 * m: This attribute is reserved for future extensibility. In this
177 * version of SCRAM, its presence in a client or a server message
178 * MUST cause authentication failure when the attribute is parsed by
179 * the other end.
180 */
181 throw new AuthenticationException("Server sent reserved token: 'm'");
182 }
183 final String i = attributes.get("i");
184 final String s = attributes.get("s");
185 final String nonce = attributes.get("r");
186 final String h = attributes.get("h");
187 if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
188 throw new AuthenticationException("Missing attributes from server first message");
189 }
190 final Integer iterationCount = Ints.tryParse(i);
191
192 if (iterationCount == null || iterationCount < 0) {
193 throw new AuthenticationException("Server did not send iteration count");
194 }
195
196 if (iterationCount < ITERATION_COUNT_MINIMUM) {
197 throw new AuthenticationException(
198 String.format(
199 "Weak iteration count. %d instead of %d",
200 iterationCount, ITERATION_COUNT_MINIMUM));
201 }
202
203 if (!nonce.startsWith(clientNonce)) {
204 throw new AuthenticationException(
205 "Server nonce does not contain client nonce: " + nonce);
206 }
207
208 final byte[] salt;
209
210 try {
211 salt = BaseEncoding.base64().decode(s);
212 } catch (final IllegalArgumentException e) {
213 throw new AuthenticationException("Invalid salt in server first message");
214 }
215
216 if (h != null && this.downgradeProtection != null) {
217 final String asSeenInFeatures;
218 try {
219 asSeenInFeatures = downgradeProtection.asHString();
220 } catch (final SecurityException e) {
221 throw new AuthenticationException(e);
222 }
223 final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
224 if (!hashed.equals(h)) {
225 throw new AuthenticationException("Mismatch in SSDP");
226 }
227 }
228
229 final byte[] channelBindingData = getChannelBindingData(socket);
230
231 final int gs2Len = this.gs2Header.getBytes().length;
232 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
233 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
234 System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
235
236 final String clientFinalMessageWithoutProof =
237 String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
238
239 final var authMessage =
240 Joiner.on(',')
241 .join(
242 clientFirstMessageBare,
243 new String(serverFirstMessage),
244 clientFinalMessageWithoutProof);
245
246 final KeyPair keys;
247 try {
248 keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
249 } catch (final ExecutionException e) {
250 throw new AuthenticationException("Invalid keys generated");
251 }
252 final byte[] clientSignature;
253 try {
254 serverSignature = hmac(keys.serverKey, authMessage.getBytes());
255 final byte[] storedKey = digest(keys.clientKey);
256
257 clientSignature = hmac(storedKey, authMessage.getBytes());
258
259 } catch (final InvalidKeyException e) {
260 throw new AuthenticationException(e);
261 }
262
263 final byte[] clientProof = new byte[keys.clientKey.length];
264
265 if (clientSignature.length < keys.clientKey.length) {
266 throw new AuthenticationException("client signature was shorter than clientKey");
267 }
268
269 for (int j = 0; j < clientProof.length; j++) {
270 clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
271 }
272
273 final var clientFinalMessage =
274 String.format(
275 "%s,p=%s",
276 clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
277 this.state = State.RESPONSE_SENT;
278 return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
279 }
280
281 private Map<String, String> splitToAttributes(final String message) {
282 final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
283 for (final String token : Splitter.on(',').split(message)) {
284 final var tuple = Splitter.on('=').limit(2).splitToList(token);
285 if (tuple.size() == 2) {
286 builder.put(tuple.get(0), tuple.get(1));
287 }
288 }
289 return builder.buildOrThrow();
290 }
291
292 private String processServerFinalMessage(final String challenge)
293 throws AuthenticationException {
294 final String serverFinalMessage;
295 try {
296 serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
297 } catch (final IllegalArgumentException e) {
298 throw new AuthenticationException("Invalid base64 in server final message", e);
299 }
300 final var clientCalculatedServerFinalMessage =
301 String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
302 if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
303 this.state = State.VALID_SERVER_RESPONSE;
304 return "";
305 }
306 throw new AuthenticationException(
307 "Server final message does not match calculated final message");
308 }
309
310 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
311 throws AuthenticationException {
312 if (this.channelBinding == ChannelBinding.NONE) {
313 return new byte[0];
314 }
315 throw new AssertionError("getChannelBindingData needs to be overwritten");
316 }
317
318 private static class CacheKey {
319 private final String algorithm;
320 private final String password;
321 private final byte[] salt;
322 private final int iterations;
323
324 private CacheKey(
325 final String algorithm,
326 final String password,
327 final byte[] salt,
328 final int iterations) {
329 this.algorithm = algorithm;
330 this.password = password;
331 this.salt = salt;
332 this.iterations = iterations;
333 }
334
335 @Override
336 public boolean equals(final Object o) {
337 if (this == o) return true;
338 if (o == null || getClass() != o.getClass()) return false;
339 CacheKey cacheKey = (CacheKey) o;
340 return iterations == cacheKey.iterations
341 && Objects.equal(algorithm, cacheKey.algorithm)
342 && Objects.equal(password, cacheKey.password)
343 && Arrays.equals(salt, cacheKey.salt);
344 }
345
346 @Override
347 public int hashCode() {
348 final int result = Objects.hashCode(algorithm, password, iterations);
349 return 31 * result + Arrays.hashCode(salt);
350 }
351 }
352
353 private static class KeyPair {
354 final byte[] clientKey;
355 final byte[] serverKey;
356
357 KeyPair(final byte[] clientKey, final byte[] serverKey) {
358 this.clientKey = clientKey;
359 this.serverKey = serverKey;
360 }
361 }
362}