@@ -1,23 +1,25 @@
package eu.siacs.conversations.crypto.sasl;
-import android.util.Base64;
-
import com.google.common.base.CaseFormat;
+import com.google.common.base.Joiner;
import com.google.common.base.Objects;
+import com.google.common.base.Splitter;
+import com.google.common.base.Strings;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
+import com.google.common.collect.ImmutableMap;
import com.google.common.hash.HashFunction;
-
-import java.nio.charset.Charset;
+import com.google.common.io.BaseEncoding;
+import com.google.common.primitives.Ints;
+import eu.siacs.conversations.entities.Account;
+import eu.siacs.conversations.utils.CryptoHelper;
import java.security.InvalidKeyException;
+import java.util.Arrays;
+import java.util.Map;
import java.util.concurrent.ExecutionException;
-
import javax.crypto.SecretKey;
import javax.net.ssl.SSLSocket;
-import eu.siacs.conversations.entities.Account;
-import eu.siacs.conversations.utils.CryptoHelper;
-
abstract class ScramMechanism extends SaslMechanism {
public static final SecretKey EMPTY_KEY =
@@ -46,7 +48,7 @@ abstract class ScramMechanism extends SaslMechanism {
private final String gs2Header;
private final String clientNonce;
protected State state = State.INITIAL;
- private String clientFirstMessageBare;
+ private final String clientFirstMessageBare;
private byte[] serverSignature = null;
ScramMechanism(final Account account, final ChannelBinding channelBinding) {
@@ -67,28 +69,35 @@ abstract class ScramMechanism extends SaslMechanism {
}
// This nonce should be different for each authentication attempt.
this.clientNonce = CryptoHelper.random(100);
- clientFirstMessageBare = "";
+ this.clientFirstMessageBare =
+ String.format(
+ "n=%s,r=%s",
+ CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername())),
+ this.clientNonce);
}
protected abstract HashFunction getHMac(final byte[] key);
protected abstract HashFunction getDigest();
- private KeyPair getKeyPair(final String password, final String salt, final int iterations)
+ private KeyPair getKeyPair(final String password, final byte[] salt, final int iterations)
throws ExecutionException {
- return CACHE.get(
- new CacheKey(getMechanism(), password, salt, iterations),
- () -> {
- final byte[] saltedPassword, serverKey, clientKey;
- saltedPassword =
- hi(
- password.getBytes(),
- Base64.decode(salt, Base64.DEFAULT),
- iterations);
- serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
- clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
- return new KeyPair(clientKey, serverKey);
- });
+ final var key = new CacheKey(getMechanism(), password, salt, iterations);
+ return CACHE.get(key, () -> calculateKeyPair(password, salt, iterations));
+ }
+
+ private KeyPair calculateKeyPair(final String password, final byte[] salt, final int iterations)
+ throws InvalidKeyException {
+ final byte[] saltedPassword, serverKey, clientKey;
+ saltedPassword = hi(password.getBytes(), salt, iterations);
+ serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
+ clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
+ return new KeyPair(clientKey, serverKey);
+ }
+
+ @Override
+ public String getMechanism() {
+ return "";
}
private byte[] hmac(final byte[] key, final byte[] input) throws InvalidKeyException {
@@ -119,152 +128,155 @@ abstract class ScramMechanism extends SaslMechanism {
@Override
public String getClientFirstMessage(final SSLSocket sslSocket) {
- if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
- clientFirstMessageBare =
- "n="
- + CryptoHelper.saslEscape(CryptoHelper.saslPrep(account.getUsername()))
- + ",r="
- + this.clientNonce;
- state = State.AUTH_TEXT_SENT;
+ if (this.state != State.INITIAL) {
+ throw new IllegalArgumentException("Calling getClientFirstMessage from invalid state");
}
- return Base64.encodeToString(
- (gs2Header + clientFirstMessageBare).getBytes(Charset.defaultCharset()),
- Base64.NO_WRAP);
+ this.state = State.AUTH_TEXT_SENT;
+ final byte[] message = (gs2Header + clientFirstMessageBare).getBytes();
+ return BaseEncoding.base64().encode(message);
}
@Override
public String getResponse(final String challenge, final SSLSocket socket)
throws AuthenticationException {
- switch (state) {
- case AUTH_TEXT_SENT:
- if (challenge == null) {
- throw new AuthenticationException("challenge can not be null");
- }
- byte[] serverFirstMessage;
- try {
- serverFirstMessage = Base64.decode(challenge, Base64.DEFAULT);
- } catch (IllegalArgumentException e) {
- throw new AuthenticationException("Unable to decode server challenge", e);
- }
- final Tokenizer tokenizer = new Tokenizer(serverFirstMessage);
- String nonce = "";
- int iterationCount = -1;
- String salt = "";
- for (final String token : tokenizer) {
- if (token.charAt(1) == '=') {
- switch (token.charAt(0)) {
- case 'i':
- try {
- iterationCount = Integer.parseInt(token.substring(2));
- } catch (final NumberFormatException e) {
- throw new AuthenticationException(e);
- }
- break;
- case 's':
- salt = token.substring(2);
- break;
- case 'r':
- nonce = token.substring(2);
- break;
- case 'm':
- /*
- * RFC 5802:
- * m: This attribute is reserved for future extensibility. In this
- * version of SCRAM, its presence in a client or a server message
- * MUST cause authentication failure when the attribute is parsed by
- * the other end.
- */
- throw new AuthenticationException(
- "Server sent reserved token: `m'");
- }
- }
- }
+ return switch (state) {
+ case AUTH_TEXT_SENT -> processServerFirstMessage(challenge, socket);
+ case RESPONSE_SENT -> processServerFinalMessage(challenge);
+ default -> throw new InvalidStateException(state);
+ };
+ }
- if (iterationCount < 0) {
- throw new AuthenticationException("Server did not send iteration count");
- }
- if (nonce.isEmpty() || !nonce.startsWith(clientNonce)) {
- throw new AuthenticationException(
- "Server nonce does not contain client nonce: " + nonce);
- }
- if (salt.isEmpty()) {
- throw new AuthenticationException("Server sent empty salt");
- }
+ private String processServerFirstMessage(final String challenge, final SSLSocket socket)
+ throws AuthenticationException {
+ if (Strings.isNullOrEmpty(challenge)) {
+ throw new AuthenticationException("challenge can not be null");
+ }
+ byte[] serverFirstMessage;
+ try {
+ serverFirstMessage = BaseEncoding.base64().decode(challenge);
+ } catch (final IllegalArgumentException e) {
+ throw new AuthenticationException("Unable to decode server challenge", e);
+ }
+ final Map<String, String> attributes;
+ try {
+ attributes = splitToAttributes(new String(serverFirstMessage));
+ } catch (final IllegalArgumentException e) {
+ throw new AuthenticationException("Duplicate attributes");
+ }
+ if (attributes.containsKey("m")) {
+ /*
+ * RFC 5802:
+ * m: This attribute is reserved for future extensibility. In this
+ * version of SCRAM, its presence in a client or a server message
+ * MUST cause authentication failure when the attribute is parsed by
+ * the other end.
+ */
+ throw new AuthenticationException("Server sent reserved token: 'm'");
+ }
+ final String i = attributes.get("i");
+ final String s = attributes.get("s");
+ final String nonce = attributes.get("r");
+ final String d = attributes.get("d");
+ if (Strings.isNullOrEmpty(s) || Strings.isNullOrEmpty(nonce) || Strings.isNullOrEmpty(i)) {
+ throw new AuthenticationException("Missing attributes from server first message");
+ }
+ final Integer iterationCount = Ints.tryParse(i);
- final byte[] channelBindingData = getChannelBindingData(socket);
-
- final int gs2Len = this.gs2Header.getBytes().length;
- final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
- System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
- System.arraycopy(
- channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
-
- final String clientFinalMessageWithoutProof =
- "c=" + Base64.encodeToString(cMessage, Base64.NO_WRAP) + ",r=" + nonce;
-
- final byte[] authMessage =
- (clientFirstMessageBare
- + ','
- + new String(serverFirstMessage)
- + ','
- + clientFinalMessageWithoutProof)
- .getBytes();
-
- final KeyPair keys;
- try {
- keys =
- getKeyPair(
- CryptoHelper.saslPrep(account.getPassword()),
- salt,
- iterationCount);
- } catch (ExecutionException e) {
- throw new AuthenticationException("Invalid keys generated");
- }
- final byte[] clientSignature;
- try {
- serverSignature = hmac(keys.serverKey, authMessage);
- final byte[] storedKey = digest(keys.clientKey);
+ if (iterationCount == null || iterationCount < 0) {
+ throw new AuthenticationException("Server did not send iteration count");
+ }
+ if (!nonce.startsWith(clientNonce)) {
+ throw new AuthenticationException(
+ "Server nonce does not contain client nonce: " + nonce);
+ }
- clientSignature = hmac(storedKey, authMessage);
+ final byte[] salt;
- } catch (final InvalidKeyException e) {
- throw new AuthenticationException(e);
- }
+ try {
+ salt = BaseEncoding.base64().decode(s);
+ } catch (final IllegalArgumentException e) {
+ throw new AuthenticationException("Invalid salt in server first message");
+ }
- final byte[] clientProof = new byte[keys.clientKey.length];
+ final byte[] channelBindingData = getChannelBindingData(socket);
- if (clientSignature.length < keys.clientKey.length) {
- throw new AuthenticationException(
- "client signature was shorter than clientKey");
- }
+ final int gs2Len = this.gs2Header.getBytes().length;
+ final byte[] cMessage = new byte[gs2Len + channelBindingData.length];
+ System.arraycopy(this.gs2Header.getBytes(), 0, cMessage, 0, gs2Len);
+ System.arraycopy(channelBindingData, 0, cMessage, gs2Len, channelBindingData.length);
- for (int i = 0; i < clientProof.length; i++) {
- clientProof[i] = (byte) (keys.clientKey[i] ^ clientSignature[i]);
- }
+ final String clientFinalMessageWithoutProof =
+ String.format("c=%s,r=%s", BaseEncoding.base64().encode(cMessage), nonce);
- final String clientFinalMessage =
- clientFinalMessageWithoutProof
- + ",p="
- + Base64.encodeToString(clientProof, Base64.NO_WRAP);
- state = State.RESPONSE_SENT;
- return Base64.encodeToString(clientFinalMessage.getBytes(), Base64.NO_WRAP);
- case RESPONSE_SENT:
- try {
- final String clientCalculatedServerFinalMessage =
- "v=" + Base64.encodeToString(serverSignature, Base64.NO_WRAP);
- if (!clientCalculatedServerFinalMessage.equals(
- new String(Base64.decode(challenge, Base64.DEFAULT)))) {
- throw new Exception();
- }
- state = State.VALID_SERVER_RESPONSE;
- return "";
- } catch (Exception e) {
- throw new AuthenticationException(
- "Server final message does not match calculated final message");
- }
- default:
- throw new InvalidStateException(state);
+ final var authMessage =
+ Joiner.on(',')
+ .join(
+ clientFirstMessageBare,
+ new String(serverFirstMessage),
+ clientFinalMessageWithoutProof);
+
+ final KeyPair keys;
+ try {
+ keys = getKeyPair(CryptoHelper.saslPrep(account.getPassword()), salt, iterationCount);
+ } catch (final ExecutionException e) {
+ throw new AuthenticationException("Invalid keys generated");
+ }
+ final byte[] clientSignature;
+ try {
+ serverSignature = hmac(keys.serverKey, authMessage.getBytes());
+ final byte[] storedKey = digest(keys.clientKey);
+
+ clientSignature = hmac(storedKey, authMessage.getBytes());
+
+ } catch (final InvalidKeyException e) {
+ throw new AuthenticationException(e);
+ }
+
+ final byte[] clientProof = new byte[keys.clientKey.length];
+
+ if (clientSignature.length < keys.clientKey.length) {
+ throw new AuthenticationException("client signature was shorter than clientKey");
}
+
+ for (int j = 0; j < clientProof.length; j++) {
+ clientProof[j] = (byte) (keys.clientKey[j] ^ clientSignature[j]);
+ }
+
+ final var clientFinalMessage =
+ String.format(
+ "%s,p=%s",
+ clientFinalMessageWithoutProof, BaseEncoding.base64().encode(clientProof));
+ this.state = State.RESPONSE_SENT;
+ return BaseEncoding.base64().encode(clientFinalMessage.getBytes());
+ }
+
+ private Map<String, String> splitToAttributes(final String message) {
+ final ImmutableMap.Builder<String, String> builder = new ImmutableMap.Builder<>();
+ for (final String token : Splitter.on(',').split(message)) {
+ final var tuple = Splitter.on('=').limit(2).splitToList(token);
+ if (tuple.size() == 2) {
+ builder.put(tuple.get(0), tuple.get(1));
+ }
+ }
+ return builder.buildOrThrow();
+ }
+
+ private String processServerFinalMessage(final String challenge)
+ throws AuthenticationException {
+ final String serverFinalMessage;
+ try {
+ serverFinalMessage = new String(BaseEncoding.base64().decode(challenge));
+ } catch (final IllegalArgumentException e) {
+ throw new AuthenticationException("Invalid base64 in server final message", e);
+ }
+ final var clientCalculatedServerFinalMessage =
+ String.format("v=%s", BaseEncoding.base64().encode(serverSignature));
+ if (clientCalculatedServerFinalMessage.equals(serverFinalMessage)) {
+ this.state = State.VALID_SERVER_RESPONSE;
+ return "";
+ }
+ throw new AuthenticationException(
+ "Server final message does not match calculated final message");
}
protected byte[] getChannelBindingData(final SSLSocket sslSocket)
@@ -276,12 +288,16 @@ abstract class ScramMechanism extends SaslMechanism {
}
private static class CacheKey {
- final String algorithm;
- final String password;
- final String salt;
- final int iterations;
-
- private CacheKey(String algorithm, String password, String salt, int iterations) {
+ private final String algorithm;
+ private final String password;
+ private final byte[] salt;
+ private final int iterations;
+
+ private CacheKey(
+ final String algorithm,
+ final String password,
+ final byte[] salt,
+ final int iterations) {
this.algorithm = algorithm;
this.password = password;
this.salt = salt;
@@ -289,19 +305,20 @@ abstract class ScramMechanism extends SaslMechanism {
}
@Override
- public boolean equals(Object o) {
+ public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CacheKey cacheKey = (CacheKey) o;
return iterations == cacheKey.iterations
&& Objects.equal(algorithm, cacheKey.algorithm)
&& Objects.equal(password, cacheKey.password)
- && Objects.equal(salt, cacheKey.salt);
+ && Arrays.equals(salt, cacheKey.salt);
}
@Override
public int hashCode() {
- return Objects.hashCode(algorithm, password, salt, iterations);
+ final int result = Objects.hashCode(algorithm, password, iterations);
+ return 31 * result + Arrays.hashCode(salt);
}
}
@@ -68,6 +68,7 @@ import im.conversations.android.xmpp.model.AuthenticationStreamFeature;
import im.conversations.android.xmpp.model.StreamElement;
import im.conversations.android.xmpp.model.bind2.Bind;
import im.conversations.android.xmpp.model.bind2.Bound;
+import im.conversations.android.xmpp.model.cb.SaslChannelBinding;
import im.conversations.android.xmpp.model.csi.Active;
import im.conversations.android.xmpp.model.csi.Inactive;
import im.conversations.android.xmpp.model.error.Condition;
@@ -917,7 +918,6 @@ public class XmppConnection implements Runnable {
final Tag tag = tagReader.readTag();
if (tag != null && tag.isStart("stream", Namespace.STREAMS)) {
processStream();
- return;
} else {
throw new StateChangingException(Account.State.STREAM_OPENING_ERROR);
}
@@ -1552,9 +1552,8 @@ public class XmppConnection implements Runnable {
authElement = this.streamFeatures.getExtension(Authentication.class);
}
final Collection<String> mechanisms = authElement.getMechanismNames();
- final Element cbElement =
- this.streamFeatures.findChild("sasl-channel-binding", Namespace.CHANNEL_BINDING);
- final Collection<ChannelBinding> channelBindings = ChannelBinding.of(cbElement);
+ final var cbExtension = this.streamFeatures.getExtension(SaslChannelBinding.class);
+ final Collection<ChannelBinding> channelBindings = ChannelBinding.of(cbExtension);
final SaslMechanism.Factory factory = new SaslMechanism.Factory(account);
final SaslMechanism saslMechanism =
factory.of(mechanisms, channelBindings, version, SSLSockets.version(this.socket));
@@ -2674,11 +2673,11 @@ public class XmppConnection implements Runnable {
}
public Jid findDiscoItemByFeature(final String feature) {
- final List<Entry<Jid, ServiceDiscoveryResult>> items = findDiscoItemsByFeature(feature);
- if (items.size() >= 1) {
- return items.get(0).getKey();
+ final var items = findDiscoItemsByFeature(feature);
+ if (items.isEmpty()) {
+ return null;
}
- return null;
+ return Iterables.getFirst(items, null).getKey();
}
public boolean r() {
@@ -3096,7 +3095,7 @@ public class XmppConnection implements Runnable {
new String[] {Namespace.HTTP_UPLOAD, Namespace.HTTP_UPLOAD_LEGACY}) {
List<Entry<Jid, ServiceDiscoveryResult>> items =
findDiscoItemsByFeature(namespace);
- if (items.size() > 0) {
+ if (!items.isEmpty()) {
try {
long maxsize =
Long.parseLong(
@@ -3136,7 +3135,7 @@ public class XmppConnection implements Runnable {
for (String namespace :
new String[] {Namespace.HTTP_UPLOAD, Namespace.HTTP_UPLOAD_LEGACY}) {
List<Entry<Jid, ServiceDiscoveryResult>> items = findDiscoItemsByFeature(namespace);
- if (items.size() > 0) {
+ if (!items.isEmpty()) {
try {
return Long.parseLong(
items.get(0)