ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.util.Base64;
  4
  5import com.google.common.base.Objects;
  6import com.google.common.cache.Cache;
  7import com.google.common.cache.CacheBuilder;
  8
  9import org.bouncycastle.crypto.Digest;
 10import org.bouncycastle.crypto.macs.HMac;
 11import org.bouncycastle.crypto.params.KeyParameter;
 12
 13import java.nio.charset.Charset;
 14import java.security.InvalidKeyException;
 15import java.util.concurrent.ExecutionException;
 16
 17import eu.siacs.conversations.entities.Account;
 18import eu.siacs.conversations.utils.CryptoHelper;
 19
 20abstract class ScramMechanism extends SaslMechanism {
 21    // TODO: When channel binding (SCRAM-SHA1-PLUS) is supported in future, generalize this to
 22    // indicate support and/or usage.
 23    private static final String GS2_HEADER = "n,,";
 24    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 25    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 26    private static final Cache<CacheKey, KeyPair> CACHE =
 27            CacheBuilder.newBuilder().maximumSize(10).build();
 28    protected final ChannelBinding channelBinding;
 29    private final String clientNonce;
 30    protected State state = State.INITIAL;
 31    private String clientFirstMessageBare;
 32    private byte[] serverSignature = null;
 33
 34    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 35        super(account);
 36        this.channelBinding = channelBinding;
 37        // This nonce should be different for each authentication attempt.
 38        this.clientNonce = CryptoHelper.random(100);
 39        clientFirstMessageBare = "";
 40    }
 41
 42    protected abstract HMac getHMAC();
 43
 44    protected abstract Digest getDigest();
 45
 46    private KeyPair getKeyPair(final String password, final String salt, final int iterations)
 47            throws ExecutionException {
 48        return CACHE.get(
 49                new CacheKey(getHMAC().getAlgorithmName(), password, salt, iterations),
 50                () -> {
 51                    final byte[] saltedPassword, serverKey, clientKey;
 52                    saltedPassword =
 53                            hi(
 54                                    password.getBytes(),
 55                                    Base64.decode(salt, Base64.DEFAULT),
 56                                    iterations);
 57                    serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 58                    clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 59                    return new KeyPair(clientKey, serverKey);
 60                });
 61    }
 62
 63    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
 64        final HMac hMac = getHMAC();
 65        hMac.init(new KeyParameter(key));
 66        hMac.update(input, 0, input.length);
 67        final byte[] out = new byte[hMac.getMacSize()];
 68        hMac.doFinal(out, 0);
 69        return out;
 70    }
 71
 72    public byte[] digest(byte[] bytes) {
 73        final Digest digest = getDigest();
 74        digest.reset();
 75        digest.update(bytes, 0, bytes.length);
 76        final byte[] out = new byte[digest.getDigestSize()];
 77        digest.doFinal(out, 0);
 78        return out;
 79    }
 80
 81    /*
 82     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
 83     * pseudorandom function (PRF) and with dkLen == output length of
 84     * HMAC() == output length of H().
 85     */
 86    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
 87            throws InvalidKeyException {
 88        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
 89        byte[] out = u.clone();
 90        for (int i = 1; i < iterations; i++) {
 91            u = hmac(key, u);
 92            for (int j = 0; j < u.length; j++) {
 93                out[j] ^= u[j];
 94            }
 95        }
 96        return out;
 97    }
 98
 99    @Override
