pipeline sasl2 directly after stream start

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/entities/Account.java       |   4 
src/main/java/eu/siacs/conversations/ui/EditAccountActivity.java |   2 
src/main/java/eu/siacs/conversations/xml/Namespace.java          |   1 
src/main/java/eu/siacs/conversations/xml/TagWriter.java          |  11 
src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java    | 112 +
src/main/java/eu/siacs/conversations/xmpp/bind/Bind2.java        |  33 
6 files changed, 121 insertions(+), 42 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/entities/Account.java 🔗

@@ -57,16 +57,14 @@ public class Account extends AbstractEntity implements AvatarService.Avatarable
     public static final String PINNED_CHANNEL_BINDING = "pinned_channel_binding";
 
 
-    public static final int OPTION_USETLS = 0;
     public static final int OPTION_DISABLED = 1;
     public static final int OPTION_REGISTER = 2;
-    public static final int OPTION_USECOMPRESSION = 3;
     public static final int OPTION_MAGIC_CREATE = 4;
     public static final int OPTION_REQUIRES_ACCESS_MODE_CHANGE = 5;
     public static final int OPTION_LOGGED_IN_SUCCESSFULLY = 6;
     public static final int OPTION_HTTP_UPLOAD_AVAILABLE = 7;
-    public static final int OPTION_UNVERIFIED = 8;
     public static final int OPTION_FIXED_USERNAME = 9;
+    public static final int OPTION_QUICKSTART_AVAILABLE = 10;
 
     private static final String KEY_PGP_SIGNATURE = "pgp_signature";
     private static final String KEY_PGP_ID = "pgp_id";

src/main/java/eu/siacs/conversations/ui/EditAccountActivity.java 🔗

@@ -286,8 +286,6 @@ public class EditAccountActivity extends OmemoActivity implements OnAccountUpdat
                 mAccount = new Account(jid.asBareJid(), password);
                 mAccount.setPort(numericPort);
                 mAccount.setHostname(hostname);
-                mAccount.setOption(Account.OPTION_USETLS, true);
-                mAccount.setOption(Account.OPTION_USECOMPRESSION, true);
                 mAccount.setOption(Account.OPTION_REGISTER, registerNewAccount);
                 xmppConnectionService.createAccount(mAccount);
             }

src/main/java/eu/siacs/conversations/xml/Namespace.java 🔗

@@ -1,6 +1,7 @@
 package eu.siacs.conversations.xml;
 
 public final class Namespace {
+    public static final String STREAMS = "http://etherx.jabber.org/streams";
     public static final String DISCO_ITEMS = "http://jabber.org/protocol/disco#items";
     public static final String DISCO_INFO = "http://jabber.org/protocol/disco#info";
     public static final String EXTERNAL_SERVICE_DISCOVERY = "urn:xmpp:extdisco:2";

src/main/java/eu/siacs/conversations/xml/TagWriter.java 🔗

@@ -58,15 +58,20 @@ public class TagWriter {
             throw new IOException("output stream was null");
         }
         outputStream.write("<?xml version='1.0'?>");
-        outputStream.flush();
     }
 
