ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.util.Base64;
  4
  5import com.google.common.base.CaseFormat;
  6import com.google.common.base.Objects;
  7import com.google.common.cache.Cache;
  8import com.google.common.cache.CacheBuilder;
  9import com.google.common.hash.HashFunction;
 10
 11import java.nio.charset.Charset;
 12import java.security.InvalidKeyException;
 13import java.util.concurrent.ExecutionException;
 14
 15import javax.crypto.SecretKey;
 16import javax.net.ssl.SSLSocket;
 17
 18import eu.siacs.conversations.entities.Account;
 19import eu.siacs.conversations.utils.CryptoHelper;
 20
 21abstract class ScramMechanism extends SaslMechanism {
 22
 23    public static final SecretKey EMPTY_KEY =
 24            new SecretKey() {
 25                @Override
 26                public String getAlgorithm() {
 27                    return "HMAC";
 28                }
 29
 30                @Override
 31                public String getFormat() {
 32                    return "RAW";
 33                }
 34
 35                @Override
 36                public byte[] getEncoded() {
 37                    return new byte[0];
 38                }
 39            };
 40
 41    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 42    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 43    private static final Cache<CacheKey, KeyPair> CACHE =
 44            CacheBuilder.newBuilder().maximumSize(10).build();
 45    protected final ChannelBinding channelBinding;
 46    private final String gs2Header;
 47    private final String clientNonce;
 48    protected State state = State.INITIAL;
 49    private String clientFirstMessageBare;
 50    private byte[] serverSignature = null;
 51
 52    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 53        super(account);
 54        this.channelBinding = channelBinding;
 55        if (channelBinding == ChannelBinding.NONE) {
 56            // TODO this needs to be changed to "y,," for the scram internal down grade protection
 57            // but we might risk compatibility issues if the server supports a binding that we don’t
 58            // support
 59            this.gs2Header = "n,,";
 60        } else {
 61            this.gs2Header =
 62                    String.format(
 63                            "p=%s,,",
 64                            CaseFormat.UPPER_UNDERSCORE
 65                                    .converterTo(CaseFormat.LOWER_HYPHEN)
 66                                    .convert(channelBinding.toString()));
 67        }
 68        // This nonce should be different for each authentication attempt.
 69        this.clientNonce = CryptoHelper.random(100);
 70        clientFirstMessageBare = "";
 71    }
 72
 73    protected abstract HashFunction getHMac(final byte[] key);
 74
 75    protected abstract HashFunction getDigest();
 76
 77    private KeyPair getKeyPair(final String password, final String salt, final int iterations)
 78            throws ExecutionException {
 79        return CACHE.get(
 80                new CacheKey(getMechanism(), password, salt, iterations),
 81                () -> {
 82                    final byte[] saltedPassword, serverKey, clientKey;
 83                    saltedPassword =
 84                            hi(
 85                                    password.getBytes(),
 86                                    Base64.decode(salt, Base64.DEFAULT),
 87                                    iterations);
 88                    serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 89                    clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 90                    return new KeyPair(clientKey, serverKey);
 91                });
 92    }
 93
 94    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
 95        return getHMac(key).hashBytes(input).asBytes();
 96    }
 97
 98    private byte[] digest(final byte[] bytes) {
 99        return getDigest().hashBytes(bytes).asBytes();
100    }
101
102    /*
103     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
104     * pseudorandom function (PRF) and with dkLen == output length of
105     * HMAC() == output length of H().
106     */
107    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
108            throws InvalidKeyException {
109        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
110        byte[] out = u.clone();
111        for (int i = 1; i < iterations; i++) {
112            u = hmac(key, u);
113            for (int j = 0; j < u.length; j++) {
114                out[j] ^= u[j];
115            }
116        }
117        return out;
118    }
119
120    @Override
121    public String getClientFirstMessage(final SSLSocket sslSocket) {
122        if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
123            clientFirstMessageBare =
124                    "n="
125                            + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
126                            + ",r="
127                            + this.clientNonce;
128            state = State.AUTH_TEXT_SENT;
129        }
130        return Base64.encodeToString(
131                (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
132                Base64.NO_WRAP);
133    }
134
135    @Override
136    public String getResponse(final String challenge, final SSLSocket socket)
137            throws AuthenticationException {
138        switch (state) {
139            case AUTH_TEXT_SENT:
140                if (challenge == null) {
141                    throw new AuthenticationException("challenge can not be null");
142                }
143                byte[] serverFirstMessage;
144                try {
145                    serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
146                } catch (IllegalArgumentException e) {
147                    throw new AuthenticationException("Unable to decode server challenge", e);
148                }
149                final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
150                String nonce = "";
151                int iterationCount = -1;
152                String salt = "";
153                for (final String token : tokenizer) {
154                    if (token.length() > 1 && token.charAt(1) == '=') {
155                        switch (token.charAt(0)) {
156                            case 'i':
157                                try {
158                                    iterationCount = Integer.parseInt(token.substring(2));
159                                } catch (final NumberFormatException e) {
160                                    throw new AuthenticationException(e);
161                                }
162                                break;
163                            case 's':
164                                salt = token.substring(2);
165                                break;
166                            case 'r':
167                                nonce = token.substring(2);
168                                break;
169                            case 'm':
170                                /*
171                                 * RFC 5802:
172                                 * m: This attribute is reserved for future extensibility.  In this
173                                 * version of SCRAM, its presence in a client or a server message
174                                 * MUST cause authentication failure when the attribute is parsed by
175                                 * the other end.
176                                 */
177                                throw new AuthenticationException(
178                                        "Server sent reserved token: `m'");
179                        }
180                    }
181                }
182
183                if (iterationCount < 0) {
184                    throw new AuthenticationException("Server did not send iteration count");
185                }
186                if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
187                    throw new AuthenticationException(
188                            "Server nonce does not contain client nonce: " + nonce);
189                }
190                if (salt.isEmpty()) {
191                    throw new AuthenticationException("Server sent empty salt");
192                }
193
194                final byte[] channelBindingData = getChannelBindingData(socket);
195
196                final int gs2Len = this.gs2Header.getBytes().length;
197                final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
198                System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
199                System.arraycopy(
200                        channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
201
202                final String clientFinalMessageWithoutProof =
203                        "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
204
205                final byte[] authMessage =
206                        (clientFirstMessageBare
207                                        + ','
208                                        + new String(serverFirstMessage)
209                                        + ','
210                                        + clientFinalMessageWithoutProof)
211                                .getBytes();
212
213                final KeyPair keys;
214                try {
215                    keys =
216                            getKeyPair(
217                                    CryptoHelper.saslPrep(account.getPassword()),
218                                    salt,
219                                    iterationCount);
220                } catch (ExecutionException e) {
221                    throw new AuthenticationException("Invalid keys generated");
222                }
223                final byte[] clientSignature;
224                try {
225                    serverSignature = hmac(keys.serverKey, authMessage);
226                    final byte[] storedKey = digest(keys.clientKey);
227
228                    clientSignature = hmac(storedKey, authMessage);
229
230                } catch (final InvalidKeyException e) {
231                    throw new AuthenticationException(e);
232                }
233
234                final byte[] clientProof = new byte[keys.clientKey.length];
235
236                if (clientSignature.length < keys.clientKey.length) {
237                    throw new AuthenticationException(
238                            "client signature was shorter than clientKey");
239                }
240
241                for (int i = 0; i < clientProof.length; i++) {
242                    clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
243                }
244
245                final String clientFinalMessage =
246                        clientFinalMessageWithoutProof
247                                + ",p="
248                                + Base64.encodeToString(clientProof, Base64.NO_WRAP);
249                state = State.RESPONSE_SENT;
250                return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
251            case RESPONSE_SENT:
252                try {
253                    final String clientCalculatedServerFinalMessage =
254                            "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
255                    if (!clientCalculatedServerFinalMessage.equals(
256                            new String(Base64.decode(challenge, Base64.DEFAULT)))) {
257                        throw new Exception();
258                    }
259                    state = State.VALID_SERVER_RESPONSE;
260                    return "";
261                } catch (Exception e) {
262                    throw new AuthenticationException(
263                            "Server final message does not match calculated final message");
264                }
265            default:
266                throw new InvalidStateException(state);
267        }
268    }
269
270    protected byte[] getChannelBindingData(final SSLSocket sslSocket)
271            throws AuthenticationException {
272        if (this.channelBinding == ChannelBinding.NONE) {
273            return new byte[0];
274        }
275        throw new AssertionError("getChannelBindingData needs to be overwritten");
276    }
277
278    private static class CacheKey {
279        final String algorithm;
280        final String password;
281        final String salt;
282        final int iterations;
283
284        private CacheKey(String algorithm, String password, String salt, int iterations) {
285            this.algorithm = algorithm;
286            this.password = password;
287            this.salt = salt;
288            this.iterations = iterations;
289        }
290
291        @Override
292        public boolean equals(Object o) {
293            if (this == o) return true;
294            if (o == null || getClass() != o.getClass()) return false;
295            CacheKey cacheKey = (CacheKey) o;
296            return iterations == cacheKey.iterations
297                    && Objects.equal(algorithm, cacheKey.algorithm)
298                    && Objects.equal(password, cacheKey.password)
299                    && Objects.equal(salt, cacheKey.salt);
300        }
301
302        @Override
303        public int hashCode() {
304            return Objects.hashCode(algorithm, password, salt, iterations);
305        }
306    }
307
308    private static class KeyPair {
309        final byte[] clientKey;
310        final byte[] serverKey;
311
312        KeyPair(final byte[] clientKey, final byte[] serverKey) {
313            this.clientKey = clientKey;
314            this.serverKey = serverKey;
315        }
316    }
317}