Improve auth error handling and state machine

Sam Whited created

Change summary

src/main/java/eu/siacs/conversations/crypto/sasl/AuthenticationException.java | 11 
src/main/java/eu/siacs/conversations/crypto/sasl/DigestMd5.java               | 17 
src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java           | 27 
src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java               | 13 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java                 | 15 
5 files changed, 43 insertions(+), 40 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/crypto/sasl/DigestMd5.java 🔗

@@ -21,11 +21,6 @@ public class DigestMd5 extends SaslMechanism {
 		return "DIGEST-MD5";
 	}
 
-	private enum State {
-		INITIAL,
-		RESPONSE_SENT,
-	}
-
 	private State state = State.INITIAL;
 
 	@Override
@@ -53,8 +48,7 @@ public class DigestMd5 extends SaslMechanism {
 					final byte[] y = md.digest(x.getBytes(Charset.defaultCharset()));
 					final String cNonce = new BigInteger(100, rng).toString(32);
 					final byte[] a1 = CryptoHelper.concatenateByteArrays(y,
-							(":" + nonce + ":" + cNonce).getBytes(Charset
-																										.defaultCharset()));
+							(":" + nonce + ":" + cNonce).getBytes(Charset.defaultCharset()));
 					final String a2 = "AUTHENTICATE:" + digestUri;
 					final String ha1 = CryptoHelper.bytesToHex(md.digest(a1));
 					final String ha2 = CryptoHelper.bytesToHex(md.digest(a2.getBytes(Charset
@@ -72,13 +66,16 @@ public class DigestMd5 extends SaslMechanism {
 							saslString.getBytes(Charset.defaultCharset()),
 							Base64.NO_WRAP);
 				} catch (final NoSuchAlgorithmException e) {
-					return "";
+					throw new AuthenticationException(e);
 				}
 
 				return encodedResponse;
 			case RESPONSE_SENT:
-				return "";
+				state = State.VALID_SERVER_RESPONSE;
+				break;
+			default:
+				throw new InvalidStateException(state);
 		}
-		return "";
+		return null;
 	}
 }

src/main/java/eu/siacs/conversations/crypto/sasl/SaslMechanism.java 🔗

@@ -11,6 +11,33 @@ public abstract class SaslMechanism {
 	final protected Account account;
 	final protected SecureRandom rng;
 
+	protected static enum State {
+		INITIAL,
+		AUTH_TEXT_SENT,
+		RESPONSE_SENT,
+		VALID_SERVER_RESPONSE,
+	}
+
+	public static class AuthenticationException extends Exception {
+		public AuthenticationException(final String message) {
+			super(message);
+		}
+
+		public AuthenticationException(final Exception inner) {
+			super(inner);
+		}
+	}
+
+	public static class InvalidStateException extends AuthenticationException {
+		public InvalidStateException(final String message) {
+			super(message);
+		}
+
+		public InvalidStateException(final State state) {
+			this("Invalid state: " + state.toString());
+		}
+	}
+
 	public SaslMechanism(final TagWriter tagWriter, final Account account, final SecureRandom rng) {
 		this.tagWriter = tagWriter;
 		this.account = account;

src/main/java/eu/siacs/conversations/crypto/sasl/ScramSha1.java 🔗

@@ -33,13 +33,6 @@ public class ScramSha1 extends SaslMechanism {
 		HMAC = new HMac(new SHA1Digest());
 	}
 
-	private enum State {
-		INITIAL,
-		AUTH_TEXT_SENT,
-		RESPONSE_SENT,
-		VALID_SERVER_RESPONSE,
-	}
-
 	private State state = State.INITIAL;
 
 	public ScramSha1(final TagWriter tagWriter, final Account account, final SecureRandom rng) {
@@ -56,11 +49,9 @@ public class ScramSha1 extends SaslMechanism {
 
 	@Override
 	public String getClientFirstMessage() {
-		if (clientFirstMessageBare.isEmpty()) {
+		if (clientFirstMessageBare.isEmpty() && state == State.INITIAL) {
 			clientFirstMessageBare = "n=" + CryptoHelper.saslPrep(account.getUsername()) +
 				",r=" + this.clientNonce;
-		}
-		if (state == State.INITIAL) {
 			state = State.AUTH_TEXT_SENT;
 		}
 		return Base64.encodeToString(
@@ -157,7 +148,7 @@ public class ScramSha1 extends SaslMechanism {
 				state = State.VALID_SERVER_RESPONSE;
 				return "";
 			default:
-				throw new AuthenticationException("Invalid state: " + state);
+				throw new InvalidStateException(state);
 		}
 	}
 

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -39,7 +39,6 @@ import javax.net.ssl.SSLSocketFactory;
 import javax.net.ssl.X509TrustManager;
 
 import eu.siacs.conversations.Config;
-import eu.siacs.conversations.crypto.sasl.AuthenticationException;
 import eu.siacs.conversations.crypto.sasl.DigestMd5;
 import eu.siacs.conversations.crypto.sasl.Plain;
 import eu.siacs.conversations.crypto.sasl.SaslMechanism;
@@ -284,14 +283,14 @@ public class XmppConnection implements Runnable {
 							} else if (nextTag.isStart("compressed")) {
 								switchOverToZLib(nextTag);
 							} else if (nextTag.isStart("success")) {
-								Log.d(Config.LOGTAG, account.getJid().toBareJid().toString() + ": logged in");
 								final String challenge = tagReader.readElement(nextTag).getContent();
 								try {
 									saslMechanism.getResponse(challenge);
-								} catch (final AuthenticationException e) {
+								} catch (final SaslMechanism.AuthenticationException e) {
 									disconnect(true);
 									Log.e(Config.LOGTAG, String.valueOf(e));
 								}
+								Log.d(Config.LOGTAG, account.getJid().toBareJid().toString() + ": logged in");
 								tagReader.reset();
 								sendStartStream();
 								processStream(tagReader.readTag());
@@ -306,7 +305,7 @@ public class XmppConnection implements Runnable {
 										"urn:ietf:params:xml:ns:xmpp-sasl");
 								try {
 									response.setContent(saslMechanism.getResponse(challenge));
-								} catch (final AuthenticationException e) {
+								} catch (final SaslMechanism.AuthenticationException e) {
 									// TODO: Send auth abort tag.
 									Log.e(Config.LOGTAG, e.toString());
 								}
@@ -643,10 +642,10 @@ public class XmppConnection implements Runnable {
 				saslMechanism = new Plain(tagWriter, account);
 				auth.setAttribute("mechanism", Plain.getMechanism());
 			}
-			if (!saslMechanism.getClientFirstMessage().isEmpty()) {
-				auth.setContent(saslMechanism.getClientFirstMessage());
-			}
-			tagWriter.writeElement(auth);
+            if (!saslMechanism.getClientFirstMessage().isEmpty()) {
+                auth.setContent(saslMechanism.getClientFirstMessage());
+            }
+            tagWriter.writeElement(auth);
 		} else if (this.streamFeatures.hasChild("sm", "urn:xmpp:sm:"
 					+ smVersion)
 				&& streamId != null) {