-    public synchronized void writeTag(Tag tag) throws IOException {
+    public void writeTag(final Tag tag) throws IOException {
+        writeTag(tag, true);
+    }
+
+    public synchronized void writeTag(final Tag tag, final boolean flush) throws IOException {
         if (outputStream == null) {
             throw new IOException("output stream was null");
         }
         outputStream.write(tag.toString());
-        outputStream.flush();
+        if (flush) {
+            outputStream.flush();
+        }
     }
 
     public synchronized void writeElement(Element element) throws IOException {

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -89,6 +89,7 @@ import eu.siacs.conversations.xml.Namespace;
 import eu.siacs.conversations.xml.Tag;
 import eu.siacs.conversations.xml.TagWriter;
 import eu.siacs.conversations.xml.XmlReader;
+import eu.siacs.conversations.xmpp.bind.Bind2;
 import eu.siacs.conversations.xmpp.forms.Data;
 import eu.siacs.conversations.xmpp.jingle.OnJinglePacketReceived;
 import eu.siacs.conversations.xmpp.jingle.stanzas.JinglePacket;
@@ -155,6 +156,7 @@ public class XmppConnection implements Runnable {
     private TagWriter tagWriter = new TagWriter();
     private boolean shouldAuthenticate = true;
     private boolean inSmacksSession = false;
+    private boolean quickStartInProgress = false;
     private boolean isBound = false;
     private Element streamFeatures;
     private String streamId = null;
@@ -270,11 +272,11 @@ public class XmppConnection implements Runnable {
         }
         Log.d(Config.LOGTAG, account.getJid().asBareJid().toString() + ": connecting");
         features.encryptionEnabled = false;
-        inSmacksSession = false;
-        isBound = false;
+        this.inSmacksSession = false;
+        this.quickStartInProgress = false;
+        this.isBound = false;
         this.attempt++;
-        this.verifiedHostname =
-                null; // will be set if user entered hostname is being used or hostname was verified
+        this.verifiedHostname = null; // will be set if user entered hostname is being used or hostname was verified
         // with dnssec
         try {
             Socket localSocket;
@@ -310,14 +312,14 @@ public class XmppConnection implements Runnable {
 
                 try {
                     startXmpp(localSocket);
-                } catch (InterruptedException e) {
+                } catch (final InterruptedException e) {
                     Log.d(
                             Config.LOGTAG,
                             account.getJid().asBareJid()
                                     + ": thread was interrupted before beginning stream");
                     return;
-                } catch (Exception e) {
-                    throw new IOException(e.getMessage());
+                } catch (final Exception e) {
+                    throw new IOException("Could not start stream", e);
                 }
             } else {
                 final String domain = account.getServer();
@@ -477,7 +479,7 @@ public class XmppConnection implements Runnable {
      *
      * @return true if server returns with valid xmpp, false otherwise
      */
-    private boolean startXmpp(Socket socket) throws Exception {
+    private boolean startXmpp(final Socket socket) throws Exception {
         if (Thread.currentThread().isInterrupted()) {
             throw new InterruptedException();
         }
@@ -490,15 +492,22 @@ public class XmppConnection implements Runnable {
         tagWriter.setOutputStream(socket.getOutputStream());
         tagReader.setInputStream(socket.getInputStream());
         tagWriter.beginDocument();
-        sendStartStream();
+        final boolean quickStart;
+        if (socket instanceof SSLSocket) {
+            SSLSocketHelper.log(account, (SSLSocket) socket);
+            quickStart = establishStream(true);
+        } else {
+            quickStart = establishStream(false);
+        }
         final Tag tag = tagReader.readTag();
         if (Thread.currentThread().isInterrupted()) {
             throw new InterruptedException();
         }
-        if (socket instanceof SSLSocket) {
-            SSLSocketHelper.log(account, (SSLSocket) socket);
+        final boolean success = tag != null && tag.isStart("stream", Namespace.STREAMS);
+        if (success && quickStart) {
+            this.quickStartInProgress = true;
         }
-        return tag != null && tag.isStart("stream");
+        return success;
     }
 
     private SSLSocketFactory getSSLSocketFactory()
@@ -761,11 +770,12 @@ public class XmppConnection implements Runnable {
                 sendPostBindInitialization(streamManagementEnabled != null, carbonsEnabled != null);
             }
         }
+        this.quickStartInProgress = false;
         if (version == SaslMechanism.Version.SASL) {
             tagReader.reset();
-            sendStartStream();
+            sendStartStream(true);
             final Tag tag = tagReader.readTag();
-            if (tag != null && tag.isStart("stream")) {
+            if (tag != null && tag.isStart("stream", Namespace.STREAMS)) {
                 processStream();
                 return true;
             } else {
@@ -1119,11 +1129,14 @@ public class XmppConnection implements Runnable {
         final SSLSocket sslSocket = upgradeSocketToTls(socket);
         tagReader.setInputStream(sslSocket.getInputStream());
         tagWriter.setOutputStream(sslSocket.getOutputStream());
-        sendStartStream();
         Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": TLS connection established");
+        final boolean quickStart = establishStream(true);
+        if (quickStart) {
+            this.quickStartInProgress = true;
+        }
         features.encryptionEnabled = true;
         final Tag tag = tagReader.readTag();
-        if (tag != null && tag.isStart("stream")) {
+        if (tag != null && tag.isStart("stream", Namespace.STREAMS)) {
             SSLSocketHelper.log(account, sslSocket);
             processStream();
         } else {
@@ -1170,7 +1183,13 @@ public class XmppConnection implements Runnable {
         final boolean isSecure =
                 features.encryptionEnabled || Config.ALLOW_NON_TLS_CONNECTIONS || account.isOnion();
         final boolean needsBinding = !isBound && !account.isOptionSet(Account.OPTION_REGISTER);
-        if (this.streamFeatures.hasChild("starttls", Namespace.TLS)
+        if (this.quickStartInProgress) {
+            Log.d(
+                    Config.LOGTAG,
+                    account.getJid().asBareJid()
+                            + ": quick start in progress. ignoring features: "
+                            + XmlHelper.printElementNames(this.streamFeatures));
+        } else if (this.streamFeatures.hasChild("starttls", Namespace.TLS)
                 && !features.encryptionEnabled) {
             sendStartTLS();
         } else if (this.streamFeatures.hasChild("register", Namespace.REGISTER_STREAM_FEATURE)
@@ -1238,6 +1257,7 @@ public class XmppConnection implements Runnable {
         } else {
             authElement = this.streamFeatures.findChild("authentication", Namespace.SASL_2);
         }
+        //TODO externalize
         final Collection<String> mechanisms =
                 Collections2.transform(
                         Collections2.filter(
@@ -1261,6 +1281,8 @@ public class XmppConnection implements Runnable {
         final SaslMechanism.Factory factory = new SaslMechanism.Factory(account);
         this.saslMechanism = factory.of(mechanisms, channelBindings);
 
+        //TODO externalize checks
+
         if (saslMechanism == null) {
             Log.d(
                     Config.LOGTAG,
@@ -1282,6 +1304,7 @@ public class XmppConnection implements Runnable {
                             + "). Possible downgrade attack?");
             throw new StateChangingException(Account.State.DOWNGRADE_ATTACK);
         }
+        final boolean quickStartAvailable;
         final String firstMessage = saslMechanism.getClientFirstMessage();
         final Element authenticate;
         if (version == SaslMechanism.Version.SASL) {
@@ -1289,15 +1312,24 @@ public class XmppConnection implements Runnable {
             if (!Strings.isNullOrEmpty(firstMessage)) {
                 authenticate.setContent(firstMessage);
             }
+            quickStartAvailable = false;
         } else if (version == SaslMechanism.Version.SASL_2) {
             final Element inline = authElement.findChild("inline", Namespace.SASL_2);
             final boolean sm = inline != null && inline.hasChild("sm", "urn:xmpp:sm:3");
-            final Collection<String> bindFeatures = bindFeatures(inline);
+            final Collection<String> bindFeatures = Bind2.features(inline);
+            quickStartAvailable =
+                    sm
+                            && bindFeatures != null
+                            && bindFeatures.containsAll(Bind2.QUICKSTART_FEATURES);
             authenticate = generateAuthenticationRequest(firstMessage, bindFeatures, sm);
         } else {
             throw new AssertionError("Missing implementation for " + version);
         }
 
+        if (account.setOption(Account.OPTION_QUICKSTART_AVAILABLE, quickStartAvailable)) {
+            mXmppConnectionService.updateAccount(account);
+        }
+
         Log.d(
                 Config.LOGTAG,
                 account.getJid().toString()
@@ -1309,19 +1341,8 @@ public class XmppConnection implements Runnable {
         tagWriter.writeElement(authenticate);
     }
 
-    private static Collection<String> bindFeatures(final Element inline) {
-        final Element inlineBind2 =
-                inline != null ? inline.findChild("bind", Namespace.BIND2) : null;
-        final Element inlineBind2Inline =
-                inlineBind2 != null ? inlineBind2.findChild("inline", Namespace.BIND2) : null;
-        if (inlineBind2 == null) {
-            return null;
-        }
-        if (inlineBind2Inline == null) {
-            return Collections.emptyList();
-        }
-        return Collections2.transform(
-                inlineBind2Inline.getChildren(), c -> c == null ? null : c.getAttribute("var"));
+    private Element generateAuthenticationRequest(final String firstMessage) {
+        return generateAuthenticationRequest(firstMessage, Bind2.QUICKSTART_FEATURES, true);
     }
 
     private Element generateAuthenticationRequest(
@@ -1988,14 +2009,37 @@ public class XmppConnection implements Runnable {
         }
     }
 
-    private void sendStartStream() throws IOException {
+    private boolean establishStream(final boolean secureConnection) throws IOException {
+        final SaslMechanism saslMechanism = account.getPinnedMechanism();
+        if (secureConnection
+                && saslMechanism != null
+                && account.isOptionSet(Account.OPTION_QUICKSTART_AVAILABLE)) {
+            this.saslMechanism = saslMechanism;
+            final Element authenticate =
+                    generateAuthenticationRequest(saslMechanism.getClientFirstMessage());
+            authenticate.setAttribute("mechanism", saslMechanism.getMechanism());
+            sendStartStream(false);
+            tagWriter.writeElement(authenticate);
+            Log.d(
+                    Config.LOGTAG,
+                    account.getJid().toString()
+                            + ": quick start with "
+                            + saslMechanism.getMechanism());
+            return true;
+        } else {
+            sendStartStream(true);
+            return false;
+        }
+    }
+
+    private void sendStartStream(final boolean flush) throws IOException {
         final Tag stream = Tag.start("stream:stream");
         stream.setAttribute("to", account.getServer());
         stream.setAttribute("version", "1.0");
         stream.setAttribute("xml:lang", LocalizedContent.STREAM_LANGUAGE);
         stream.setAttribute("xmlns", "jabber:client");
-        stream.setAttribute("xmlns:stream", "http://etherx.jabber.org/streams");
-        tagWriter.writeTag(stream);
+        stream.setAttribute("xmlns:stream", Namespace.STREAMS);
+        tagWriter.writeTag(stream, flush);
     }
 
     private String createNewResource() {

src/main/java/eu/siacs/conversations/xmpp/bind/Bind2.java 🔗

@@ -0,0 +1,33 @@
+package eu.siacs.conversations.xmpp.bind;
+
+import com.google.common.collect.Collections2;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+
+import eu.siacs.conversations.xml.Element;
+import eu.siacs.conversations.xml.Namespace;
+
+public class Bind2 {
+
+    public static final Collection<String> QUICKSTART_FEATURES = Arrays.asList(
+            Namespace.CARBONS,
+            Namespace.STREAM_MANAGEMENT
+    );
+
+    public static Collection<String> features(final Element inline) {
+        final Element inlineBind2 =
+                inline != null ? inline.findChild("bind", Namespace.BIND2) : null;
+        final Element inlineBind2Inline =
+                inlineBind2 != null ? inlineBind2.findChild("inline", Namespace.BIND2) : null;
+        if (inlineBind2 == null) {
+            return null;
+        }
+        if (inlineBind2Inline == null) {
+            return Collections.emptyList();
+        }
+        return Collections2.transform(
+                inlineBind2Inline.getChildren(), c -> c == null ? null : c.getAttribute("var"));
+    }
+}