1package eu.siacs.conversations.crypto.sasl;
2
3import com.google.common.base.CaseFormat;
4import com.google.common.base.Joiner;
5import com.google.common.base.Objects;
6import com.google.common.base.Splitter;
7import com.google.common.base.Strings;
8import com.google.common.cache.Cache;
9import com.google.common.cache.CacheBuilder;
10import com.google.common.collect.ImmutableMap;
11import com.google.common.hash.HashFunction;
12import com.google.common.io.BaseEncoding;
13import com.google.common.primitives.Ints;
14import eu.siacs.conversations.entities.Account;
15import eu.siacs.conversations.utils.CryptoHelper;
16import java.security.InvalidKeyException;
17import java.util.Arrays;
18import java.util.Map;
19import java.util.concurrent.ExecutionException;
20import javax.crypto.SecretKey;
21import javax.net.ssl.SSLSocket;
22
23abstract class ScramMechanism extends SaslMechanism {
24
25 public static final SecretKey EMPTY_KEY =
26 new SecretKey() {
27 @Override
28 public String getAlgorithm() {
29 return "HMAC";
30 }
31
32 @Override
33 public String getFormat() {
34 return "RAW";
35 }
36
37 @Override
38 public byte[] getEncoded() {
39 return new byte[0];
40 }
41 };
42
43 private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
44 private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
45 private static final Cache<CacheKey, KeyPair> CACHE =
46 CacheBuilder.newBuilder().maximumSize(10).build();
47 protected final ChannelBinding channelBinding;
48 private final String gs2Header;
49 private final String clientNonce;
50 protected State state = State.INITIAL;
51 private final String clientFirstMessageBare;
52 private byte[] serverSignature = null;
53
54 ScramMechanism(final Account account, final ChannelBinding channelBinding) {
55 super(account);
56 this.channelBinding = channelBinding;
57 if (channelBinding == ChannelBinding.NONE) {
58 // TODO this needs to be changed to "y,," for the scram internal down grade protection
59 // but we might risk compatibility issues if the server supports a binding that we don’t
60 // support
61 this.gs2Header = "n,,";
62 } else {
63 this.gs2Header =
64 String.format(
65 "p=%s,,",
66 CaseFormat.UPPER_UNDERSCORE
67 .converterTo(CaseFormat.LOWER_HYPHEN)
68 .convert(channelBinding.toString()));
69 }
70 // This nonce should be different for each authentication attempt.
71 this.clientNonce = CryptoHelper.random(100);
72 this.clientFirstMessageBare =
73 String.format(
74 "n=%s,r=%s",
75 CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
76 this.clientNonce);
77 }
78
79 protected abstract HashFunction getHMac(final byte[] key);
80
81 protected abstract HashFunction getDigest();
82
83 private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
84 throws ExecutionException {
85 final var key = new CacheKey(getMechanism(), password, salt, iterations);
86 return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
87 }
88
89 private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
90 throws InvalidKeyException {
91 final byte[] saltedPassword, serverKey, clientKey;
92 saltedPassword = hi(password.getBytes(), salt, iterations);
93 serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
94 clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
95 return new KeyPair(clientKey, serverKey);
96 }
97
98 @Override
99 public String getMechanism() {
100 return "";
101 }
102
103 private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
104 return getHMac(key).hashBytes(input).asBytes();
105 }
106
107 private byte[] digest(final byte[] bytes) {
108 return getDigest().hashBytes(bytes).asBytes();
109 }
110
111 /*
112 * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
113 * pseudorandom function (PRF) and with dkLen == output length of
114 * HMAC() == output length of H().
115 */
116 private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
117 throws InvalidKeyException {
118 byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
119 byte[] out = u.clone();
120 for (int i = 1; i < iterations; i++) {
121 u = hmac(key, u);
122 for (int j = 0; j < u.length; j++) {
123 out[j] ^= u[j];
124 }
125 }
126 return out;
127 }
128
129 @Override
130 public String getClientFirstMessage(final SSLSocket sslSocket) {
131 if (this.state != State.INITIAL) {
132 throw new IllegalArgumentException("Calling getClientFirstMessage from invalid state");
133 }
134 this.state = State.AUTH_TEXT_SENT;
135 final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
136 return BaseEncoding.base64().encode(message);
137 }
138
139 @Override
140 public String getResponse(final String challenge, final SSLSocket socket)
141 throws AuthenticationException {
142 return switch (state) {
143 case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
144 case RESPONSE_SENT -> processServerFinalMessage(challenge);
145 default -> throw new InvalidStateException(state);
146 };
147 }
148
149 private String processServerFirstMessage(final String challenge, final SSLSocket socket)
150 throws AuthenticationException {
151 if (Strings.isNullOrEmpty(challenge)) {
152 throw new AuthenticationException("challenge can not be null");
153 }
154 byte[] serverFirstMessage;
155 try {
156 serverFirstMessage = BaseEncoding.base64().decode(challenge);
157 } catch (final IllegalArgumentException e) {
158 throw new AuthenticationException("Unable to decode server challenge", e);
159 }
160 final Map<String, String> attributes;
161 try {
162 attributes = splitToAttributes(new String(serverFirstMessage));
163 } catch (final IllegalArgumentException e) {
164 throw new AuthenticationException("Duplicate attributes");
165 }
166 if (attributes.containsKey("m")) {
167 /*
168 * RFC 5802:
169 * m: This attribute is reserved for future extensibility. In this
170 * version of SCRAM, its presence in a client or a server message
171 * MUST cause authentication failure when the attribute is parsed by
172 * the other end.
173 */
174 throw new AuthenticationException("Server sent reserved token: 'm'");
175 }
176 final String i = attributes.get("i");
177 final String s = attributes.get("s");
178 final String nonce = attributes.get("r");
179 final String d = attributes.get("d");
180 if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
181 throw new AuthenticationException("Missing attributes from server first message");
182 }
183 final Integer iterationCount = Ints.tryParse(i);
184
185 if (iterationCount == null || iterationCount < 0) {
186 throw new AuthenticationException("Server did not send iteration count");
187 }
188 if (!nonce.startsWith(clientNonce)) {
189 throw new AuthenticationException(
190 "Server nonce does not contain client nonce: " + nonce);
191 }
192
193 final byte[] salt;
194
195 try {
196 salt = BaseEncoding.base64().decode(s);
197 } catch (final IllegalArgumentException e) {
198 throw new AuthenticationException("Invalid salt in server first message");
199 }
200
201 final byte[] channelBindingData = getChannelBindingData(socket);
202
203 final int gs2Len = this.gs2Header.getBytes().length;
204 final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
205 System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
206 System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
207
208 final String clientFinalMessageWithoutProof =
209 String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
210
211 final var authMessage =
212 Joiner.on(',')
213 .join(
214 clientFirstMessageBare,
215 new String(serverFirstMessage),
216 clientFinalMessageWithoutProof);
217
218 final KeyPair keys;
219 try {
220 keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
221 } catch (final ExecutionException e) {
222 throw new AuthenticationException("Invalid keys generated");
223 }
224 final byte[] clientSignature;
225 try {
226 serverSignature = hmac(keys.serverKey, authMessage.getBytes());
227 final byte[] storedKey = digest(keys.clientKey);
228
229 clientSignature = hmac(storedKey, authMessage.getBytes());
230
231 } catch (final InvalidKeyException e) {
232 throw new AuthenticationException(e);
233 }
234
235 final byte[] clientProof = new byte[keys.clientKey.length];
236
237 if (clientSignature.length < keys.clientKey.length) {
238 throw new AuthenticationException("client signature was shorter than clientKey");
239 }
240
241 for (int j = 0; j < clientProof.length; j++) {
242 clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
243 }
244
245 final var clientFinalMessage =
246 String.format(
247 "%s,p=%s",
248 clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
249 this.state = State.RESPONSE_SENT;
250 return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
251 }
252
253 private Map<String, String> splitToAttributes(final String message) {
254 final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
255 for (final String token : Splitter.on(',').split(message)) {
256 final var tuple = Splitter.on('=').limit(2).splitToList(token);
257 if (tuple.size() == 2) {
258 builder.put(tuple.get(0), tuple.get(1));
259 }
260 }
261 return builder.buildOrThrow();
262 }
263
264 private String processServerFinalMessage(final String challenge)
265 throws AuthenticationException {
266 final String serverFinalMessage;
267 try {
268 serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
269 } catch (final IllegalArgumentException e) {
270 throw new AuthenticationException("Invalid base64 in server final message", e);
271 }
272 final var clientCalculatedServerFinalMessage =
273 String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
274 if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
275 this.state = State.VALID_SERVER_RESPONSE;
276 return "";
277 }
278 throw new AuthenticationException(
279 "Server final message does not match calculated final message");
280 }
281
282 protected byte[] getChannelBindingData(final SSLSocket sslSocket)
283 throws AuthenticationException {
284 if (this.channelBinding == ChannelBinding.NONE) {
285 return new byte[0];
286 }
287 throw new AssertionError("getChannelBindingData needs to be overwritten");
288 }
289
290 private static class CacheKey {
291 private final String algorithm;
292 private final String password;
293 private final byte[] salt;
294 private final int iterations;
295
296 private CacheKey(
297 final String algorithm,
298 final String password,
299 final byte[] salt,
300 final int iterations) {
301 this.algorithm = algorithm;
302 this.password = password;
303 this.salt = salt;
304 this.iterations = iterations;
305 }
306
307 @Override
308 public boolean equals(final Object o) {
309 if (this == o) return true;
310 if (o == null || getClass() != o.getClass()) return false;
311 CacheKey cacheKey = (CacheKey) o;
312 return iterations == cacheKey.iterations
313 && Objects.equal(algorithm, cacheKey.algorithm)
314 && Objects.equal(password, cacheKey.password)
315 && Arrays.equals(salt, cacheKey.salt);
316 }
317
318 @Override
319 public int hashCode() {
320 final int result = Objects.hashCode(algorithm, password, iterations);
321 return 31 * result + Arrays.hashCode(salt);
322 }
323 }
324
325 private static class KeyPair {
326 final byte[] clientKey;
327 final byte[] serverKey;
328
329 KeyPair(final byte[] clientKey, final byte[] serverKey) {
330 this.clientKey = clientKey;
331 this.serverKey = serverKey;
332 }
333 }
334}