ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.annotation.TargetApi;
  4import android.os.Build;
  5import android.util.Base64;
  6import android.util.LruCache;
  7
  8import org.bouncycastle.crypto.Digest;
  9import org.bouncycastle.crypto.macs.HMac;
 10import org.bouncycastle.crypto.params.KeyParameter;
 11
 12import java.math.BigInteger;
 13import java.nio.charset.Charset;
 14import java.security.InvalidKeyException;
 15import java.security.SecureRandom;
 16
 17import eu.siacs.conversations.entities.Account;
 18import eu.siacs.conversations.utils.CryptoHelper;
 19import eu.siacs.conversations.xml.TagWriter;
 20
 21@TargetApi(Build.VERSION_CODES.HONEYCOMB_MR1)
 22abstract class ScramMechanism extends SaslMechanism {
 23    // TODO: When channel binding (SCRAM-SHA1-PLUS) is supported in future, generalize this to indicate support and/or usage.
 24    private final static String GS2_HEADER = "n,,";
 25    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 26    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 27    private static final LruCache<String, KeyPair> CACHE;
 28    static HMac HMAC;
 29    static Digest DIGEST;
 30
 31    static {
 32        CACHE = new LruCache<String, KeyPair>(10) {
 33            protected KeyPair create(final String k) {
 34                // Map keys are "bytesToHex(JID),bytesToHex(password),bytesToHex(salt),iterations,SASL-Mechanism".
 35                // Changing any of these values forces a cache miss. `CryptoHelper.bytesToHex()'
 36                // is applied to prevent commas in the strings breaking things.
 37                final String[] kParts = k.split(",", 5);
 38                try {
 39                    final byte[] saltedPassword, serverKey, clientKey;
 40                    saltedPassword = hi(CryptoHelper.hexToString(kParts[1]).getBytes(),
 41                            Base64.decode(CryptoHelper.hexToString(kParts[2]), Base64.DEFAULT), Integer.parseInt(kParts[3]));
 42                    serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 43                    clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 44
 45                    return new KeyPair(clientKey, serverKey);
 46                } catch (final InvalidKeyException | NumberFormatException e) {
 47                    return null;
 48                }
 49            }
 50        };
 51    }
 52
 53    private final String clientNonce;
 54    protected State state = State.INITIAL;
 55    private String clientFirstMessageBare;
 56    private byte[] serverSignature = null;
 57
 58    ScramMechanism(final TagWriter tagWriter, final Account account, final SecureRandom rng) {
 59        super(tagWriter, account, rng);
 60
 61        // This nonce should be different for each authentication attempt.
 62        clientNonce = CryptoHelper.random(100, rng);
 63        clientFirstMessageBare = "";
 64    }
 65
 66    private static synchronized byte[] hmac(final byte[] key, final byte[] input)
 67            throws InvalidKeyException {
 68        HMAC.init(new KeyParameter(key));
 69        HMAC.update(input, 0, input.length);
 70        final byte[] out = new byte[HMAC.getMacSize()];
 71        HMAC.doFinal(out, 0);
 72        return out;
 73    }
 74
 75    public static synchronized byte[] digest(byte[] bytes) {
 76        DIGEST.reset();
 77        DIGEST.update(bytes, 0, bytes.length);
 78        final byte[] out = new byte[DIGEST.getDigestSize()];
 79        DIGEST.doFinal(out, 0);
 80        return out;
 81    }
 82
 83    /*
 84     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
 85     * pseudorandom function (PRF) and with dkLen == output length of
 86     * HMAC() == output length of H().
 87     */
 88    private static synchronized byte[] hi(final byte[] key, final byte[] salt, final int iterations)
 89            throws InvalidKeyException {
 90        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
 91        byte[] out = u.clone();
 92        for (int i = 1; i < iterations; i++) {
 93            u = hmac(key, u);
 94            for (int j = 0; j < u.length; j++) {
 95                out[j] ^= u[j];
 96            }
 97        }
 98        return out;
 99    }
100
101    @Override
102    public String getClientFirstMessage() {
103        if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
104            clientFirstMessageBare = "n=" + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())) +
105                    ",r=" + this.clientNonce;
106            state = State.AUTH_TEXT_SENT;
107        }
108        return Base64.encodeToString(
109                (GS2_HEADER + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
110                Base64.NO_WRAP);
111    }
112
113    @Override
114    public String getResponse(final String challenge) throws AuthenticationException {
115        switch (state) {
116            case AUTH_TEXT_SENT:
117                if (challenge == null) {
118                    throw new AuthenticationException("challenge can not be null");
119                }
120                byte[] serverFirstMessage;
121                try {
122                    serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
123                } catch (IllegalArgumentException e) {
124                    throw new AuthenticationException("Unable to decode server challenge", e);
125                }
126                final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
127                String nonce = "";
128                int iterationCount = -1;
129                String salt = "";
130                for (final String token : tokenizer) {
131                    if (token.charAt(1) == '=') {
132                        switch (token.charAt(0)) {
133                            case 'i':
134                                try {
135                                    iterationCount = Integer.parseInt(token.substring(2));
136                                } catch (final NumberFormatException e) {
137                                    throw new AuthenticationException(e);
138                                }
139                                break;
140                            case 's':
141                                salt = token.substring(2);
142                                break;
143                            case 'r':
144                                nonce = token.substring(2);
145                                break;
146                            case 'm':
147                                /*
148                                 * RFC 5802:
149                                 * m: This attribute is reserved for future extensibility.  In this
150                                 * version of SCRAM, its presence in a client or a server message
151                                 * MUST cause authentication failure when the attribute is parsed by
152                                 * the other end.
153                                 */
154                                throw new AuthenticationException("Server sent reserved token: `m'");
155                        }
156                    }
157                }
158
159                if (iterationCount < 0) {
160                    throw new AuthenticationException("Server did not send iteration count");
161                }
162                if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
163                    throw new AuthenticationException("Server nonce does not contain client nonce: " + nonce);
164                }
165                if (salt.isEmpty()) {
166                    throw new AuthenticationException("Server sent empty salt");
167                }
168
169                final String clientFinalMessageWithoutProof = "c=" + Base64.encodeToString(
170                        GS2_HEADER.getBytes(), Base64.NO_WRAP) + ",r=" + nonce;
171                final byte[] authMessage = (clientFirstMessageBare + ',' + new String(serverFirstMessage) + ','
172                        + clientFinalMessageWithoutProof).getBytes();
173
174                // Map keys are "bytesToHex(JID),bytesToHex(password),bytesToHex(salt),iterations,SASL-Mechanism".
175                final KeyPair keys = CACHE.get(
176                        CryptoHelper.bytesToHex(CryptoHelper.saslPrep(account.getJid().asBareJid().toEscapedString()).getBytes()) + ","
177                                + CryptoHelper.bytesToHex(CryptoHelper.saslPrep(account.getPassword()).getBytes()) + ","
178                                + CryptoHelper.bytesToHex(salt.getBytes()) + ","
179                                + iterationCount + ","
180                                + getMechanism()
181                );
182                if (keys == null) {
183                    throw new AuthenticationException("Invalid keys generated");
184                }
185                final byte[] clientSignature;
186                try {
187                    serverSignature = hmac(keys.serverKey, authMessage);
188                    final byte[] storedKey = digest(keys.clientKey);
189
190                    clientSignature = hmac(storedKey, authMessage);
191
192                } catch (final InvalidKeyException e) {
193                    throw new AuthenticationException(e);
194                }
195
196                final byte[] clientProof = new byte[keys.clientKey.length];
197
198                if (clientSignature.length < keys.clientKey.length) {
199                    throw new AuthenticationException("client signature was shorter than clientKey");
200                }
201
202                for (int i = 0; i < clientProof.length; i++) {
203                    clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
204                }
205
206
207                final String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" +
208                        Base64.encodeToString(clientProof, Base64.NO_WRAP);
209                state = State.RESPONSE_SENT;
210                return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
211            case RESPONSE_SENT:
212                try {
213                    final String clientCalculatedServerFinalMessage = "v=" +
214                            Base64.encodeToString(serverSignature, Base64.NO_WRAP);
215                    if (!clientCalculatedServerFinalMessage.equals(new String(Base64.decode(challenge, Base64.DEFAULT)))) {
216                        throw new Exception();
217                    }
218                    state = State.VALID_SERVER_RESPONSE;
219                    return "";
220                } catch (Exception e) {
221                    throw new AuthenticationException("Server final message does not match calculated final message");
222                }
223            default:
224                throw new InvalidStateException(state);
225        }
226    }
227
228    private static class KeyPair {
229        final byte[] clientKey;
230        final byte[] serverKey;
231
232        KeyPair(final byte[] clientKey, final byte[] serverKey) {
233            this.clientKey = clientKey;
234            this.serverKey = serverKey;
235        }
236    }
237}