ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import com.google.common.base.CaseFormat;
  4import com.google.common.base.Joiner;
  5import com.google.common.base.Objects;
  6import com.google.common.base.Preconditions;
  7import com.google.common.base.Splitter;
  8import com.google.common.base.Strings;
  9import com.google.common.cache.Cache;
 10import com.google.common.cache.CacheBuilder;
 11import com.google.common.collect.ImmutableMap;
 12import com.google.common.hash.HashFunction;
 13import com.google.common.io.BaseEncoding;
 14import com.google.common.primitives.Ints;
 15import eu.siacs.conversations.entities.Account;
 16import eu.siacs.conversations.utils.CryptoHelper;
 17import java.security.InvalidKeyException;
 18import java.util.Arrays;
 19import java.util.Map;
 20import java.util.concurrent.ExecutionException;
 21import javax.crypto.SecretKey;
 22import javax.net.ssl.SSLSocket;
 23
 24public abstract class ScramMechanism extends SaslMechanism {
 25
 26    public static final SecretKey EMPTY_KEY =
 27            new SecretKey() {
 28                @Override
 29                public String getAlgorithm() {
 30                    return "HMAC";
 31                }
 32
 33                @Override
 34                public String getFormat() {
 35                    return "RAW";
 36                }
 37
 38                @Override
 39                public byte[] getEncoded() {
 40                    return new byte[0];
 41                }
 42            };
 43
 44    // For the SCRAM-SHA-1/SCRAM-SHA-1-PLUS SASL mechanism, servers SHOULD announce a hash
 45    // iteration-count of at least 4096.
 46    // https://datatracker.ietf.org/doc/html/rfc5802#section-5.1
 47    private static final int ITERATION_COUNT_MINIMUM = 4096;
 48    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 49    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 50    private static final Cache<CacheKey, KeyPair> CACHE =
 51            CacheBuilder.newBuilder().maximumSize(10).build();
 52    protected final ChannelBinding channelBinding;
 53    private final String gs2Header;
 54    private final String clientNonce;
 55    private final String clientFirstMessageBare;
 56    private byte[] serverSignature = null;
 57    private DowngradeProtection downgradeProtection = null;
 58
 59    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 60        super(account);
 61        this.channelBinding = channelBinding;
 62        if (channelBinding == ChannelBinding.NONE) {
 63            this.gs2Header = "y,,";
 64        } else {
 65            this.gs2Header =
 66                    String.format(
 67                            "p=%s,,",
 68                            CaseFormat.UPPER_UNDERSCORE
 69                                    .converterTo(CaseFormat.LOWER_HYPHEN)
 70                                    .convert(channelBinding.toString()));
 71        }
 72        // This nonce should be different for each authentication attempt.
 73        this.clientNonce = CryptoHelper.random(100);
 74        this.clientFirstMessageBare =
 75                String.format(
 76                        "n=%s,r=%s",
 77                        CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
 78                        this.clientNonce);
 79    }
 80
 81    public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
 82        Preconditions.checkState(
 83                this.state == State.INITIAL, "setting downgrade protection in invalid state");
 84        this.downgradeProtection = downgradeProtection;
 85    }
 86
 87    protected abstract HashFunction getHMac(final byte[] key);
 88
 89    protected abstract HashFunction getDigest();
 90
 91    private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
 92            throws ExecutionException {
 93        final var key = new CacheKey(getMechanism(), password, salt, iterations);
 94        return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
 95    }
 96
 97    private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
 98            throws InvalidKeyException {
 99        final byte[] saltedPassword, serverKey, clientKey;
100        saltedPassword = hi(password.getBytes(), salt, iterations);
101        serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
102        clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
103        return new KeyPair(clientKey, serverKey);
104    }
105
106    @Override
107    public String getMechanism() {
108        return "";
109    }
110
111    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
112        return getHMac(key).hashBytes(input).asBytes();
113    }
114
115    private byte[] digest(final byte[] bytes) {
116        return getDigest().hashBytes(bytes).asBytes();
117    }
118
119    /*
120     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
121     * pseudorandom function (PRF) and with dkLen == output length of
122     * HMAC() == output length of H().
123     */
124    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
125            throws InvalidKeyException {
126        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
127        byte[] out = u.clone();
128        for (int i = 1; i < iterations; i++) {
129            u = hmac(key, u);
130            for (int j = 0; j < u.length; j++) {
131                out[j] ^= u[j];
132            }
133        }
134        return out;
135    }
136
137    @Override
138    public String getClientFirstMessage(final SSLSocket sslSocket) {
139        Preconditions.checkState(
140                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
141        this.state = State.AUTH_TEXT_SENT;
142        final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
143        return BaseEncoding.base64().encode(message);
144    }
145
146    @Override
147    public String getResponse(final String challenge, final SSLSocket socket)
148            throws AuthenticationException {
149        return switch (state) {
150            case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
151            case RESPONSE_SENT -> processServerFinalMessage(challenge);
152            default -> throw new InvalidStateException(state);
153        };
154    }
155
156    private String processServerFirstMessage(final String challenge, final SSLSocket socket)
157            throws AuthenticationException {
158        if (Strings.isNullOrEmpty(challenge)) {
159            throw new AuthenticationException("challenge can not be null");
160        }
161        byte[] serverFirstMessage;
162        try {
163            serverFirstMessage = BaseEncoding.base64().decode(challenge);
164        } catch (final IllegalArgumentException e) {
165            throw new AuthenticationException("Unable to decode server challenge", e);
166        }
167        final Map<String, String> attributes;
168        try {
169            attributes = splitToAttributes(new String(serverFirstMessage));
170        } catch (final IllegalArgumentException e) {
171            throw new AuthenticationException("Duplicate attributes");
172        }
173        if (attributes.containsKey("m")) {
174            /*
175             * RFC 5802:
176             * m: This attribute is reserved for future extensibility.  In this
177             * version of SCRAM, its presence in a client or a server message
178             * MUST cause authentication failure when the attribute is parsed by
179             * the other end.
180             */
181            throw new AuthenticationException("Server sent reserved token: 'm'");
182        }
183        final String i = attributes.get("i");
184        final String s = attributes.get("s");
185        final String nonce = attributes.get("r");
186        final String h = attributes.get("h");
187        if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
188            throw new AuthenticationException("Missing attributes from server first message");
189        }
190        final Integer iterationCount = Ints.tryParse(i);
191
192        if (iterationCount == null || iterationCount < 0) {
193            throw new AuthenticationException("Server did not send iteration count");
194        }
195
196        if (iterationCount < ITERATION_COUNT_MINIMUM) {
197            throw new AuthenticationException(
198                    String.format(
199                            "Weak iteration count. %d instead of %d",
200                            iterationCount, ITERATION_COUNT_MINIMUM));
201        }
202
203        if (!nonce.startsWith(clientNonce)) {
204            throw new AuthenticationException(
205                    "Server nonce does not contain client nonce: " + nonce);
206        }
207
208        final byte[] salt;
209
210        try {
211            salt = BaseEncoding.base64().decode(s);
212        } catch (final IllegalArgumentException e) {
213            throw new AuthenticationException("Invalid salt in server first message");
214        }
215
216        if (h != null && this.downgradeProtection != null) {
217            final String asSeenInFeatures;
218            try {
219                asSeenInFeatures = downgradeProtection.asHString();
220            } catch (final SecurityException e) {
221                throw new AuthenticationException(e);
222            }
223            final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
224            if (!hashed.equals(h)) {
225                throw new AuthenticationException("Mismatch in SSDP");
226            }
227        }
228
229        final byte[] channelBindingData = getChannelBindingData(socket);
230
231        final int gs2Len = this.gs2Header.getBytes().length;
232        final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
233        System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
234        System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
235
236        final String clientFinalMessageWithoutProof =
237                String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
238
239        final var authMessage =
240                Joiner.on(',')
241                        .join(
242                                clientFirstMessageBare,
243                                new String(serverFirstMessage),
244                                clientFinalMessageWithoutProof);
245
246        final KeyPair keys;
247        try {
248            keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
249        } catch (final ExecutionException e) {
250            throw new AuthenticationException("Invalid keys generated");
251        }
252        final byte[] clientSignature;
253        try {
254            serverSignature = hmac(keys.serverKey, authMessage.getBytes());
255            final byte[] storedKey = digest(keys.clientKey);
256
257            clientSignature = hmac(storedKey, authMessage.getBytes());
258
259        } catch (final InvalidKeyException e) {
260            throw new AuthenticationException(e);
261        }
262
263        final byte[] clientProof = new byte[keys.clientKey.length];
264
265        if (clientSignature.length < keys.clientKey.length) {
266            throw new AuthenticationException("client signature was shorter than clientKey");
267        }
268
269        for (int j = 0; j < clientProof.length; j++) {
270            clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
271        }
272
273        final var clientFinalMessage =
274                String.format(
275                        "%s,p=%s",
276                        clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
277        this.state = State.RESPONSE_SENT;
278        return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
279    }
280
281    private Map<String, String> splitToAttributes(final String message) {
282        final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
283        for (final String token : Splitter.on(',').split(message)) {
284            final var tuple = Splitter.on('=').limit(2).splitToList(token);
285            if (tuple.size() == 2) {
286                builder.put(tuple.get(0), tuple.get(1));
287            }
288        }
289        return builder.buildOrThrow();
290    }
291
292    private String processServerFinalMessage(final String challenge)
293            throws AuthenticationException {
294        final String serverFinalMessage;
295        try {
296            serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
297        } catch (final IllegalArgumentException e) {
298            throw new AuthenticationException("Invalid base64 in server final message", e);
299        }
300        final var clientCalculatedServerFinalMessage =
301                String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
302        if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
303            this.state = State.VALID_SERVER_RESPONSE;
304            return "";
305        }
306        throw new AuthenticationException(
307                "Server final message does not match calculated final message");
308    }
309
310    protected byte[] getChannelBindingData(final SSLSocket sslSocket)
311            throws AuthenticationException {
312        if (this.channelBinding == ChannelBinding.NONE) {
313            return new byte[0];
314        }
315        throw new AssertionError("getChannelBindingData needs to be overwritten");
316    }
317
318    private static class CacheKey {
319        private final String algorithm;
320        private final String password;
321        private final byte[] salt;
322        private final int iterations;
323
324        private CacheKey(
325                final String algorithm,
326                final String password,
327                final byte[] salt,
328                final int iterations) {
329            this.algorithm = algorithm;
330            this.password = password;
331            this.salt = salt;
332            this.iterations = iterations;
333        }
334
335        @Override
336        public boolean equals(final Object o) {
337            if (this == o) return true;
338            if (o == null || getClass() != o.getClass()) return false;
339            CacheKey cacheKey = (CacheKey) o;
340            return iterations == cacheKey.iterations
341                    && Objects.equal(algorithm, cacheKey.algorithm)
342                    && Objects.equal(password, cacheKey.password)
343                    && Arrays.equals(salt, cacheKey.salt);
344        }
345
346        @Override
347        public int hashCode() {
348            final int result = Objects.hashCode(algorithm, password, iterations);
349            return 31 * result + Arrays.hashCode(salt);
350        }
351    }
352
353    private static class KeyPair {
354        final byte[] clientKey;
355        final byte[] serverKey;
356
357        KeyPair(final byte[] clientKey, final byte[] serverKey) {
358            this.clientKey = clientKey;
359            this.serverKey = serverKey;
360        }
361    }
362}