ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import com.google.common.base.CaseFormat;
  4import com.google.common.base.Joiner;
  5import com.google.common.base.Objects;
  6import com.google.common.base.Preconditions;
  7import com.google.common.base.Splitter;
  8import com.google.common.base.Strings;
  9import com.google.common.cache.Cache;
 10import com.google.common.cache.CacheBuilder;
 11import com.google.common.collect.ImmutableMap;
 12import com.google.common.hash.HashFunction;
 13import com.google.common.io.BaseEncoding;
 14import com.google.common.primitives.Ints;
 15import eu.siacs.conversations.entities.Account;
 16import eu.siacs.conversations.utils.CryptoHelper;
 17import java.security.InvalidKeyException;
 18import java.util.Arrays;
 19import java.util.Map;
 20import java.util.concurrent.ExecutionException;
 21import javax.crypto.SecretKey;
 22import javax.net.ssl.SSLSocket;
 23
 24public abstract class ScramMechanism extends SaslMechanism {
 25
 26    public static final SecretKey EMPTY_KEY =
 27            new SecretKey() {
 28                @Override
 29                public String getAlgorithm() {
 30                    return "HMAC";
 31                }
 32
 33                @Override
 34                public String getFormat() {
 35                    return "RAW";
 36                }
 37
 38                @Override
 39                public byte[] getEncoded() {
 40                    return new byte[0];
 41                }
 42            };
 43
 44    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 45    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 46    private static final Cache<CacheKey, KeyPair> CACHE =
 47            CacheBuilder.newBuilder().maximumSize(10).build();
 48    protected final ChannelBinding channelBinding;
 49    private final String gs2Header;
 50    private final String clientNonce;
 51    private final String clientFirstMessageBare;
 52    private byte[] serverSignature = null;
 53    private DowngradeProtection downgradeProtection = null;
 54
 55    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 56        super(account);
 57        this.channelBinding = channelBinding;
 58        if (channelBinding == ChannelBinding.NONE) {
 59            // TODO this needs to be changed to "y,," for the scram internal down grade protection
 60            // but we might risk compatibility issues if the server supports a binding that we don’t
 61            // support
 62            this.gs2Header = "n,,";
 63        } else {
 64            this.gs2Header =
 65                    String.format(
 66                            "p=%s,,",
 67                            CaseFormat.UPPER_UNDERSCORE
 68                                    .converterTo(CaseFormat.LOWER_HYPHEN)
 69                                    .convert(channelBinding.toString()));
 70        }
 71        // This nonce should be different for each authentication attempt.
 72        this.clientNonce = CryptoHelper.random(100);
 73        this.clientFirstMessageBare =
 74                String.format(
 75                        "n=%s,r=%s",
 76                        CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
 77                        this.clientNonce);
 78    }
 79
 80    public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
 81        Preconditions.checkState(
 82                this.state == State.INITIAL, "setting downgrade protection in invalid state");
 83        this.downgradeProtection = downgradeProtection;
 84    }
 85
 86    protected abstract HashFunction getHMac(final byte[] key);
 87
 88    protected abstract HashFunction getDigest();
 89
 90    private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
 91            throws ExecutionException {
 92        final var key = new CacheKey(getMechanism(), password, salt, iterations);
 93        return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
 94    }
 95
 96    private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
 97            throws InvalidKeyException {
 98        final byte[] saltedPassword, serverKey, clientKey;
 99        saltedPassword = hi(password.getBytes(), salt, iterations);
100        serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
101        clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
102        return new KeyPair(clientKey, serverKey);
103    }
104
105    @Override
106    public String getMechanism() {
107        return "";
108    }
109
110    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
111        return getHMac(key).hashBytes(input).asBytes();
112    }
113
114    private byte[] digest(final byte[] bytes) {
115        return getDigest().hashBytes(bytes).asBytes();
116    }
117
118    /*
119     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
120     * pseudorandom function (PRF) and with dkLen == output length of
121     * HMAC() == output length of H().
122     */
123    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
124            throws InvalidKeyException {
125        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
126        byte[] out = u.clone();
127        for (int i = 1; i < iterations; i++) {
128            u = hmac(key, u);
129            for (int j = 0; j < u.length; j++) {
130                out[j] ^= u[j];
131            }
132        }
133        return out;
134    }
135
136    @Override
137    public String getClientFirstMessage(final SSLSocket sslSocket) {
138        Preconditions.checkState(
139                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
140        this.state = State.AUTH_TEXT_SENT;
141        final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
142        return BaseEncoding.base64().encode(message);
143    }
144
145    @Override
146    public String getResponse(final String challenge, final SSLSocket socket)
147            throws AuthenticationException {
148        return switch (state) {
149            case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
150            case RESPONSE_SENT -> processServerFinalMessage(challenge);
151            default -> throw new InvalidStateException(state);
152        };
153    }
154
155    private String processServerFirstMessage(final String challenge, final SSLSocket socket)
156            throws AuthenticationException {
157        if (Strings.isNullOrEmpty(challenge)) {
158            throw new AuthenticationException("challenge can not be null");
159        }
160        byte[] serverFirstMessage;
161        try {
162            serverFirstMessage = BaseEncoding.base64().decode(challenge);
163        } catch (final IllegalArgumentException e) {
164            throw new AuthenticationException("Unable to decode server challenge", e);
165        }
166        final Map<String, String> attributes;
167        try {
168            attributes = splitToAttributes(new String(serverFirstMessage));
169        } catch (final IllegalArgumentException e) {
170            throw new AuthenticationException("Duplicate attributes");
171        }
172        if (attributes.containsKey("m")) {
173            /*
174             * RFC 5802:
175             * m: This attribute is reserved for future extensibility.  In this
176             * version of SCRAM, its presence in a client or a server message
177             * MUST cause authentication failure when the attribute is parsed by
178             * the other end.
179             */
180            throw new AuthenticationException("Server sent reserved token: 'm'");
181        }
182        final String i = attributes.get("i");
183        final String s = attributes.get("s");
184        final String nonce = attributes.get("r");
185        final String h = attributes.get("h");
186        if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
187            throw new AuthenticationException("Missing attributes from server first message");
188        }
189        final Integer iterationCount = Ints.tryParse(i);
190
191        if (iterationCount == null || iterationCount < 0) {
192            throw new AuthenticationException("Server did not send iteration count");
193        }
194        if (!nonce.startsWith(clientNonce)) {
195            throw new AuthenticationException(
196                    "Server nonce does not contain client nonce: " + nonce);
197        }
198
199        final byte[] salt;
200
201        try {
202            salt = BaseEncoding.base64().decode(s);
203        } catch (final IllegalArgumentException e) {
204            throw new AuthenticationException("Invalid salt in server first message");
205        }
206
207        if (h != null && this.downgradeProtection != null) {
208            final String asSeenInFeatures;
209            try {
210                asSeenInFeatures = downgradeProtection.asHString();
211            } catch (final SecurityException e) {
212                throw new AuthenticationException(e);
213            }
214            final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
215            if (!hashed.equals(h)) {
216                throw new AuthenticationException("Mismatch in SSDP");
217            }
218        }
219
220        final byte[] channelBindingData = getChannelBindingData(socket);
221
222        final int gs2Len = this.gs2Header.getBytes().length;
223        final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
224        System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
225        System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
226
227        final String clientFinalMessageWithoutProof =
228                String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
229
230        final var authMessage =
231                Joiner.on(',')
232                        .join(
233                                clientFirstMessageBare,
234                                new String(serverFirstMessage),
235                                clientFinalMessageWithoutProof);
236
237        final KeyPair keys;
238        try {
239            keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
240        } catch (final ExecutionException e) {
241            throw new AuthenticationException("Invalid keys generated");
242        }
243        final byte[] clientSignature;
244        try {
245            serverSignature = hmac(keys.serverKey, authMessage.getBytes());
246            final byte[] storedKey = digest(keys.clientKey);
247
248            clientSignature = hmac(storedKey, authMessage.getBytes());
249
250        } catch (final InvalidKeyException e) {
251            throw new AuthenticationException(e);
252        }
253
254        final byte[] clientProof = new byte[keys.clientKey.length];
255
256        if (clientSignature.length < keys.clientKey.length) {
257            throw new AuthenticationException("client signature was shorter than clientKey");
258        }
259
260        for (int j = 0; j < clientProof.length; j++) {
261            clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
262        }
263
264        final var clientFinalMessage =
265                String.format(
266                        "%s,p=%s",
267                        clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
268        this.state = State.RESPONSE_SENT;
269        return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
270    }
271
272    private Map<String, String> splitToAttributes(final String message) {
273        final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
274        for (final String token : Splitter.on(',').split(message)) {
275            final var tuple = Splitter.on('=').limit(2).splitToList(token);
276            if (tuple.size() == 2) {
277                builder.put(tuple.get(0), tuple.get(1));
278            }
279        }
280        return builder.buildOrThrow();
281    }
282
283    private String processServerFinalMessage(final String challenge)
284            throws AuthenticationException {
285        final String serverFinalMessage;
286        try {
287            serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
288        } catch (final IllegalArgumentException e) {
289            throw new AuthenticationException("Invalid base64 in server final message", e);
290        }
291        final var clientCalculatedServerFinalMessage =
292                String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
293        if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
294            this.state = State.VALID_SERVER_RESPONSE;
295            return "";
296        }
297        throw new AuthenticationException(
298                "Server final message does not match calculated final message");
299    }
300
301    protected byte[] getChannelBindingData(final SSLSocket sslSocket)
302            throws AuthenticationException {
303        if (this.channelBinding == ChannelBinding.NONE) {
304            return new byte[0];
305        }
306        throw new AssertionError("getChannelBindingData needs to be overwritten");
307    }
308
309    private static class CacheKey {
310        private final String algorithm;
311        private final String password;
312        private final byte[] salt;
313        private final int iterations;
314
315        private CacheKey(
316                final String algorithm,
317                final String password,
318                final byte[] salt,
319                final int iterations) {
320            this.algorithm = algorithm;
321            this.password = password;
322            this.salt = salt;
323            this.iterations = iterations;
324        }
325
326        @Override
327        public boolean equals(final Object o) {
328            if (this == o) return true;
329            if (o == null || getClass() != o.getClass()) return false;
330            CacheKey cacheKey = (CacheKey) o;
331            return iterations == cacheKey.iterations
332                    && Objects.equal(algorithm, cacheKey.algorithm)
333                    && Objects.equal(password, cacheKey.password)
334                    && Arrays.equals(salt, cacheKey.salt);
335        }
336
337        @Override
338        public int hashCode() {
339            final int result = Objects.hashCode(algorithm, password, iterations);
340            return 31 * result + Arrays.hashCode(salt);
341        }
342    }
343
344    private static class KeyPair {
345        final byte[] clientKey;
346        final byte[] serverKey;
347
348        KeyPair(final byte[] clientKey, final byte[] serverKey) {
349            this.clientKey = clientKey;
350            this.serverKey = serverKey;
351        }
352    }
353}