ScramMechanism.java

  1package eu.siacs.conversations.crypto.sasl;
  2
  3import android.util.Base64;
  4import android.util.Log;
  5
  6import com.google.common.base.CaseFormat;
  7import com.google.common.base.Objects;
  8import com.google.common.cache.Cache;
  9import com.google.common.cache.CacheBuilder;
 10import com.google.common.hash.HashFunction;
 11
 12import java.nio.charset.Charset;
 13import java.security.InvalidKeyException;
 14import java.util.concurrent.ExecutionException;
 15
 16import javax.net.ssl.SSLSocket;
 17
 18import eu.siacs.conversations.Config;
 19import eu.siacs.conversations.entities.Account;
 20import eu.siacs.conversations.utils.CryptoHelper;
 21
 22abstract class ScramMechanism extends SaslMechanism {
 23
 24    private static final byte[] CLIENT_KEY_BYTES = "Client Key".getBytes();
 25    private static final byte[] SERVER_KEY_BYTES = "Server Key".getBytes();
 26    private static final Cache<CacheKey, KeyPair> CACHE =
 27            CacheBuilder.newBuilder().maximumSize(10).build();
 28    protected final ChannelBinding channelBinding;
 29    private final String gs2Header;
 30    private final String clientNonce;
 31    protected State state = State.INITIAL;
 32    private String clientFirstMessageBare;
 33    private byte[] serverSignature = null;
 34
 35    ScramMechanism(final Account account, final ChannelBinding channelBinding) {
 36        super(account);
 37        this.channelBinding = channelBinding;
 38        if (channelBinding == ChannelBinding.NONE) {
 39            // TODO this needs to be changed to "y,," for the scram internal down grade protection
 40            // but we might risk compatibility issues if the server supports a binding that we don’t
 41            // support
 42            this.gs2Header = "n,,";
 43        } else {
 44            this.gs2Header =
 45                    String.format(
 46                            "p=%s,,",
 47                            CaseFormat.UPPER_UNDERSCORE
 48                                    .converterTo(CaseFormat.LOWER_HYPHEN)
 49                                    .convert(channelBinding.toString()));
 50        }
 51        // This nonce should be different for each authentication attempt.
 52        this.clientNonce = CryptoHelper.random(100);
 53        clientFirstMessageBare = "";
 54    }
 55
 56    protected abstract HashFunction getHMac(final byte[] key);
 57
 58    protected abstract HashFunction getDigest();
 59
 60    private KeyPair getKeyPair(final String password, final String salt, final int iterations)
 61            throws ExecutionException {
 62        return CACHE.get(
 63                new CacheKey(getMechanism(), password, salt, iterations),
 64                () -> {
 65                    final byte[] saltedPassword, serverKey, clientKey;
 66                    saltedPassword =
 67                            hi(
 68                                    password.getBytes(),
 69                                    Base64.decode(salt, Base64.DEFAULT),
 70                                    iterations);
 71                    serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
 72                    clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
 73                    return new KeyPair(clientKey, serverKey);
 74                });
 75    }
 76
 77    private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
 78        return getHMac(key).hashBytes(input).asBytes();
 79    }
 80
 81    private byte[] digest(final byte[] bytes) {
 82        return getDigest().hashBytes(bytes).asBytes();
 83    }
 84
 85    /*
 86     * Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the
 87     * pseudorandom function (PRF) and with dkLen == output length of
 88     * HMAC() == output length of H().
 89     */
 90    private byte[] hi(final byte[] key, final byte[] salt, final int iterations)
 91            throws InvalidKeyException {
 92        byte[] u = hmac(key, CryptoHelper.concatenateByteArrays(salt, CryptoHelper.ONE));
 93        byte[] out = u.clone();
 94        for (int i = 1; i < iterations; i++) {
 95            u = hmac(key, u);
 96            for (int j = 0; j < u.length; j++) {
 97                out[j] ^= u[j];
 98            }
 99        }
100        return out;
101    }
102
103    @Override
104    public String getClientFirstMessage(final SSLSocket sslSocket) {
105        if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
106            clientFirstMessageBare =
107                    "n="
108                            + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
109                            + ",r="
110                            + this.clientNonce;
111            state = State.AUTH_TEXT_SENT;
112        }
113        return Base64.encodeToString(
114                (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
115                Base64.NO_WRAP);
116    }
117
118    @Override
119    public String getResponse(final String challenge, final SSLSocket socket)
120            throws AuthenticationException {
121        switch (state) {
122            case AUTH_TEXT_SENT:
123                if (challenge == null) {
124                    throw new AuthenticationException("challenge can not be null");
125                }
126                byte[] serverFirstMessage;
127                try {
128                    serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
129                } catch (IllegalArgumentException e) {
130                    throw new AuthenticationException("Unable to decode server challenge", e);
131                }
132                final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
133                String nonce = "";
134                int iterationCount = -1;
135                String salt = "";
136                for (final String token : tokenizer) {
137                    if (token.charAt(1) == '=') {
138                        switch (token.charAt(0)) {
139                            case 'i':
140                                try {
141                                    iterationCount = Integer.parseInt(token.substring(2));
142                                } catch (final NumberFormatException e) {
143                                    throw new AuthenticationException(e);
144                                }
145                                break;
146                            case 's':
147                                salt = token.substring(2);
148                                break;
149                            case 'r':
150                                nonce = token.substring(2);
151                                break;
152                            case 'm':
153                                /*
154                                 * RFC 5802:
155                                 * m: This attribute is reserved for future extensibility.  In this
156                                 * version of SCRAM, its presence in a client or a server message
157                                 * MUST cause authentication failure when the attribute is parsed by
158                                 * the other end.
159                                 */
160                                throw new AuthenticationException(
161                                        "Server sent reserved token: `m'");
162                        }
163                    }
164                }
165
166                if (iterationCount < 0) {
167                    throw new AuthenticationException("Server did not send iteration count");
168                }
169                if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
170                    throw new AuthenticationException(
171                            "Server nonce does not contain client nonce: " + nonce);
172                }
173                if (salt.isEmpty()) {
174                    throw new AuthenticationException("Server sent empty salt");
175                }
176
177                final byte[] channelBindingData = getChannelBindingData(socket);
178
179                final int gs2Len = this.gs2Header.getBytes().length;
180                final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
181                System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
182                System.arraycopy(
183                        channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
184
185                final String clientFinalMessageWithoutProof =
186                        "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
187
188                final byte[] authMessage =
189                        (clientFirstMessageBare
190                                        + ','
191                                        + new String(serverFirstMessage)
192                                        + ','
193                                        + clientFinalMessageWithoutProof)
194                                .getBytes();
195
196                final KeyPair keys;
197                try {
198                    keys =
199                            getKeyPair(
200                                    CryptoHelper.saslPrep(account.getPassword()),
201                                    salt,
202                                    iterationCount);
203                } catch (ExecutionException e) {
204                    throw new AuthenticationException("Invalid keys generated");
205                }
206                final byte[] clientSignature;
207                try {
208                    serverSignature = hmac(keys.serverKey, authMessage);
209                    final byte[] storedKey = digest(keys.clientKey);
210
211                    clientSignature = hmac(storedKey, authMessage);
212
213                } catch (final InvalidKeyException e) {
214                    throw new AuthenticationException(e);
215                }
216
217                final byte[] clientProof = new byte[keys.clientKey.length];
218
219                if (clientSignature.length < keys.clientKey.length) {
220                    throw new AuthenticationException(
221                            "client signature was shorter than clientKey");
222                }
223
224                for (int i = 0; i < clientProof.length; i++) {
225                    clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
226                }
227
228                final String clientFinalMessage =
229                        clientFinalMessageWithoutProof
230                                + ",p="
231                                + Base64.encodeToString(clientProof, Base64.NO_WRAP);
232                state = State.RESPONSE_SENT;
233                return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
234            case RESPONSE_SENT:
235                try {
236                    final String clientCalculatedServerFinalMessage =
237                            "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
238                    if (!clientCalculatedServerFinalMessage.equals(
239                            new String(Base64.decode(challenge, Base64.DEFAULT)))) {
240                        throw new Exception();
241                    }
242                    state = State.VALID_SERVER_RESPONSE;
243                    return "";
244                } catch (Exception e) {
245                    throw new AuthenticationException(
246                            "Server final message does not match calculated final message");
247                }
248            default:
249                throw new InvalidStateException(state);
250        }
251    }
252
253    protected byte[] getChannelBindingData(final SSLSocket sslSocket)
254            throws AuthenticationException {
255        if (this.channelBinding == ChannelBinding.NONE) {
256            return new byte[0];
257        }
258        throw new AssertionError("getChannelBindingData needs to be overwritten");
259    }
260
261    private static class CacheKey {
262        final String algorithm;
263        final String password;
264        final String salt;
265        final int iterations;
266
267        private CacheKey(String algorithm, String password, String salt, int iterations) {
268            this.algorithm = algorithm;
269            this.password = password;
270            this.salt = salt;
271            this.iterations = iterations;
272        }
273
274        @Override
275        public boolean equals(Object o) {
276            if (this == o) return true;
277            if (o == null || getClass() != o.getClass()) return false;
278            CacheKey cacheKey = (CacheKey) o;
279            return iterations == cacheKey.iterations
280                    && Objects.equal(algorithm, cacheKey.algorithm)
281                    && Objects.equal(password, cacheKey.password)
282                    && Objects.equal(salt, cacheKey.salt);
283        }
284
285        @Override
286        public int hashCode() {
287            return Objects.hashCode(algorithm, password, salt, iterations);
288        }
289    }
290
291    private static class KeyPair {
292        final byte[] clientKey;
293        final byte[] serverKey;
294
295        KeyPair(final byte[] clientKey, final byte[] serverKey) {
296            this.clientKey = clientKey;
297            this.serverKey = serverKey;
298        }
299    }
300}