ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import com.google.common.base.CaseFormat;
  4import com.google.common.base.Joiner;
  5import com.google.common.base.Objects;
  6import com.google.common.base.Preconditions;
  7import com.google.common.base.Splitter;
  8import com.google.common.base.Strings;
  9import com.google.common.cache.Cache;
 10import com.google.common.cache.CacheBuilder;
 11import com.google.common.collect.ImmutableMap;
 12import com.google.common.hash.HashFunction;
 13import com.google.common.io.BaseEncoding;
 14import com.google.common.primitives.Ints;
 15import eu.siacs.conversations.entities.Account;
 16import eu.siacs.conversations.utils.CryptoHelper;
 17import java.security.InvalidKeyException;
 18import java.util.Arrays;
 19import java.util.Map;
 20import java.util.concurrent.ExecutionException;
 21import javax.crypto.SecretKey;
 22import javax.net.ssl.SSLSocket;
 23
 24public abstract class ScramMechanism extends SaslMechanism {
 25
 26    public static final SecretKey EMPTY_KEY =
 27            new SecretKey() {
 28                @Override
 29                public String getAlgorithm() {
 30                    return "HMAC";
 31                }
 32
 33                @Override
 34                public String getFormat() {
 35                    return "RAW";
 36                }
 37
 38                @Override
 39                public byte[] getEncoded() {
 40                    return new byte[0];
 41                }
 42            };
 43
 44    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 45    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 46    private static final Cache<CacheKey, KeyPair> CACHE =
 47            CacheBuilder.newBuilder().maximumSize(10).build();
 48    protected final ChannelBinding channelBinding;
 49    private final String gs2Header;
 50    private final String clientNonce;
 51    protected State state = State.INITIAL;
 52    private final String clientFirstMessageBare;
 53    private byte[] serverSignature = null;
 54    private DowngradeProtection downgradeProtection = null;
 55
 56    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 57        super(account);
 58        this.channelBinding = channelBinding;
 59        if (channelBinding == ChannelBinding.NONE) {
 60            // TODO this needs to be changed to "y,," for the scram internal down grade protection
 61            // but we might risk compatibility issues if the server supports a binding that we don’t
 62            // support
 63            this.gs2Header = "n,,";
 64        } else {
 65            this.gs2Header =
 66                    String.format(
 67                            "p=%s,,",
 68                            CaseFormat.UPPER_UNDERSCORE
 69                                    .converterTo(CaseFormat.LOWER_HYPHEN)
 70                                    .convert(channelBinding.toString()));
 71        }
 72        // This nonce should be different for each authentication attempt.
 73        this.clientNonce = CryptoHelper.random(100);
 74        this.clientFirstMessageBare =
 75                String.format(
 76                        "n=%s,r=%s",
 77                        CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
 78                        this.clientNonce);
 79    }
 80
 81    public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
 82        Preconditions.checkState(
 83                this.state == State.INITIAL, "setting downgrade protection in invalid state");
 84        this.downgradeProtection = downgradeProtection;
 85    }
 86
 87    protected abstract HashFunction getHMac(final byte[] key);
 88
 89    protected abstract HashFunction getDigest();
 90
 91    private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
 92            throws ExecutionException {
 93        final var key = new CacheKey(getMechanism(), password, salt, iterations);
 94        return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
 95    }
 96
 97    private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
 98            throws InvalidKeyException {
 99        final byte[] saltedPassword, serverKey, clientKey;
100        saltedPassword = hi(password.getBytes(), salt, iterations);
101        serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
102        clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
103        return new KeyPair(clientKey, serverKey);
104    }
105
106    @Override
107    public String getMechanism() {
108        return "";
109    }
110
111    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
112        return getHMac(key).hashBytes(input).asBytes();
113    }
114
115    private byte[] digest(final byte[] bytes) {
116        return getDigest().hashBytes(bytes).asBytes();
117    }
118
119    /*
120     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
121     * pseudorandom function (PRF) and with dkLen == output length of
122     * HMAC() == output length of H().
123     */
124    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
125            throws InvalidKeyException {
126        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
127        byte[] out = u.clone();
128        for (int i = 1; i < iterations; i++) {
129            u = hmac(key, u);
130            for (int j = 0; j < u.length; j++) {
131                out[j] ^= u[j];
132            }
133        }
134        return out;
135    }
136
137    @Override
138    public String getClientFirstMessage(final SSLSocket sslSocket) {
139        Preconditions.checkState(
140                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
141        this.state = State.AUTH_TEXT_SENT;
142        final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
143        return BaseEncoding.base64().encode(message);
144    }
145
146    @Override
147    public String getResponse(final String challenge, final SSLSocket socket)
148            throws AuthenticationException {
149        return switch (state) {
150            case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
151            case RESPONSE_SENT -> processServerFinalMessage(challenge);
152            default -> throw new InvalidStateException(state);
153        };
154    }
155
156    private String processServerFirstMessage(final String challenge, final SSLSocket socket)
157            throws AuthenticationException {
158        if (Strings.isNullOrEmpty(challenge)) {
159            throw new AuthenticationException("challenge can not be null");
160        }
161        byte[] serverFirstMessage;
162        try {
163            serverFirstMessage = BaseEncoding.base64().decode(challenge);
164        } catch (final IllegalArgumentException e) {
165            throw new AuthenticationException("Unable to decode server challenge", e);
166        }
167        final Map<String, String> attributes;
168        try {
169            attributes = splitToAttributes(new String(serverFirstMessage));
170        } catch (final IllegalArgumentException e) {
171            throw new AuthenticationException("Duplicate attributes");
172        }
173        if (attributes.containsKey("m")) {
174            /*
175             * RFC 5802:
176             * m: This attribute is reserved for future extensibility.  In this
177             * version of SCRAM, its presence in a client or a server message
178             * MUST cause authentication failure when the attribute is parsed by
179             * the other end.
180             */
181            throw new AuthenticationException("Server sent reserved token: 'm'");
182        }
183        final String i = attributes.get("i");
184        final String s = attributes.get("s");
185        final String nonce = attributes.get("r");
186        final String h = attributes.get("h");
187        if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
188            throw new AuthenticationException("Missing attributes from server first message");
189        }
190        final Integer iterationCount = Ints.tryParse(i);
191
192        if (iterationCount == null || iterationCount < 0) {
193            throw new AuthenticationException("Server did not send iteration count");
194        }
195        if (!nonce.startsWith(clientNonce)) {
196            throw new AuthenticationException(
197                    "Server nonce does not contain client nonce: " + nonce);
198        }
199
200        final byte[] salt;
201
202        try {
203            salt = BaseEncoding.base64().decode(s);
204        } catch (final IllegalArgumentException e) {
205            throw new AuthenticationException("Invalid salt in server first message");
206        }
207
208        if (h != null && this.downgradeProtection != null) {
209            final String asSeenInFeatures;
210            try {
211                asSeenInFeatures = downgradeProtection.asHString();
212            } catch (final SecurityException e) {
213                throw new AuthenticationException(e);
214            }
215            final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
216            if (!hashed.equals(h)) {
217                throw new AuthenticationException("Mismatch in SSDP");
218            }
219        }
220
221        final byte[] channelBindingData = getChannelBindingData(socket);
222
223        final int gs2Len = this.gs2Header.getBytes().length;
224        final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
225        System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
226        System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
227
228        final String clientFinalMessageWithoutProof =
229                String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
230
231        final var authMessage =
232                Joiner.on(',')
233                        .join(
234                                clientFirstMessageBare,
235                                new String(serverFirstMessage),
236                                clientFinalMessageWithoutProof);
237
238        final KeyPair keys;
239        try {
240            keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
241        } catch (final ExecutionException e) {
242            throw new AuthenticationException("Invalid keys generated");
243        }
244        final byte[] clientSignature;
245        try {
246            serverSignature = hmac(keys.serverKey, authMessage.getBytes());
247            final byte[] storedKey = digest(keys.clientKey);
248
249            clientSignature = hmac(storedKey, authMessage.getBytes());
250
251        } catch (final InvalidKeyException e) {
252            throw new AuthenticationException(e);
253        }
254
255        final byte[] clientProof = new byte[keys.clientKey.length];
256
257        if (clientSignature.length < keys.clientKey.length) {
258            throw new AuthenticationException("client signature was shorter than clientKey");
259        }
260
261        for (int j = 0; j < clientProof.length; j++) {
262            clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
263        }
264
265        final var clientFinalMessage =
266                String.format(
267                        "%s,p=%s",
268                        clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
269        this.state = State.RESPONSE_SENT;
270        return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
271    }
272
273    private Map<String, String> splitToAttributes(final String message) {
274        final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
275        for (final String token : Splitter.on(',').split(message)) {
276            final var tuple = Splitter.on('=').limit(2).splitToList(token);
277            if (tuple.size() == 2) {
278                builder.put(tuple.get(0), tuple.get(1));
279            }
280        }
281        return builder.buildOrThrow();
282    }
283
284    private String processServerFinalMessage(final String challenge)
285            throws AuthenticationException {
286        final String serverFinalMessage;
287        try {
288            serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
289        } catch (final IllegalArgumentException e) {
290            throw new AuthenticationException("Invalid base64 in server final message", e);
291        }
292        final var clientCalculatedServerFinalMessage =
293                String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
294        if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
295            this.state = State.VALID_SERVER_RESPONSE;
296            return "";
297        }
298        throw new AuthenticationException(
299                "Server final message does not match calculated final message");
300    }
301
302    protected byte[] getChannelBindingData(final SSLSocket sslSocket)
303            throws AuthenticationException {
304        if (this.channelBinding == ChannelBinding.NONE) {
305            return new byte[0];
306        }
307        throw new AssertionError("getChannelBindingData needs to be overwritten");
308    }
309
310    private static class CacheKey {
311        private final String algorithm;
312        private final String password;
313        private final byte[] salt;
314        private final int iterations;
315
316        private CacheKey(
317                final String algorithm,
318                final String password,
319                final byte[] salt,
320                final int iterations) {
321            this.algorithm = algorithm;
322            this.password = password;
323            this.salt = salt;
324            this.iterations = iterations;
325        }
326
327        @Override
328        public boolean equals(final Object o) {
329            if (this == o) return true;
330            if (o == null || getClass() != o.getClass()) return false;
331            CacheKey cacheKey = (CacheKey) o;
332            return iterations == cacheKey.iterations
333                    && Objects.equal(algorithm, cacheKey.algorithm)
334                    && Objects.equal(password, cacheKey.password)
335                    && Arrays.equals(salt, cacheKey.salt);
336        }
337
338        @Override
339        public int hashCode() {
340            final int result = Objects.hashCode(algorithm, password, iterations);
341            return 31 * result + Arrays.hashCode(salt);
342        }
343    }
344
345    private static class KeyPair {
346        final byte[] clientKey;
347        final byte[] serverKey;
348
349        KeyPair(final byte[] clientKey, final byte[] serverKey) {
350            this.clientKey = clientKey;
351            this.serverKey = serverKey;
352        }
353    }
354}