1package eu.siacs.conversations.crypto.sasl;
2
3import com.google.common.base.CaseFormat;
4import com.google.common.base.Joiner;
5import com.google.common.base.Objects;
6import com.google.common.base.Preconditions;
7import com.google.common.base.Splitter;
8import com.google.common.base.Strings;
9import com.google.common.cache.Cache;
10import com.google.common.cache.CacheBuilder;
11import com.google.common.collect.ImmutableMap;
12import com.google.common.hash.HashFunction;
13import com.google.common.io.BaseEncoding;
14import com.google.common.primitives.Ints;
15import eu.siacs.conversations.entities.Account;
16import eu.siacs.conversations.utils.CryptoHelper;
17import java.security.InvalidKeyException;
18import java.util.Arrays;
19import java.util.Map;
20import java.util.concurrent.ExecutionException;
21import javax.crypto.SecretKey;
22import javax.net.ssl.SSLSocket;
23
24public abstract class ScramMechanism extends SaslMechanism {
25
26 public static final SecretKey EMPTY_KEY =
27 new SecretKey() {
28 @Override
29 public String getAlgorithm() {
30 return "HMAC";
31 }
32
33 @Override
34 public String getFormat() {
35 return "RAW";
36 }
37
38 @Override
39 public byte[] getEncoded() {
40 return new byte[0];
41 }
42 };
43
44 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
45 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
46 private static final Cache<CacheKey, KeyPair> CACHE =
47 CacheBuilder.newBuilder().maximumSize(10).build();
48 protected final ChannelBinding channelBinding;
49 private final String gs2Header;
50 private final String clientNonce;
51 private final String clientFirstMessageBare;
52 private byte[] serverSignature = null;
53 private DowngradeProtection downgradeProtection = null;
54
55 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
56 super(account);
57 this.channelBinding = channelBinding;
58 if (channelBinding == ChannelBinding.NONE) {
59 this.gs2Header = "y,,";
60 } else {
61 this.gs2Header =
62 String.format(
63 "p=%s,,",
64 CaseFormat.UPPER_UNDERSCORE
65 .converterTo(CaseFormat.LOWER_HYPHEN)
66 .convert(channelBinding.toString()));
67 }
68 // This nonce should be different for each authentication attempt.
69 this.clientNonce = CryptoHelper.random(100);
70 this.clientFirstMessageBare =
71 String.format(
72 "n=%s,r=%s",
73 CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
74 this.clientNonce);
75 }
76
77 public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
78 Preconditions.checkState(
79 this.state == State.INITIAL, "setting downgrade protection in invalid state");
80 this.downgradeProtection = downgradeProtection;
81 }
82
83 protected abstract HashFunction getHMac(final byte[] key);
84
85 protected abstract HashFunction getDigest();
86
87 private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
88 throws ExecutionException {
89 final var key = new CacheKey(getMechanism(), password, salt, iterations);
90 return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
91 }
92
93 private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
94 throws InvalidKeyException {
95 final byte[] saltedPassword, serverKey, clientKey;
96 saltedPassword = hi(password.getBytes(), salt, iterations);
97 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
98 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
99 return new KeyPair(clientKey, serverKey);
100 }
101
102 @Override
103 public String getMechanism() {
104 return "";
105 }
106
107 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
108 return getHMac(key).hashBytes(input).asBytes();
109 }
110
111 private byte[] digest(final byte[] bytes) {
112 return getDigest().hashBytes(bytes).asBytes();
113 }
114
115 /*
116 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
117 * pseudorandom function (PRF) and with dkLen == output length of
118 * HMAC() == output length of H().
119 */
120 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
121 throws InvalidKeyException {
122 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
123 byte[] out = u.clone();
124 for (int i = 1; i < iterations; i++) {
125 u = hmac(key, u);
126 for (int j = 0; j < u.length; j++) {
127 out[j] ^= u[j];
128 }
129 }
130 return out;
131 }
132
133 @Override
134 public String getClientFirstMessage(final SSLSocket sslSocket) {
135 Preconditions.checkState(
136 this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
137 this.state = State.AUTH_TEXT_SENT;
138 final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
139 return BaseEncoding.base64().encode(message);
140 }
141
142 @Override
143 public String getResponse(final String challenge, final SSLSocket socket)
144 throws AuthenticationException {
145 return switch (state) {
146 case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
147 case RESPONSE_SENT -> processServerFinalMessage(challenge);
148 default -> throw new InvalidStateException(state);
149 };
150 }
151
152 private String processServerFirstMessage(final String challenge, final SSLSocket socket)
153 throws AuthenticationException {
154 if (Strings.isNullOrEmpty(challenge)) {
155 throw new AuthenticationException("challenge can not be null");
156 }
157 byte[] serverFirstMessage;
158 try {
159 serverFirstMessage = BaseEncoding.base64().decode(challenge);
160 } catch (final IllegalArgumentException e) {
161 throw new AuthenticationException("Unable to decode server challenge", e);
162 }
163 final Map<String, String> attributes;
164 try {
165 attributes = splitToAttributes(new String(serverFirstMessage));
166 } catch (final IllegalArgumentException e) {
167 throw new AuthenticationException("Duplicate attributes");
168 }
169 if (attributes.containsKey("m")) {
170 /*
171 * RFC 5802:
172 * m: This attribute is reserved for future extensibility. In this
173 * version of SCRAM, its presence in a client or a server message
174 * MUST cause authentication failure when the attribute is parsed by
175 * the other end.
176 */
177 throw new AuthenticationException("Server sent reserved token: 'm'");
178 }
179 final String i = attributes.get("i");
180 final String s = attributes.get("s");
181 final String nonce = attributes.get("r");
182 final String h = attributes.get("h");
183 if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
184 throw new AuthenticationException("Missing attributes from server first message");
185 }
186 final Integer iterationCount = Ints.tryParse(i);
187
188 if (iterationCount == null || iterationCount < 0) {
189 throw new AuthenticationException("Server did not send iteration count");
190 }
191 if (!nonce.startsWith(clientNonce)) {
192 throw new AuthenticationException(
193 "Server nonce does not contain client nonce: " + nonce);
194 }
195
196 final byte[] salt;
197
198 try {
199 salt = BaseEncoding.base64().decode(s);
200 } catch (final IllegalArgumentException e) {
201 throw new AuthenticationException("Invalid salt in server first message");
202 }
203
204 if (h != null && this.downgradeProtection != null) {
205 final String asSeenInFeatures;
206 try {
207 asSeenInFeatures = downgradeProtection.asHString();
208 } catch (final SecurityException e) {
209 throw new AuthenticationException(e);
210 }
211 final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
212 if (!hashed.equals(h)) {
213 throw new AuthenticationException("Mismatch in SSDP");
214 }
215 }
216
217 final byte[] channelBindingData = getChannelBindingData(socket);
218
219 final int gs2Len = this.gs2Header.getBytes().length;
220 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
221 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
222 System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
223
224 final String clientFinalMessageWithoutProof =
225 String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
226
227 final var authMessage =
228 Joiner.on(',')
229 .join(
230 clientFirstMessageBare,
231 new String(serverFirstMessage),
232 clientFinalMessageWithoutProof);
233
234 final KeyPair keys;
235 try {
236 keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
237 } catch (final ExecutionException e) {
238 throw new AuthenticationException("Invalid keys generated");
239 }
240 final byte[] clientSignature;
241 try {
242 serverSignature = hmac(keys.serverKey, authMessage.getBytes());
243 final byte[] storedKey = digest(keys.clientKey);
244
245 clientSignature = hmac(storedKey, authMessage.getBytes());
246
247 } catch (final InvalidKeyException e) {
248 throw new AuthenticationException(e);
249 }
250
251 final byte[] clientProof = new byte[keys.clientKey.length];
252
253 if (clientSignature.length < keys.clientKey.length) {
254 throw new AuthenticationException("client signature was shorter than clientKey");
255 }
256
257 for (int j = 0; j < clientProof.length; j++) {
258 clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
259 }
260
261 final var clientFinalMessage =
262 String.format(
263 "%s,p=%s",
264 clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
265 this.state = State.RESPONSE_SENT;
266 return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
267 }
268
269 private Map<String, String> splitToAttributes(final String message) {
270 final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
271 for (final String token : Splitter.on(',').split(message)) {
272 final var tuple = Splitter.on('=').limit(2).splitToList(token);
273 if (tuple.size() == 2) {
274 builder.put(tuple.get(0), tuple.get(1));
275 }
276 }
277 return builder.buildOrThrow();
278 }
279
280 private String processServerFinalMessage(final String challenge)
281 throws AuthenticationException {
282 final String serverFinalMessage;
283 try {
284 serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
285 } catch (final IllegalArgumentException e) {
286 throw new AuthenticationException("Invalid base64 in server final message", e);
287 }
288 final var clientCalculatedServerFinalMessage =
289 String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
290 if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
291 this.state = State.VALID_SERVER_RESPONSE;
292 return "";
293 }
294 throw new AuthenticationException(
295 "Server final message does not match calculated final message");
296 }
297
298 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
299 throws AuthenticationException {
300 if (this.channelBinding == ChannelBinding.NONE) {
301 return new byte[0];
302 }
303 throw new AssertionError("getChannelBindingData needs to be overwritten");
304 }
305
306 private static class CacheKey {
307 private final String algorithm;
308 private final String password;
309 private final byte[] salt;
310 private final int iterations;
311
312 private CacheKey(
313 final String algorithm,
314 final String password,
315 final byte[] salt,
316 final int iterations) {
317 this.algorithm = algorithm;
318 this.password = password;
319 this.salt = salt;
320 this.iterations = iterations;
321 }
322
323 @Override
324 public boolean equals(final Object o) {
325 if (this == o) return true;
326 if (o == null || getClass() != o.getClass()) return false;
327 CacheKey cacheKey = (CacheKey) o;
328 return iterations == cacheKey.iterations
329 && Objects.equal(algorithm, cacheKey.algorithm)
330 && Objects.equal(password, cacheKey.password)
331 && Arrays.equals(salt, cacheKey.salt);
332 }
333
334 @Override
335 public int hashCode() {
336 final int result = Objects.hashCode(algorithm, password, iterations);
337 return 31 * result + Arrays.hashCode(salt);
338 }
339 }
340
341 private static class KeyPair {
342 final byte[] clientKey;
343 final byte[] serverKey;
344
345 KeyPair(final byte[] clientKey, final byte[] serverKey) {
346 this.clientKey = clientKey;
347 this.serverKey = serverKey;
348 }
349 }
350}