ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import com.google.common.base.CaseFormat;
  4import com.google.common.base.Joiner;
  5import com.google.common.base.Objects;
  6import com.google.common.base.Preconditions;
  7import com.google.common.base.Splitter;
  8import com.google.common.base.Strings;
  9import com.google.common.cache.Cache;
 10import com.google.common.cache.CacheBuilder;
 11import com.google.common.collect.ImmutableMap;
 12import com.google.common.hash.HashFunction;
 13import com.google.common.io.BaseEncoding;
 14import com.google.common.primitives.Ints;
 15import eu.siacs.conversations.entities.Account;
 16import eu.siacs.conversations.utils.CryptoHelper;
 17import java.security.InvalidKeyException;
 18import java.util.Arrays;
 19import java.util.Map;
 20import java.util.concurrent.ExecutionException;
 21import javax.crypto.SecretKey;
 22import javax.net.ssl.SSLSocket;
 23
 24public abstract class ScramMechanism extends SaslMechanism {
 25
 26    public static final SecretKey EMPTY_KEY =
 27            new SecretKey() {
 28                @Override
 29                public String getAlgorithm() {
 30                    return "HMAC";
 31                }
 32
 33                @Override
 34                public String getFormat() {
 35                    return "RAW";
 36                }
 37
 38                @Override
 39                public byte[] getEncoded() {
 40                    return new byte[0];
 41                }
 42            };
 43
 44    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 45    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 46    private static final Cache<CacheKey, KeyPair> CACHE =
 47            CacheBuilder.newBuilder().maximumSize(10).build();
 48    protected final ChannelBinding channelBinding;
 49    private final String gs2Header;
 50    private final String clientNonce;
 51    private final String clientFirstMessageBare;
 52    private byte[] serverSignature = null;
 53    private DowngradeProtection downgradeProtection = null;
 54
 55    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 56        super(account);
 57        this.channelBinding = channelBinding;
 58        if (channelBinding == ChannelBinding.NONE) {
 59            this.gs2Header = "y,,";
 60        } else {
 61            this.gs2Header =
 62                    String.format(
 63                            "p=%s,,",
 64                            CaseFormat.UPPER_UNDERSCORE
 65                                    .converterTo(CaseFormat.LOWER_HYPHEN)
 66                                    .convert(channelBinding.toString()));
 67        }
 68        // This nonce should be different for each authentication attempt.
 69        this.clientNonce = CryptoHelper.random(100);
 70        this.clientFirstMessageBare =
 71                String.format(
 72                        "n=%s,r=%s",
 73                        CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
 74                        this.clientNonce);
 75    }
 76
 77    public void setDowngradeProtection(final DowngradeProtection downgradeProtection) {
 78        Preconditions.checkState(
 79                this.state == State.INITIAL, "setting downgrade protection in invalid state");
 80        this.downgradeProtection = downgradeProtection;
 81    }
 82
 83    protected abstract HashFunction getHMac(final byte[] key);
 84
 85    protected abstract HashFunction getDigest();
 86
 87    private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
 88            throws ExecutionException {
 89        final var key = new CacheKey(getMechanism(), password, salt, iterations);
 90        return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
 91    }
 92
 93    private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
 94            throws InvalidKeyException {
 95        final byte[] saltedPassword, serverKey, clientKey;
 96        saltedPassword = hi(password.getBytes(), salt, iterations);
 97        serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 98        clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 99        return new KeyPair(clientKey, serverKey);
100    }
101
102    @Override
103    public String getMechanism() {
104        return "";
105    }
106
107    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
108        return getHMac(key).hashBytes(input).asBytes();
109    }
110
111    private byte[] digest(final byte[] bytes) {
112        return getDigest().hashBytes(bytes).asBytes();
113    }
114
115    /*
116     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
117     * pseudorandom function (PRF) and with dkLen == output length of
118     * HMAC() == output length of H().
119     */
120    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
121            throws InvalidKeyException {
122        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
123        byte[] out = u.clone();
124        for (int i = 1; i < iterations; i++) {
125            u = hmac(key, u);
126            for (int j = 0; j < u.length; j++) {
127                out[j] ^= u[j];
128            }
129        }
130        return out;
131    }
132
133    @Override
134    public String getClientFirstMessage(final SSLSocket sslSocket) {
135        Preconditions.checkState(
136                this.state == State.INITIAL, "Calling getClientFirstMessage from invalid state");
137        this.state = State.AUTH_TEXT_SENT;
138        final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
139        return BaseEncoding.base64().encode(message);
140    }
141
142    @Override
143    public String getResponse(final String challenge, final SSLSocket socket)
144            throws AuthenticationException {
145        return switch (state) {
146            case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
147            case RESPONSE_SENT -> processServerFinalMessage(challenge);
148            default -> throw new InvalidStateException(state);
149        };
150    }
151
152    private String processServerFirstMessage(final String challenge, final SSLSocket socket)
153            throws AuthenticationException {
154        if (Strings.isNullOrEmpty(challenge)) {
155            throw new AuthenticationException("challenge can not be null");
156        }
157        byte[] serverFirstMessage;
158        try {
159            serverFirstMessage = BaseEncoding.base64().decode(challenge);
160        } catch (final IllegalArgumentException e) {
161            throw new AuthenticationException("Unable to decode server challenge", e);
162        }
163        final Map<String, String> attributes;
164        try {
165            attributes = splitToAttributes(new String(serverFirstMessage));
166        } catch (final IllegalArgumentException e) {
167            throw new AuthenticationException("Duplicate attributes");
168        }
169        if (attributes.containsKey("m")) {
170            /*
171             * RFC 5802:
172             * m: This attribute is reserved for future extensibility.  In this
173             * version of SCRAM, its presence in a client or a server message
174             * MUST cause authentication failure when the attribute is parsed by
175             * the other end.
176             */
177            throw new AuthenticationException("Server sent reserved token: 'm'");
178        }
179        final String i = attributes.get("i");
180        final String s = attributes.get("s");
181        final String nonce = attributes.get("r");
182        final String h = attributes.get("h");
183        if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
184            throw new AuthenticationException("Missing attributes from server first message");
185        }
186        final Integer iterationCount = Ints.tryParse(i);
187
188        if (iterationCount == null || iterationCount < 0) {
189            throw new AuthenticationException("Server did not send iteration count");
190        }
191        if (!nonce.startsWith(clientNonce)) {
192            throw new AuthenticationException(
193                    "Server nonce does not contain client nonce: " + nonce);
194        }
195
196        final byte[] salt;
197
198        try {
199            salt = BaseEncoding.base64().decode(s);
200        } catch (final IllegalArgumentException e) {
201            throw new AuthenticationException("Invalid salt in server first message");
202        }
203
204        if (h != null && this.downgradeProtection != null) {
205            final String asSeenInFeatures;
206            try {
207                asSeenInFeatures = downgradeProtection.asHString();
208            } catch (final SecurityException e) {
209                throw new AuthenticationException(e);
210            }
211            final var hashed = BaseEncoding.base64().encode(digest(asSeenInFeatures.getBytes()));
212            if (!hashed.equals(h)) {
213                throw new AuthenticationException("Mismatch in SSDP");
214            }
215        }
216
217        final byte[] channelBindingData = getChannelBindingData(socket);
218
219        final int gs2Len = this.gs2Header.getBytes().length;
220        final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
221        System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
222        System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
223
224        final String clientFinalMessageWithoutProof =
225                String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
226
227        final var authMessage =
228                Joiner.on(',')
229                        .join(
230                                clientFirstMessageBare,
231                                new String(serverFirstMessage),
232                                clientFinalMessageWithoutProof);
233
234        final KeyPair keys;
235        try {
236            keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
237        } catch (final ExecutionException e) {
238            throw new AuthenticationException("Invalid keys generated");
239        }
240        final byte[] clientSignature;
241        try {
242            serverSignature = hmac(keys.serverKey, authMessage.getBytes());
243            final byte[] storedKey = digest(keys.clientKey);
244
245            clientSignature = hmac(storedKey, authMessage.getBytes());
246
247        } catch (final InvalidKeyException e) {
248            throw new AuthenticationException(e);
249        }
250
251        final byte[] clientProof = new byte[keys.clientKey.length];
252
253        if (clientSignature.length < keys.clientKey.length) {
254            throw new AuthenticationException("client signature was shorter than clientKey");
255        }
256
257        for (int j = 0; j < clientProof.length; j++) {
258            clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
259        }
260
261        final var clientFinalMessage =
262                String.format(
263                        "%s,p=%s",
264                        clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
265        this.state = State.RESPONSE_SENT;
266        return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
267    }
268
269    private Map<String, String> splitToAttributes(final String message) {
270        final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
271        for (final String token : Splitter.on(',').split(message)) {
272            final var tuple = Splitter.on('=').limit(2).splitToList(token);
273            if (tuple.size() == 2) {
274                builder.put(tuple.get(0), tuple.get(1));
275            }
276        }
277        return builder.buildOrThrow();
278    }
279
280    private String processServerFinalMessage(final String challenge)
281            throws AuthenticationException {
282        final String serverFinalMessage;
283        try {
284            serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
285        } catch (final IllegalArgumentException e) {
286            throw new AuthenticationException("Invalid base64 in server final message", e);
287        }
288        final var clientCalculatedServerFinalMessage =
289                String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
290        if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
291            this.state = State.VALID_SERVER_RESPONSE;
292            return "";
293        }
294        throw new AuthenticationException(
295                "Server final message does not match calculated final message");
296    }
297
298    protected byte[] getChannelBindingData(final SSLSocket sslSocket)
299            throws AuthenticationException {
300        if (this.channelBinding == ChannelBinding.NONE) {
301            return new byte[0];
302        }
303        throw new AssertionError("getChannelBindingData needs to be overwritten");
304    }
305
306    private static class CacheKey {
307        private final String algorithm;
308        private final String password;
309        private final byte[] salt;
310        private final int iterations;
311
312        private CacheKey(
313                final String algorithm,
314                final String password,
315                final byte[] salt,
316                final int iterations) {
317            this.algorithm = algorithm;
318            this.password = password;
319            this.salt = salt;
320            this.iterations = iterations;
321        }
322
323        @Override
324        public boolean equals(final Object o) {
325            if (this == o) return true;
326            if (o == null || getClass() != o.getClass()) return false;
327            CacheKey cacheKey = (CacheKey) o;
328            return iterations == cacheKey.iterations
329                    && Objects.equal(algorithm, cacheKey.algorithm)
330                    && Objects.equal(password, cacheKey.password)
331                    && Arrays.equals(salt, cacheKey.salt);
332        }
333
334        @Override
335        public int hashCode() {
336            final int result = Objects.hashCode(algorithm, password, iterations);
337            return 31 * result + Arrays.hashCode(salt);
338        }
339    }
340
341    private static class KeyPair {
342        final byte[] clientKey;
343        final byte[] serverKey;
344
345        KeyPair(final byte[] clientKey, final byte[] serverKey) {
346            this.clientKey = clientKey;
347            this.serverKey = serverKey;
348        }
349    }
350}