ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.util.Base64;
  4
  5import com.google.common.base.CaseFormat;
  6import com.google.common.base.Objects;
  7import com.google.common.cache.Cache;
  8import com.google.common.cache.CacheBuilder;
  9
 10import org.bouncycastle.crypto.Digest;
 11import org.bouncycastle.crypto.macs.HMac;
 12import org.bouncycastle.crypto.params.KeyParameter;
 13
 14import java.nio.charset.Charset;
 15import java.security.InvalidKeyException;
 16import java.util.concurrent.ExecutionException;
 17
 18import javax.net.ssl.SSLSocket;
 19
 20import eu.siacs.conversations.entities.Account;
 21import eu.siacs.conversations.utils.CryptoHelper;
 22
 23abstract class ScramMechanism extends SaslMechanism {
 24
 25    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 26    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 27    private static final Cache<CacheKey, KeyPair> CACHE =
 28            CacheBuilder.newBuilder().maximumSize(10).build();
 29    protected final ChannelBinding channelBinding;
 30    private final String gs2Header;
 31    private final String clientNonce;
 32    protected State state = State.INITIAL;
 33    private String clientFirstMessageBare;
 34    private byte[] serverSignature = null;
 35
 36    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 37        super(account);
 38        this.channelBinding = channelBinding;
 39        if (channelBinding == ChannelBinding.NONE) {
 40            // TODO this needs to be changed to "y,," for the scram internal down grade protection
 41            // but we might risk compatibility issues if the server supports a binding that we don’t
 42            // support
 43            this.gs2Header = "n,,";
 44        } else {
 45            this.gs2Header =
 46                    String.format(
 47                            "p=%s,,",
 48                            CaseFormat.UPPER_UNDERSCORE
 49                                    .converterTo(CaseFormat.LOWER_HYPHEN)
 50                                    .convert(channelBinding.toString()));
 51        }
 52        // This nonce should be different for each authentication attempt.
 53        this.clientNonce = CryptoHelper.random(100);
 54        clientFirstMessageBare = "";
 55    }
 56
 57    protected abstract HMac getHMAC();
 58
 59    protected abstract Digest getDigest();
 60
 61    private KeyPair getKeyPair(final String password, final String salt, final int iterations)
 62            throws ExecutionException {
 63        return CACHE.get(
 64                new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations),
 65                () -> {
 66                    final byte[] saltedPassword, serverKey, clientKey;
 67                    saltedPassword =
 68                            hi(
 69                                    password.getBytes(),
 70                                    Base64.decode(salt, Base64.DEFAULT),
 71                                    iterations);
 72                    serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 73                    clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 74                    return new KeyPair(clientKey, serverKey);
 75                });
 76    }
 77
 78    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
 79        final HMac hMac = getHMAC();
 80        hMac.init(new KeyParameter(key));
 81        hMac.update(input, 0, input.length);
 82        final byte[] out = new byte[hMac.getMacSize()];
 83        hMac.doFinal(out, 0);
 84        return out;
 85    }
 86
 87    public byte[] digest(final byte[] bytes) {
 88        final Digest digest = getDigest();
 89        digest.reset();
 90        digest.update(bytes, 0, bytes.length);
 91        final byte[] out = new byte[digest.getDigestSize()];
 92        digest.doFinal(out, 0);
 93        return out;
 94    }
 95
 96    /*
 97     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
 98     * pseudorandom function (PRF) and with dkLen == output length of
 99     * HMAC() == output length of H().
100     */
101    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
102            throws InvalidKeyException {
103        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
104        byte[] out = u.clone();
105        for (int i = 1; i < iterations; i++) {
106            u = hmac(key, u);
107            for (int j = 0; j < u.length; j++) {
108                out[j] ^= u[j];
109            }
110        }
111        return out;
112    }
113
114    @Override
115    public String getClientFirstMessage(final SSLSocket sslSocket) {
116        if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
117            clientFirstMessageBare =
118                    "n="
119                            + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
120                            + ",r="
121                            + this.clientNonce;
122            state = State.AUTH_TEXT_SENT;
123        }
124        return Base64.encodeToString(
125                (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
126                Base64.NO_WRAP);
127    }
128
129    @Override
130    public String getResponse(final String challenge, final SSLSocket socket)
131            throws AuthenticationException {
132        switch (state) {
133            case AUTH_TEXT_SENT:
134                if (challenge == null) {
135                    throw new AuthenticationException("challenge can not be null");
136                }
137                byte[] serverFirstMessage;
138                try {
139                    serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
140                } catch (IllegalArgumentException e) {
141                    throw new AuthenticationException("Unable to decode server challenge", e);
142                }
143                final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
144                String nonce = "";
145                int iterationCount = -1;
146                String salt = "";
147                for (final String token : tokenizer) {
148                    if (token.charAt(1) == '=') {
149                        switch (token.charAt(0)) {
150                            case 'i':
151                                try {
152                                    iterationCount = Integer.parseInt(token.substring(2));
153                                } catch (final NumberFormatException e) {
154                                    throw new AuthenticationException(e);
155                                }
156                                break;
157                            case 's':
158                                salt = token.substring(2);
159                                break;
160                            case 'r':
161                                nonce = token.substring(2);
162                                break;
163                            case 'm':
164                                /*
165                                 * RFC 5802:
166                                 * m: This attribute is reserved for future extensibility.  In this
167                                 * version of SCRAM, its presence in a client or a server message
168                                 * MUST cause authentication failure when the attribute is parsed by
169                                 * the other end.
170                                 */
171                                throw new AuthenticationException(
172                                        "Server sent reserved token: `m'");
173                        }
174                    }
175                }
176
177                if (iterationCount < 0) {
178                    throw new AuthenticationException("Server did not send iteration count");
179                }
180                if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
181                    throw new AuthenticationException(
182                            "Server nonce does not contain client nonce: " + nonce);
183                }
184                if (salt.isEmpty()) {
185                    throw new AuthenticationException("Server sent empty salt");
186                }
187
188                final byte[] channelBindingData = getChannelBindingData(socket);
189
190                final int gs2Len = this.gs2Header.getBytes().length;
191                final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
192                System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
193                System.arraycopy(
194                        channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
195
196                final String clientFinalMessageWithoutProof =
197                        "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
198
199                final byte[] authMessage =
200                        (clientFirstMessageBare
201                                        + ','
202                                        + new String(serverFirstMessage)
203                                        + ','
204                                        + clientFinalMessageWithoutProof)
205                                .getBytes();
206
207                final KeyPair keys;
208                try {
209                    keys =
210                            getKeyPair(
211                                    CryptoHelper.saslPrep(account.getPassword()),
212                                    salt,
213                                    iterationCount);
214                } catch (ExecutionException e) {
215                    throw new AuthenticationException("Invalid keys generated");
216                }
217                final byte[] clientSignature;
218                try {
219                    serverSignature = hmac(keys.serverKey, authMessage);
220                    final byte[] storedKey = digest(keys.clientKey);
221
222                    clientSignature = hmac(storedKey, authMessage);
223
224                } catch (final InvalidKeyException e) {
225                    throw new AuthenticationException(e);
226                }
227
228                final byte[] clientProof = new byte[keys.clientKey.length];
229
230                if (clientSignature.length < keys.clientKey.length) {
231                    throw new AuthenticationException(
232                            "client signature was shorter than clientKey");
233                }
234
235                for (int i = 0; i < clientProof.length; i++) {
236                    clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
237                }
238
239                final String clientFinalMessage =
240                        clientFinalMessageWithoutProof
241                                + ",p="
242                                + Base64.encodeToString(clientProof, Base64.NO_WRAP);
243                state = State.RESPONSE_SENT;
244                return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
245            case RESPONSE_SENT:
246                try {
247                    final String clientCalculatedServerFinalMessage =
248                            "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
249                    if (!clientCalculatedServerFinalMessage.equals(
250                            new String(Base64.decode(challenge, Base64.DEFAULT)))) {
251                        throw new Exception();
252                    }
253                    state = State.VALID_SERVER_RESPONSE;
254                    return "";
255                } catch (Exception e) {
256                    throw new AuthenticationException(
257                            "Server final message does not match calculated final message");
258                }
259            default:
260                throw new InvalidStateException(state);
261        }
262    }
263
264    protected byte[] getChannelBindingData(final SSLSocket sslSocket)
265            throws AuthenticationException {
266        if (this.channelBinding == ChannelBinding.NONE) {
267            return new byte[0];
268        }
269        throw new AssertionError("getChannelBindingData needs to be overwritten");
270    }
271
272    private static class CacheKey {
273        final String algorithm;
274        final String password;
275        final String salt;
276        final int iterations;
277
278        private CacheKey(String algorithm, String password, String salt, int iterations) {
279            this.algorithm = algorithm;
280            this.password = password;
281            this.salt = salt;
282            this.iterations = iterations;
283        }
284
285        @Override
286        public boolean equals(Object o) {
287            if (this == o) return true;
288            if (o == null || getClass() != o.getClass()) return false;
289            CacheKey cacheKey = (CacheKey) o;
290            return iterations == cacheKey.iterations
291                    && Objects.equal(algorithm, cacheKey.algorithm)
292                    && Objects.equal(password, cacheKey.password)
293                    && Objects.equal(salt, cacheKey.salt);
294        }
295
296        @Override
297        public int hashCode() {
298            return Objects.hashCode(algorithm, password, salt, iterations);
299        }
300    }
301
302    private static class KeyPair {
303        final byte[] clientKey;
304        final byte[] serverKey;
305
306        KeyPair(final byte[] clientKey, final byte[] serverKey) {
307            this.clientKey = clientKey;
308            this.serverKey = serverKey;
309        }
310    }
311}