1package eu.siacs.conversations.crypto.sasl;
2
3import com.google.common.base.CaseFormat;
4import com.google.common.base.Joiner;
5import com.google.common.base.Objects;
6import com.google.common.base.Preconditions;
7import com.google.common.base.Splitter;
8import com.google.common.base.Strings;
9import com.google.common.cache.Cache;
10import com.google.common.cache.CacheBuilder;
11import com.google.common.collect.ImmutableMap;
12import com.google.common.hash.HashFunction;
13import com.google.common.io.BaseEncoding;
14import com.google.common.primitives.Ints;
15import eu.siacs.conversations.entities.Account;
16import eu.siacs.conversations.utils.CryptoHelper;
17import java.security.InvalidKeyException;
18import java.util.Arrays;
19import java.util.Map;
20import java.util.concurrent.ExecutionException;
21import javax.crypto.SecretKey;
22import javax.net.ssl.SSLSocket;
23
24public abstract class ScramMechanism extends SaslMechanism {
25
26 public static final SecretKey EMPTY_KEY =
27 new SecretKey() {
28 @Override
29 public String getAlgorithm() {
30 return "HMAC";
31 }
32
33 @Override
34 public String getFormat() {
35 return "RAW";
36 }
37
38 @Override
39 public byte[] getEncoded() {
40 return new byte[0];
41 }
42 };
43
44 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
45 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
46 private static final Cache<CacheKey, KeyPair> CACHE =
47 CacheBuilder.newBuilder().maximumSize(10).build();
48 protected final ChannelBinding channelBinding;
49 private final String gs2Header;
50 private final String clientNonce;
51 protected State state = State.INITIAL;
52 private final String clientFirstMessageBare;
53 private byte[] serverSignature = null;
54 private DowngradeProtection downgradeProtection = null;
55
56 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
57 super(account);
58 this.channelBinding = channelBinding;
59 if (channelBinding == ChannelBinding.NONE) {
60 // TODO this needs to be changed to "y,," for the scram internal down grade protection
61 // but we might risk compatibility issues if the server supports a binding that we don’t
62 // support
63 this.gs2Header = "n,,";
64 } else {
65 this.gs2Header =
66 String.format(
67 "p=%s,,",
68 CaseFormat.UPPER_UNDERSCORE
69 .converterTo(CaseFormat.LOWER_HYPHEN)
70 .convert(channelBinding.toString()));
71 }
72 // This nonce should be different for each authentication attempt.
73 this.clientNonce = CryptoHelper.random(100);
74 this.clientFirstMessageBare =
75 String.format(
76 "n=%s,r=%s",
77 CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
78 this.clientNonce);
79 }
80
81 public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
82 Preconditions.checkState(
83 this.state == State.INITIAL, "setting downgrade protection in invalid state");
84 this.downgradeProtection = downgradeProtection;
85 }
86
87 protected abstract HashFunction getHMac(final byte[] key);
88
89 protected abstract HashFunction getDigest();
90
91 private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
92 throws ExecutionException {
93 final var key = new CacheKey(getMechanism(), password, salt, iterations);
94 return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
95 }
96
97 private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
98 throws InvalidKeyException {
99 final byte[] saltedPassword, serverKey, clientKey;
100 saltedPassword = hi(password.getBytes(), salt, iterations);
101 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
102 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
103 return new KeyPair(clientKey, serverKey);
104 }
105
106 @Override
107 public String getMechanism() {
108 return "";
109 }
110
111 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
112 return getHMac(key).hashBytes(input).asBytes();
113 }
114
115 private byte[] digest(final byte[] bytes) {
116 return getDigest().hashBytes(bytes).asBytes();
117 }
118
119 /*
120 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
121 * pseudorandom function (PRF) and with dkLen == output length of
122 * HMAC() == output length of H().
123 */
124 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
125 throws InvalidKeyException {
126 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
127 byte[] out = u.clone();
128 for (int i = 1; i < iterations; i++) {
129 u = hmac(key, u);
130 for (int j = 0; j < u.length; j++) {
131 out[j] ^= u[j];
132 }
133 }
134 return out;
135 }
136
137 @Override
138 public String getClientFirstMessage(final SSLSocket sslSocket) {
139 Preconditions.checkState(
140 this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
141 this.state = State.AUTH_TEXT_SENT;
142 final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
143 return BaseEncoding.base64().encode(message);
144 }
145
146 @Override
147 public String getResponse(final String challenge, final SSLSocket socket)
148 throws AuthenticationException {
149 return switch (state) {
150 case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
151 case RESPONSE_SENT -> processServerFinalMessage(challenge);
152 default -> throw new InvalidStateException(state);
153 };
154 }
155
156 private String processServerFirstMessage(final String challenge, final SSLSocket socket)
157 throws AuthenticationException {
158 if (Strings.isNullOrEmpty(challenge)) {
159 throw new AuthenticationException("challenge can not be null");
160 }
161 byte[] serverFirstMessage;
162 try {
163 serverFirstMessage = BaseEncoding.base64().decode(challenge);
164 } catch (final IllegalArgumentException e) {
165 throw new AuthenticationException("Unable to decode server challenge", e);
166 }
167 final Map<String, String> attributes;
168 try {
169 attributes = splitToAttributes(new String(serverFirstMessage));
170 } catch (final IllegalArgumentException e) {
171 throw new AuthenticationException("Duplicate attributes");
172 }
173 if (attributes.containsKey("m")) {
174 /*
175 * RFC 5802:
176 * m: This attribute is reserved for future extensibility. In this
177 * version of SCRAM, its presence in a client or a server message
178 * MUST cause authentication failure when the attribute is parsed by
179 * the other end.
180 */
181 throw new AuthenticationException("Server sent reserved token: 'm'");
182 }
183 final String i = attributes.get("i");
184 final String s = attributes.get("s");
185 final String nonce = attributes.get("r");
186 final String d = attributes.get("d");
187 if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
188 throw new AuthenticationException("Missing attributes from server first message");
189 }
190 final Integer iterationCount = Ints.tryParse(i);
191
192 if (iterationCount == null || iterationCount < 0) {
193 throw new AuthenticationException("Server did not send iteration count");
194 }
195 if (!nonce.startsWith(clientNonce)) {
196 throw new AuthenticationException(
197 "Server nonce does not contain client nonce: " + nonce);
198 }
199
200 final byte[] salt;
201
202 try {
203 salt = BaseEncoding.base64().decode(s);
204 } catch (final IllegalArgumentException e) {
205 throw new AuthenticationException("Invalid salt in server first message");
206 }
207
208 if (d != null && this.downgradeProtection != null) {
209 final String asSeenInFeatures;
210 try {
211 asSeenInFeatures = downgradeProtection.asDString();
212 } catch (final SecurityException e) {
213 throw new AuthenticationException(e);
214 }
215 final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
216 if (!hashed.equals(d)) {
217 throw new AuthenticationException("Mismatch in SSDP");
218 }
219 }
220
221 final byte[] channelBindingData = getChannelBindingData(socket);
222
223 final int gs2Len = this.gs2Header.getBytes().length;
224 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
225 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
226 System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
227
228 final String clientFinalMessageWithoutProof =
229 String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
230
231 final var authMessage =
232 Joiner.on(',')
233 .join(
234 clientFirstMessageBare,
235 new String(serverFirstMessage),
236 clientFinalMessageWithoutProof);
237
238 final KeyPair keys;
239 try {
240 keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
241 } catch (final ExecutionException e) {
242 throw new AuthenticationException("Invalid keys generated");
243 }
244 final byte[] clientSignature;
245 try {
246 serverSignature = hmac(keys.serverKey, authMessage.getBytes());
247 final byte[] storedKey = digest(keys.clientKey);
248
249 clientSignature = hmac(storedKey, authMessage.getBytes());
250
251 } catch (final InvalidKeyException e) {
252 throw new AuthenticationException(e);
253 }
254
255 final byte[] clientProof = new byte[keys.clientKey.length];
256
257 if (clientSignature.length < keys.clientKey.length) {
258 throw new AuthenticationException("client signature was shorter than clientKey");
259 }
260
261 for (int j = 0; j < clientProof.length; j++) {
262 clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
263 }
264
265 final var clientFinalMessage =
266 String.format(
267 "%s,p=%s",
268 clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
269 this.state = State.RESPONSE_SENT;
270 return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
271 }
272
273 private Map<String, String> splitToAttributes(final String message) {
274 final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
275 for (final String token : Splitter.on(',').split(message)) {
276 final var tuple = Splitter.on('=').limit(2).splitToList(token);
277 if (tuple.size() == 2) {
278 builder.put(tuple.get(0), tuple.get(1));
279 }
280 }
281 return builder.buildOrThrow();
282 }
283
284 private String processServerFinalMessage(final String challenge)
285 throws AuthenticationException {
286 final String serverFinalMessage;
287 try {
288 serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
289 } catch (final IllegalArgumentException e) {
290 throw new AuthenticationException("Invalid base64 in server final message", e);
291 }
292 final var clientCalculatedServerFinalMessage =
293 String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
294 if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
295 this.state = State.VALID_SERVER_RESPONSE;
296 return "";
297 }
298 throw new AuthenticationException(
299 "Server final message does not match calculated final message");
300 }
301
302 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
303 throws AuthenticationException {
304 if (this.channelBinding == ChannelBinding.NONE) {
305 return new byte[0];
306 }
307 throw new AssertionError("getChannelBindingData needs to be overwritten");
308 }
309
310 private static class CacheKey {
311 private final String algorithm;
312 private final String password;
313 private final byte[] salt;
314 private final int iterations;
315
316 private CacheKey(
317 final String algorithm,
318 final String password,
319 final byte[] salt,
320 final int iterations) {
321 this.algorithm = algorithm;
322 this.password = password;
323 this.salt = salt;
324 this.iterations = iterations;
325 }
326
327 @Override
328 public boolean equals(final Object o) {
329 if (this == o) return true;
330 if (o == null || getClass() != o.getClass()) return false;
331 CacheKey cacheKey = (CacheKey) o;
332 return iterations == cacheKey.iterations
333 && Objects.equal(algorithm, cacheKey.algorithm)
334 && Objects.equal(password, cacheKey.password)
335 && Arrays.equals(salt, cacheKey.salt);
336 }
337
338 @Override
339 public int hashCode() {
340 final int result = Objects.hashCode(algorithm, password, iterations);
341 return 31 * result + Arrays.hashCode(salt);
342 }
343 }
344
345 private static class KeyPair {
346 final byte[] clientKey;
347 final byte[] serverKey;
348
349 KeyPair(final byte[] clientKey, final byte[] serverKey) {
350 this.clientKey = clientKey;
351 this.serverKey = serverKey;
352 }
353 }
354}