ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import com.google.common.base.CaseFormat;
  4import com.google.common.base.Joiner;
  5import com.google.common.base.Objects;
  6import com.google.common.base.Splitter;
  7import com.google.common.base.Strings;
  8import com.google.common.cache.Cache;
  9import com.google.common.cache.CacheBuilder;
 10import com.google.common.collect.ImmutableMap;
 11import com.google.common.hash.HashFunction;
 12import com.google.common.io.BaseEncoding;
 13import com.google.common.primitives.Ints;
 14import eu.siacs.conversations.entities.Account;
 15import eu.siacs.conversations.utils.CryptoHelper;
 16import java.security.InvalidKeyException;
 17import java.util.Arrays;
 18import java.util.Map;
 19import java.util.concurrent.ExecutionException;
 20import javax.crypto.SecretKey;
 21import javax.net.ssl.SSLSocket;
 22
 23abstract class ScramMechanism extends SaslMechanism {
 24
 25    public static final SecretKey EMPTY_KEY =
 26            new SecretKey() {
 27                @Override
 28                public String getAlgorithm() {
 29                    return "HMAC";
 30                }
 31
 32                @Override
 33                public String getFormat() {
 34                    return "RAW";
 35                }
 36
 37                @Override
 38                public byte[] getEncoded() {
 39                    return new byte[0];
 40                }
 41            };
 42
 43    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 44    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 45    private static final Cache<CacheKey, KeyPair> CACHE =
 46            CacheBuilder.newBuilder().maximumSize(10).build();
 47    protected final ChannelBinding channelBinding;
 48    private final String gs2Header;
 49    private final String clientNonce;
 50    protected State state = State.INITIAL;
 51    private final String clientFirstMessageBare;
 52    private byte[] serverSignature = null;
 53
 54    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 55        super(account);
 56        this.channelBinding = channelBinding;
 57        if (channelBinding == ChannelBinding.NONE) {
 58            // TODO this needs to be changed to "y,," for the scram internal down grade protection
 59            // but we might risk compatibility issues if the server supports a binding that we don’t
 60            // support
 61            this.gs2Header = "n,,";
 62        } else {
 63            this.gs2Header =
 64                    String.format(
 65                            "p=%s,,",
 66                            CaseFormat.UPPER_UNDERSCORE
 67                                    .converterTo(CaseFormat.LOWER_HYPHEN)
 68                                    .convert(channelBinding.toString()));
 69        }
 70        // This nonce should be different for each authentication attempt.
 71        this.clientNonce = CryptoHelper.random(100);
 72        this.clientFirstMessageBare =
 73                String.format(
 74                        "n=%s,r=%s",
 75                        CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
 76                        this.clientNonce);
 77    }
 78
 79    protected abstract HashFunction getHMac(final byte[] key);
 80
 81    protected abstract HashFunction getDigest();
 82
 83    private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
 84            throws ExecutionException {
 85        final var key = new CacheKey(getMechanism(), password, salt, iterations);
 86        return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
 87    }
 88
 89    private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
 90            throws InvalidKeyException {
 91        final byte[] saltedPassword, serverKey, clientKey;
 92        saltedPassword = hi(password.getBytes(), salt, iterations);
 93        serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 94        clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 95        return new KeyPair(clientKey, serverKey);
 96    }
 97
 98    @Override
 99    public String getMechanism() {
100        return "";
101    }
102
103    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
104        return getHMac(key).hashBytes(input).asBytes();
105    }
106
107    private byte[] digest(final byte[] bytes) {
108        return getDigest().hashBytes(bytes).asBytes();
109    }
110
111    /*
112     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
113     * pseudorandom function (PRF) and with dkLen == output length of
114     * HMAC() == output length of H().
115     */
116    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
117            throws InvalidKeyException {
118        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
119        byte[] out = u.clone();
120        for (int i = 1; i < iterations; i++) {
121            u = hmac(key, u);
122            for (int j = 0; j < u.length; j++) {
123                out[j] ^= u[j];
124            }
125        }
126        return out;
127    }
128
129    @Override
130    public String getClientFirstMessage(final SSLSocket sslSocket) {
131        if (this.state != State.INITIAL) {
132            throw new IllegalArgumentException("Calling getClientFirstMessage from invalid state");
133        }
134        this.state = State.AUTH_TEXT_SENT;
135        final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
136        return BaseEncoding.base64().encode(message);
137    }
138
139    @Override
140    public String getResponse(final String challenge, final SSLSocket socket)
141            throws AuthenticationException {
142        return switch (state) {
143            case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
144            case RESPONSE_SENT -> processServerFinalMessage(challenge);
145            default -> throw new InvalidStateException(state);
146        };
147    }
148
149    private String processServerFirstMessage(final String challenge, final SSLSocket socket)
150            throws AuthenticationException {
151        if (Strings.isNullOrEmpty(challenge)) {
152            throw new AuthenticationException("challenge can not be null");
153        }
154        byte[] serverFirstMessage;
155        try {
156            serverFirstMessage = BaseEncoding.base64().decode(challenge);
157        } catch (final IllegalArgumentException e) {
158            throw new AuthenticationException("Unable to decode server challenge", e);
159        }
160        final Map<String, String> attributes;
161        try {
162            attributes = splitToAttributes(new String(serverFirstMessage));
163        } catch (final IllegalArgumentException e) {
164            throw new AuthenticationException("Duplicate attributes");
165        }
166        if (attributes.containsKey("m")) {
167            /*
168             * RFC 5802:
169             * m: This attribute is reserved for future extensibility.  In this
170             * version of SCRAM, its presence in a client or a server message
171             * MUST cause authentication failure when the attribute is parsed by
172             * the other end.
173             */
174            throw new AuthenticationException("Server sent reserved token: 'm'");
175        }
176        final String i = attributes.get("i");
177        final String s = attributes.get("s");
178        final String nonce = attributes.get("r");
179        final String d = attributes.get("d");
180        if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
181            throw new AuthenticationException("Missing attributes from server first message");
182        }
183        final Integer iterationCount = Ints.tryParse(i);
184
185        if (iterationCount == null || iterationCount < 0) {
186            throw new AuthenticationException("Server did not send iteration count");
187        }
188        if (!nonce.startsWith(clientNonce)) {
189            throw new AuthenticationException(
190                    "Server nonce does not contain client nonce: " + nonce);
191        }
192
193        final byte[] salt;
194
195        try {
196            salt = BaseEncoding.base64().decode(s);
197        } catch (final IllegalArgumentException e) {
198            throw new AuthenticationException("Invalid salt in server first message");
199        }
200
201        final byte[] channelBindingData = getChannelBindingData(socket);
202
203        final int gs2Len = this.gs2Header.getBytes().length;
204        final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
205        System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
206        System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
207
208        final String clientFinalMessageWithoutProof =
209                String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
210
211        final var authMessage =
212                Joiner.on(',')
213                        .join(
214                                clientFirstMessageBare,
215                                new String(serverFirstMessage),
216                                clientFinalMessageWithoutProof);
217
218        final KeyPair keys;
219        try {
220            keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
221        } catch (final ExecutionException e) {
222            throw new AuthenticationException("Invalid keys generated");
223        }
224        final byte[] clientSignature;
225        try {
226            serverSignature = hmac(keys.serverKey, authMessage.getBytes());
227            final byte[] storedKey = digest(keys.clientKey);
228
229            clientSignature = hmac(storedKey, authMessage.getBytes());
230
231        } catch (final InvalidKeyException e) {
232            throw new AuthenticationException(e);
233        }
234
235        final byte[] clientProof = new byte[keys.clientKey.length];
236
237        if (clientSignature.length < keys.clientKey.length) {
238            throw new AuthenticationException("client signature was shorter than clientKey");
239        }
240
241        for (int j = 0; j < clientProof.length; j++) {
242            clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
243        }
244
245        final var clientFinalMessage =
246                String.format(
247                        "%s,p=%s",
248                        clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
249        this.state = State.RESPONSE_SENT;
250        return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
251    }
252
253    private Map<String, String> splitToAttributes(final String message) {
254        final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
255        for (final String token : Splitter.on(',').split(message)) {
256            final var tuple = Splitter.on('=').limit(2).splitToList(token);
257            if (tuple.size() == 2) {
258                builder.put(tuple.get(0), tuple.get(1));
259            }
260        }
261        return builder.buildOrThrow();
262    }
263
264    private String processServerFinalMessage(final String challenge)
265            throws AuthenticationException {
266        final String serverFinalMessage;
267        try {
268            serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
269        } catch (final IllegalArgumentException e) {
270            throw new AuthenticationException("Invalid base64 in server final message", e);
271        }
272        final var clientCalculatedServerFinalMessage =
273                String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
274        if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
275            this.state = State.VALID_SERVER_RESPONSE;
276            return "";
277        }
278        throw new AuthenticationException(
279                "Server final message does not match calculated final message");
280    }
281
282    protected byte[] getChannelBindingData(final SSLSocket sslSocket)
283            throws AuthenticationException {
284        if (this.channelBinding == ChannelBinding.NONE) {
285            return new byte[0];
286        }
287        throw new AssertionError("getChannelBindingData needs to be overwritten");
288    }
289
290    private static class CacheKey {
291        private final String algorithm;
292        private final String password;
293        private final byte[] salt;
294        private final int iterations;
295
296        private CacheKey(
297                final String algorithm,
298                final String password,
299                final byte[] salt,
300                final int iterations) {
301            this.algorithm = algorithm;
302            this.password = password;
303            this.salt = salt;
304            this.iterations = iterations;
305        }
306
307        @Override
308        public boolean equals(final Object o) {
309            if (this == o) return true;
310            if (o == null || getClass() != o.getClass()) return false;
311            CacheKey cacheKey = (CacheKey) o;
312            return iterations == cacheKey.iterations
313                    && Objects.equal(algorithm, cacheKey.algorithm)
314                    && Objects.equal(password, cacheKey.password)
315                    && Arrays.equals(salt, cacheKey.salt);
316        }
317
318        @Override
319        public int hashCode() {
320            final int result = Objects.hashCode(algorithm, password, iterations);
321            return 31 * result + Arrays.hashCode(salt);
322        }
323    }
324
325    private static class KeyPair {
326        final byte[] clientKey;
327        final byte[] serverKey;
328
329        KeyPair(final byte[] clientKey, final byte[] serverKey) {
330            this.clientKey = clientKey;
331            this.serverKey = serverKey;
332        }
333    }
334}