100    public String getClientFirstMessage() {
101        if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
102            clientFirstMessageBare =
103                    "n="
104                            + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
105                            + ",r="
106                            + this.clientNonce;
107            state = State.AUTH_TEXT_SENT;
108        }
109        return Base64.encodeToString(
110                (GS2_HEADER + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
111                Base64.NO_WRAP);
112    }
113
114    @Override
115    public String getResponse(final String challenge) throws AuthenticationException {
116        switch (state) {
117            case AUTH_TEXT_SENT:
118                if (challenge == null) {
119                    throw new AuthenticationException("challenge can not be null");
120                }
121                byte[] serverFirstMessage;
122                try {
123                    serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
124                } catch (IllegalArgumentException e) {
125                    throw new AuthenticationException("Unable to decode server challenge", e);
126                }
127                final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
128                String nonce = "";
129                int iterationCount = -1;
130                String salt = "";
131                for (final String token : tokenizer) {
132                    if (token.charAt(1) == '=') {
133                        switch (token.charAt(0)) {
134                            case 'i':
135                                try {
136                                    iterationCount = Integer.parseInt(token.substring(2));
137                                } catch (final NumberFormatException e) {
138                                    throw new AuthenticationException(e);
139                                }
140                                break;
141                            case 's':
142                                salt = token.substring(2);
143                                break;
144                            case 'r':
145                                nonce = token.substring(2);
146                                break;
147                            case 'm':
148                                /*
149                                 * RFC 5802:
150                                 * m: This attribute is reserved for future extensibility.  In this
151                                 * version of SCRAM, its presence in a client or a server message
152                                 * MUST cause authentication failure when the attribute is parsed by
153                                 * the other end.
154                                 */
155                                throw new AuthenticationException(
156                                        "Server sent reserved token: `m'");
157                        }
158                    }
159                }
160
161                if (iterationCount < 0) {
162                    throw new AuthenticationException("Server did not send iteration count");
163                }
164                if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
165                    throw new AuthenticationException(
166                            "Server nonce does not contain client nonce: " + nonce);
167                }
168                if (salt.isEmpty()) {
169                    throw new AuthenticationException("Server sent empty salt");
170                }
171
172                final String clientFinalMessageWithoutProof =
173                        "c="
174                                + Base64.encodeToString(GS2_HEADER.getBytes(), Base64.NO_WRAP)
175                                + ",r="
176                                + nonce;
177                final byte[] authMessage =
178                        (clientFirstMessageBare
179                                        + ','
180                                        + new String(serverFirstMessage)
181                                        + ','
182                                        + clientFinalMessageWithoutProof)
183                                .getBytes();
184
185                final KeyPair keys;
186                try {
187                    keys =
188                            getKeyPair(
189                                    CryptoHelper.saslPrep(account.getPassword()),
190                                    salt,
191                                    iterationCount);
192                } catch (ExecutionException e) {
193                    throw new AuthenticationException("Invalid keys generated");
194                }
195                final byte[] clientSignature;
196                try {
197                    serverSignature = hmac(keys.serverKey, authMessage);
198                    final byte[] storedKey = digest(keys.clientKey);
199
200                    clientSignature = hmac(storedKey, authMessage);
201
202                } catch (final InvalidKeyException e) {
203                    throw new AuthenticationException(e);
204                }
205
206                final byte[] clientProof = new byte[keys.clientKey.length];
207
208                if (clientSignature.length < keys.clientKey.length) {
209                    throw new AuthenticationException(
210                            "client signature was shorter than clientKey");
211                }
212
213                for (int i = 0; i < clientProof.length; i++) {
214                    clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
215                }
216
217                final String clientFinalMessage =
218                        clientFinalMessageWithoutProof
219                                + ",p="
220                                + Base64.encodeToString(clientProof, Base64.NO_WRAP);
221                state = State.RESPONSE_SENT;
222                return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
223            case RESPONSE_SENT:
224                try {
225                    final String clientCalculatedServerFinalMessage =
226                            "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
227                    if (!clientCalculatedServerFinalMessage.equals(
228                            new String(Base64.decode(challenge, Base64.DEFAULT)))) {
229                        throw new Exception();
230                    }
231                    state = State.VALID_SERVER_RESPONSE;
232                    return "";
233                } catch (Exception e) {
234                    throw new AuthenticationException(
235                            "Server final message does not match calculated final message");
236                }
237            default:
238                throw new InvalidStateException(state);
239        }
240    }
241
242    private static class CacheKey {
243        final String algorithm;
244        final String password;
245        final String salt;
246        final int iterations;
247
248        private CacheKey(String algorithm, String password, String salt, int iterations) {
249            this.algorithm = algorithm;
250            this.password = password;
251            this.salt = salt;
252            this.iterations = iterations;
253        }
254
255        @Override
256        public boolean equals(Object o) {
257            if (this == o) return true;
258            if (o == null || getClass() != o.getClass()) return false;
259            CacheKey cacheKey = (CacheKey) o;
260            return iterations == cacheKey.iterations
261                    && Objects.equal(algorithm, cacheKey.algorithm)
262                    && Objects.equal(password, cacheKey.password)
263                    && Objects.equal(salt, cacheKey.salt);
264        }
265
266        @Override
267        public int hashCode() {
268            return Objects.hashCode(algorithm, password, salt, iterations);
269        }
270    }
271
272    private static class KeyPair {
273        final byte[] clientKey;
274        final byte[] serverKey;
275
276        KeyPair(final byte[] clientKey, final byte[] serverKey) {
277            this.clientKey = clientKey;
278            this.serverKey = serverKey;
279        }
280    }
281}