Make sure timeout doesn't fire if we get a response and vice versa

Stephen Paul Weber created

Change summary

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java | 43 +++-
1 file changed, 26 insertions(+), 17 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/xmpp/XmppConnection.java 🔗

@@ -49,6 +49,7 @@ import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -150,7 +151,7 @@ public class XmppConnection implements Runnable {
     private final HashMap<Jid, ServiceDiscoveryResult> disco = new HashMap<>();
     private final HashMap<String, Jid> commands = new HashMap<>();
     private final SparseArray<AbstractAcknowledgeableStanza> mStanzaQueue = new SparseArray<>();
-    private final Hashtable<String, Pair<IqPacket, OnIqPacketReceived>> packetCallbacks =
+    private final Hashtable<String, Pair<IqPacket, Pair<OnIqPacketReceived, ScheduledFuture>>> packetCallbacks =
             new Hashtable<>();
     private final Set<OnAdvancedStreamFeaturesLoaded> advancedStreamFeaturesLoadedListeners =
             new HashSet<>();
@@ -1169,13 +1170,16 @@ public class XmppConnection implements Runnable {
         } else {
             OnIqPacketReceived callback = null;
             synchronized (this.packetCallbacks) {
-                final Pair<IqPacket, OnIqPacketReceived> packetCallbackDuple =
+                final Pair<IqPacket, Pair<OnIqPacketReceived, ScheduledFuture>> packetCallbackDuple =
                         packetCallbacks.get(packet.getId());
                 if (packetCallbackDuple != null) {
+                    ScheduledFuture timeoutFuture = packetCallbackDuple.second.second;
                     // Packets to the server should have responses from the server
                     if (packetCallbackDuple.first.toServer(account)) {
                         if (packet.fromServer(account)) {
-                            callback = packetCallbackDuple.second;
+                            if (timeoutFuture == null || timeoutFuture.cancel(false)) {
+                                callback = packetCallbackDuple.second.first;
+                            }
                             packetCallbacks.remove(packet.getId());
                         } else {
                             Log.e(
@@ -1186,7 +1190,9 @@ public class XmppConnection implements Runnable {
                     } else {
                         if (packet.getFrom() != null
                                 && packet.getFrom().equals(packetCallbackDuple.first.getTo())) {
-                            callback = packetCallbackDuple.second;
+                            if (timeoutFuture == null || timeoutFuture.cancel(false)) {
+                                callback = packetCallbackDuple.second.first;
+                            }
                             packetCallbacks.remove(packet.getId());
                         } else {
                             Log.e(
@@ -1829,11 +1835,13 @@ public class XmppConnection implements Runnable {
                             + ": clearing "
                             + this.packetCallbacks.size()
                             + " iq callbacks");
-            final Iterator<Pair<IqPacket, OnIqPacketReceived>> iterator =
+            final Iterator<Pair<IqPacket, Pair<OnIqPacketReceived, ScheduledFuture>>> iterator =
                     this.packetCallbacks.values().iterator();
             while (iterator.hasNext()) {
-                Pair<IqPacket, OnIqPacketReceived> entry = iterator.next();
-                callbacks.add(entry.second);
+                Pair<IqPacket, Pair<OnIqPacketReceived, ScheduledFuture>> entry = iterator.next();
+                if (entry.second.second == null || entry.second.second.cancel(false)) {
+                    callbacks.add(entry.second.first);
+                }
                 iterator.remove();
             }
         }
@@ -2265,19 +2273,20 @@ public class XmppConnection implements Runnable {
         }
         if (callback != null) {
             synchronized (this.packetCallbacks) {
-                packetCallbacks.put(packet.getId(), new Pair<>(packet, callback));
+                ScheduledFuture timeoutFuture = null;
+                if (timeout != null) {
+                    timeoutFuture = SCHEDULER.schedule(() -> {
+                        synchronized (this.packetCallbacks) {
+                            final IqPacket failurePacket = new IqPacket(IqPacket.TYPE.TIMEOUT);
+                            final Pair<IqPacket, Pair<OnIqPacketReceived, ScheduledFuture>> removedCallback = packetCallbacks.remove(packet.getId());
+                            if (removedCallback != null) removedCallback.second.first.onIqPacketReceived(account, failurePacket);
+                        }
+                    }, timeout, TimeUnit.SECONDS);
+                }
+                packetCallbacks.put(packet.getId(), new Pair<>(packet, new Pair<>(callback, timeoutFuture)));
             }
         }
         this.sendPacket(packet, force);
-        if (timeout != null) {
-            SCHEDULER.schedule(() -> {
-                synchronized (this.packetCallbacks) {
-                    final IqPacket failurePacket = new IqPacket(IqPacket.TYPE.TIMEOUT);
-                    final Pair<IqPacket, OnIqPacketReceived> removedCallback = packetCallbacks.remove(packet.getId());
-                    if (removedCallback != null) removedCallback.second.onIqPacketReceived(account, failurePacket);
-                }
-            }, timeout, TimeUnit.SECONDS);
-        }
         return packet.getId();
     }