ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.annotation.TargetApi;
  4import android.os.Build;
  5import android.util.Base64;
  6
  7import com.google.common.base.Objects;
  8import com.google.common.cache.Cache;
  9import com.google.common.cache.CacheBuilder;
 10
 11import org.bouncycastle.crypto.Digest;
 12import org.bouncycastle.crypto.macs.HMac;
 13import org.bouncycastle.crypto.params.KeyParameter;
 14
 15import java.nio.charset.Charset;
 16import java.security.InvalidKeyException;
 17import java.security.SecureRandom;
 18import java.util.concurrent.ExecutionException;
 19
 20import eu.siacs.conversations.entities.Account;
 21import eu.siacs.conversations.utils.CryptoHelper;
 22import eu.siacs.conversations.xml.TagWriter;
 23
 24@TargetApi(Build.VERSION_CODES.HONEYCOMB_MR1)
 25abstract class ScramMechanism extends SaslMechanism {
 26    // TODO: When channel binding (SCRAM-SHA1-PLUS) is supported in future, generalize this to indicate support and/or usage.
 27    private final static String GS2_HEADER = "n,,";
 28    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 29    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 30
 31    protected abstract HMac getHMAC();
 32
 33    protected abstract Digest getDigest();
 34
 35    private static final Cache<CacheKey, KeyPair> CACHE = CacheBuilder.newBuilder().maximumSize(10).build();
 36
 37    private static class CacheKey {
 38        final String algorithm;
 39        final String password;
 40        final String salt;
 41        final int iterations;
 42
 43        private CacheKey(String algorithm, String password, String salt, int iterations) {
 44            this.algorithm = algorithm;
 45            this.password = password;
 46            this.salt = salt;
 47            this.iterations = iterations;
 48        }
 49
 50        @Override
 51        public boolean equals(Object o) {
 52            if (this == o) return true;
 53            if (o == null || getClass() != o.getClass()) return false;
 54            CacheKey cacheKey = (CacheKey) o;
 55            return iterations == cacheKey.iterations &&
 56                    Objects.equal(algorithm, cacheKey.algorithm) &&
 57                    Objects.equal(password, cacheKey.password) &&
 58                    Objects.equal(salt, cacheKey.salt);
 59        }
 60
 61        @Override
 62        public int hashCode() {
 63            return Objects.hashCode(algorithm, password, salt, iterations);
 64        }
 65    }
 66
 67    private KeyPair getKeyPair(final String password, final String salt, final int iterations) throws ExecutionException {
 68        return CACHE.get(new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations), () -> {
 69            final byte[] saltedPassword, serverKey, clientKey;
 70            saltedPassword = hi(password.getBytes(), Base64.decode(salt, Base64.DEFAULT), iterations);
 71            serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 72            clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 73            return new KeyPair(clientKey, serverKey);
 74        });
 75    }
 76
 77    private final String clientNonce;
 78    protected State state = State.INITIAL;
 79    private String clientFirstMessageBare;
 80    private byte[] serverSignature = null;
 81
 82    ScramMechanism(final TagWriter tagWriter, final Account account, final SecureRandom rng) {
 83        super(tagWriter, account, rng);
 84
 85        // This nonce should be different for each authentication attempt.
 86        clientNonce = CryptoHelper.random(100, rng);
 87        clientFirstMessageBare = "";
 88    }
 89
 90    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
 91        final HMac hMac = getHMAC();
 92        hMac.init(new KeyParameter(key));
 93        hMac.update(input, 0, input.length);
 94        final byte[] out = new byte[hMac.getMacSize()];
 95        hMac.doFinal(out, 0);
 96        return out;
 97    }
 98
 99    public byte[] digest(byte[] bytes) {
100        final Digest digest = getDigest();
101        digest.reset();
102        digest.update(bytes, 0, bytes.length);
103        final byte[] out = new byte[digest.getDigestSize()];
104        digest.doFinal(out, 0);
105        return out;
106    }
107
108    /*
109     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
110     * pseudorandom function (PRF) and with dkLen == output length of
111     * HMAC() == output length of H().
112     */
113    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
114            throws InvalidKeyException {
115        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
116        byte[] out = u.clone();
117        for (int i = 1; i < iterations; i++) {
118            u = hmac(key, u);
119            for (int j = 0; j < u.length; j++) {
120                out[j] ^= u[j];
121            }
122        }
123        return out;
124    }
125
126    @Override
127    public String getClientFirstMessage() {
128        if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
129            clientFirstMessageBare = "n=" + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())) +
130                    ",r=" + this.clientNonce;
131            state = State.AUTH_TEXT_SENT;
132        }
133        return Base64.encodeToString(
134                (GS2_HEADER + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
135                Base64.NO_WRAP);
136    }
137
138    @Override
139    public String getResponse(final String challenge) throws AuthenticationException {
140        switch (state) {
141            case AUTH_TEXT_SENT:
142                if (challenge == null) {
143                    throw new AuthenticationException("challenge can not be null");
144                }
145                byte[] serverFirstMessage;
146                try {
147                    serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
148                } catch (IllegalArgumentException e) {
149                    throw new AuthenticationException("Unable to decode server challenge", e);
150                }
151                final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
152                String nonce = "";
153                int iterationCount = -1;
154                String salt = "";
155                for (final String token : tokenizer) {
156                    if (token.charAt(1) == '=') {
157                        switch (token.charAt(0)) {
158                            case 'i':
159                                try {
160                                    iterationCount = Integer.parseInt(token.substring(2));
161                                } catch (final NumberFormatException e) {
162                                    throw new AuthenticationException(e);
163                                }
164                                break;
165                            case 's':
166                                salt = token.substring(2);
167                                break;
168                            case 'r':
169                                nonce = token.substring(2);
170                                break;
171                            case 'm':
172                                /*
173                                 * RFC 5802:
174                                 * m: This attribute is reserved for future extensibility.  In this
175                                 * version of SCRAM, its presence in a client or a server message
176                                 * MUST cause authentication failure when the attribute is parsed by
177                                 * the other end.
178                                 */
179                                throw new AuthenticationException("Server sent reserved token: `m'");
180                        }
181                    }
182                }
183
184                if (iterationCount < 0) {
185                    throw new AuthenticationException("Server did not send iteration count");
186                }
187                if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
188                    throw new AuthenticationException("Server nonce does not contain client nonce: " + nonce);
189                }
190                if (salt.isEmpty()) {
191                    throw new AuthenticationException("Server sent empty salt");
192                }
193
194                final String clientFinalMessageWithoutProof = "c=" + Base64.encodeToString(
195                        GS2_HEADER.getBytes(), Base64.NO_WRAP) + ",r=" + nonce;
196                final byte[] authMessage = (clientFirstMessageBare + ',' + new String(serverFirstMessage) + ','
197                        + clientFinalMessageWithoutProof).getBytes();
198
199                final KeyPair keys;
200                try {
201                    keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
202                } catch (ExecutionException e) {
203                    throw new AuthenticationException("Invalid keys generated");
204                }
205                final byte[] clientSignature;
206                try {
207                    serverSignature = hmac(keys.serverKey, authMessage);
208                    final byte[] storedKey = digest(keys.clientKey);
209
210                    clientSignature = hmac(storedKey, authMessage);
211
212                } catch (final InvalidKeyException e) {
213                    throw new AuthenticationException(e);
214                }
215
216                final byte[] clientProof = new byte[keys.clientKey.length];
217
218                if (clientSignature.length < keys.clientKey.length) {
219                    throw new AuthenticationException("client signature was shorter than clientKey");
220                }
221
222                for (int i = 0; i < clientProof.length; i++) {
223                    clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
224                }
225
226
227                final String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" +
228                        Base64.encodeToString(clientProof, Base64.NO_WRAP);
229                state = State.RESPONSE_SENT;
230                return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
231            case RESPONSE_SENT:
232                try {
233                    final String clientCalculatedServerFinalMessage = "v=" +
234                            Base64.encodeToString(serverSignature, Base64.NO_WRAP);
235                    if (!clientCalculatedServerFinalMessage.equals(new String(Base64.decode(challenge, Base64.DEFAULT)))) {
236                        throw new Exception();
237                    }
238                    state = State.VALID_SERVER_RESPONSE;
239                    return "";
240                } catch (Exception e) {
241                    throw new AuthenticationException("Server final message does not match calculated final message");
242                }
243            default:
244                throw new InvalidStateException(state);
245        }
246    }
247
248    private static class KeyPair {
249        final byte[] clientKey;
250        final byte[] serverKey;
251
252        KeyPair(final byte[] clientKey, final byte[] serverKey) {
253            this.clientKey = clientKey;
254            this.serverKey = serverKey;
255        }
256    }
257}