prevent race condition when fetching device ids

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/crypto/axolotl/AxolotlService.java    | 86 
src/main/java/eu/siacs/conversations/crypto/axolotl/FingerprintStatus.java |  4 
2 files changed, 47 insertions(+), 43 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/crypto/axolotl/AxolotlService.java 🔗

@@ -1021,28 +1021,33 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
 		}
 		if (packet != null) {
 			mXmppConnectionService.sendIqPacket(account, packet, (account, response) -> {
-				synchronized (fetchDeviceIdsMap) {
-					List<OnDeviceIdsFetched> callbacks = fetchDeviceIdsMap.remove(jid);
-					if (response.getType() == IqPacket.TYPE.RESULT) {
-						fetchDeviceListStatus.put(jid, true);
-						Element item = mXmppConnectionService.getIqParser().getItem(response);
-						Set<Integer> deviceIds = mXmppConnectionService.getIqParser().deviceIds(item);
-						registerDevices(jid, deviceIds);
-						if (callbacks != null) {
-							for (OnDeviceIdsFetched c : callbacks) {
-								c.fetched(jid, deviceIds);
-							}
+				if (response.getType() == IqPacket.TYPE.RESULT) {
+					fetchDeviceListStatus.put(jid, true);
+					Element item = mXmppConnectionService.getIqParser().getItem(response);
+					Set<Integer> deviceIds = mXmppConnectionService.getIqParser().deviceIds(item);
+					registerDevices(jid, deviceIds);
+					final List<OnDeviceIdsFetched> callbacks;
+					synchronized (fetchDeviceIdsMap) {
+						callbacks = fetchDeviceIdsMap.remove(jid);
+					}
+					if (callbacks != null) {
+						for (OnDeviceIdsFetched c : callbacks) {
+							c.fetched(jid, deviceIds);
 						}
+					}
+				} else {
+					if (response.getType() == IqPacket.TYPE.TIMEOUT) {
+						fetchDeviceListStatus.remove(jid);
 					} else {
-						if (response.getType() == IqPacket.TYPE.TIMEOUT) {
-							fetchDeviceListStatus.remove(jid);
-						} else {
-							fetchDeviceListStatus.put(jid, false);
-						}
-						if (callbacks != null) {
-							for (OnDeviceIdsFetched c : callbacks) {
-								c.fetched(jid, null);
-							}
+						fetchDeviceListStatus.put(jid, false);
+					}
+					final List<OnDeviceIdsFetched> callbacks;
+					synchronized (fetchDeviceIdsMap) {
+						callbacks = fetchDeviceIdsMap.remove(jid);
+					}
+					if (callbacks != null) {
+						for (OnDeviceIdsFetched c : callbacks) {
+							c.fetched(jid, null);
 						}
 					}
 				}
@@ -1157,8 +1162,9 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
 		Set<SignalProtocolAddress> addresses = new HashSet<>();
 		for (Jid jid : getCryptoTargets(conversation)) {
 			Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Finding devices without session for " + jid);
-			if (deviceIds.get(jid) != null) {
-				for (Integer foreignId : this.deviceIds.get(jid)) {
+			Set<Integer> ids = deviceIds.get(jid);
+			if (deviceIds.get(jid) != null && !ids.isEmpty()) {
+				for (Integer foreignId : ids) {
 					SignalProtocolAddress address = new SignalProtocolAddress(jid.toString(), foreignId);
 					if (sessions.get(address) == null) {
 						IdentityKey identityKey = axolotlStore.loadSession(address).getSessionState().getRemoteIdentityKey();
@@ -1181,22 +1187,21 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
 				Log.w(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Have no target devices in PEP!");
 			}
 		}
-		if (deviceIds.get(account.getJid().asBareJid()) != null) {
-			for (Integer ownId : this.deviceIds.get(account.getJid().asBareJid())) {
-				SignalProtocolAddress address = new SignalProtocolAddress(account.getJid().asBareJid().toString(), ownId);
-				if (sessions.get(address) == null) {
-					IdentityKey identityKey = axolotlStore.loadSession(address).getSessionState().getRemoteIdentityKey();
-					if (identityKey != null) {
-						Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Already have session for " + address.toString() + ", adding to cache...");
-						XmppAxolotlSession session = new XmppAxolotlSession(account, axolotlStore, address, identityKey);
-						sessions.put(address, session);
+		Set<Integer> ownIds = this.deviceIds.get(account.getJid().asBareJid());
+		for (Integer ownId : (ownIds != null ? ownIds : new HashSet<Integer>())) {
+			SignalProtocolAddress address = new SignalProtocolAddress(account.getJid().asBareJid().toString(), ownId);
+			if (sessions.get(address) == null) {
+				IdentityKey identityKey = axolotlStore.loadSession(address).getSessionState().getRemoteIdentityKey();
+				if (identityKey != null) {
+					Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Already have session for " + address.toString() + ", adding to cache...");
+					XmppAxolotlSession session = new XmppAxolotlSession(account, axolotlStore, address, identityKey);
+					sessions.put(address, session);
+				} else {
+					Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Found device " + account.getJid().asBareJid() + ":" + ownId);
+					if (fetchStatusMap.get(address) != FetchStatus.ERROR) {
+						addresses.add(address);
 					} else {
-						Log.d(Config.LOGTAG, AxolotlService.getLogprefix(account) + "Found device " + account.getJid().asBareJid() + ":" + ownId);
-						if (fetchStatusMap.get(address) != FetchStatus.ERROR) {
-							addresses.add(address);
-						} else {
-							Log.d(Config.LOGTAG, getLogprefix(account) + "skipping over " + address + " because it's broken");
-						}
+						Log.d(Config.LOGTAG, getLogprefix(account) + "skipping over " + address + " because it's broken");
 					}
 				}
 			}
@@ -1215,12 +1220,7 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
 		}
 		Log.d(Config.LOGTAG, account.getJid().asBareJid() + ": createSessionsIfNeeded() - jids with empty device list: " + jidsWithEmptyDeviceList);
 		if (jidsWithEmptyDeviceList.size() > 0) {
-			fetchDeviceIds(jidsWithEmptyDeviceList, new OnMultipleDeviceIdFetched() {
-				@Override
-				public void fetched() {
-					createSessionsIfNeededActual(conversation);
-				}
-			});
+			fetchDeviceIds(jidsWithEmptyDeviceList, () -> createSessionsIfNeededActual(conversation));
 			return true;
 		} else {
 			return createSessionsIfNeededActual(conversation);

src/main/java/eu/siacs/conversations/crypto/axolotl/FingerprintStatus.java 🔗

@@ -78,6 +78,10 @@ public class FingerprintStatus implements Comparable<FingerprintStatus> {
         return status;
     }
 
+    public static FingerprintStatus createActive(Boolean trusted) {
+        return createActive(trusted != null && trusted);
+    }
+
     public static FingerprintStatus createActive(boolean trusted) {
         final FingerprintStatus status = new FingerprintStatus();
         status.trust = trusted ? Trust.TRUSTED : Trust.UNTRUSTED;