properly handle key transport messages. use prekeyparsing only when that attribute is set

Daniel Gultsch created

Change summary

src/main/java/eu/siacs/conversations/crypto/axolotl/AxolotlService.java     | 19 
src/main/java/eu/siacs/conversations/crypto/axolotl/XmppAxolotlMessage.java | 24 
src/main/java/eu/siacs/conversations/crypto/axolotl/XmppAxolotlSession.java | 30 
src/main/java/eu/siacs/conversations/parser/MessageParser.java              | 26 
src/main/java/eu/siacs/conversations/xmpp/jingle/JingleConnection.java      |  2 
5 files changed, 54 insertions(+), 47 deletions(-)

Detailed changes

src/main/java/eu/siacs/conversations/crypto/axolotl/AxolotlService.java 🔗

@@ -1313,16 +1313,15 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
 		return session;
 	}
 
-	public XmppAxolotlMessage.XmppAxolotlPlaintextMessage processReceivingPayloadMessage(XmppAxolotlMessage message) {
+	public XmppAxolotlMessage.XmppAxolotlPlaintextMessage processReceivingPayloadMessage(XmppAxolotlMessage message, boolean postponePreKeyMessageHandling) {
 		XmppAxolotlMessage.XmppAxolotlPlaintextMessage plaintextMessage = null;
 
 		XmppAxolotlSession session = getReceivingSession(message);
 		try {
 			plaintextMessage = message.decrypt(session, getOwnDeviceId());
-			Integer preKeyId = session.getPreKeyId();
+			Integer preKeyId = session.getPreKeyIdAndReset();
 			if (preKeyId != null) {
-				publishBundlesIfNeeded(false, false);
-				session.resetPreKeyId();
+				postPreKeyMessageHandling(session, preKeyId, postponePreKeyMessageHandling);
 			}
 		} catch (CryptoFailedException e) {
 			Log.w(Config.LOGTAG, getLogprefix(account) + "Failed to decrypt message from "+message.getFrom()+": " + e.getMessage());
@@ -1335,12 +1334,22 @@ public class AxolotlService implements OnAdvancedStreamFeaturesLoaded {
 		return plaintextMessage;
 	}
 
-	public XmppAxolotlMessage.XmppAxolotlKeyTransportMessage processReceivingKeyTransportMessage(XmppAxolotlMessage message) {
+	private void postPreKeyMessageHandling(final XmppAxolotlSession session, int preKeyId, final boolean postpone) {
+		Log.d(Config.LOGTAG,account.getJid().toBareJid()+": postPreKeyMessageHandling() preKeyId="+preKeyId+", postpone="+Boolean.toString(postpone));
+		//TODO: do not republish if we already removed this preKeyId
+		publishBundlesIfNeeded(false, false);
+	}
+
+	public XmppAxolotlMessage.XmppAxolotlKeyTransportMessage processReceivingKeyTransportMessage(XmppAxolotlMessage message, final boolean postponePreKeyMessageHandling) {
 		XmppAxolotlMessage.XmppAxolotlKeyTransportMessage keyTransportMessage;
 
 		XmppAxolotlSession session = getReceivingSession(message);
 		try {
 			keyTransportMessage = message.getParameters(session, getOwnDeviceId());
+			Integer preKeyId = session.getPreKeyIdAndReset();
+			if (preKeyId != null) {
+				postPreKeyMessageHandling(session, preKeyId, postponePreKeyMessageHandling);
+			}
 		} catch (CryptoFailedException e) {
 			Log.d(Config.LOGTAG,"could not decrypt keyTransport message "+e.getMessage());
 			keyTransportMessage = null;

src/main/java/eu/siacs/conversations/crypto/axolotl/XmppAxolotlMessage.java 🔗

@@ -2,6 +2,7 @@ package eu.siacs.conversations.crypto.axolotl;
 
 import android.util.Base64;
 import android.util.Log;
+import android.util.SparseArray;
 
 
 import java.security.InvalidAlgorithmParameterException;
@@ -43,7 +44,7 @@ public class XmppAxolotlMessage {
 	private byte[] ciphertext = null;
 	private byte[] authtagPlusInnerKey = null;
 	private byte[] iv = null;
-	private final Map<Integer, XmppAxolotlSession.AxolotlKey> keys;
+	private final SparseArray<XmppAxolotlSession.AxolotlKey> keys;
 	private final Jid from;
 	private final int sourceDeviceId;
 
@@ -99,7 +100,7 @@ public class XmppAxolotlMessage {
 			throw new IllegalArgumentException("invalid source id");
 		}
 		List<Element> keyElements = header.getChildren();
-		this.keys = new HashMap<>(keyElements.size());
+		this.keys = new SparseArray<>();
 		for (Element keyElement : keyElements) {
 			switch (keyElement.getName()) {
 				case KEYTAG:
@@ -132,7 +133,7 @@ public class XmppAxolotlMessage {
 	public XmppAxolotlMessage(Jid from, int sourceDeviceId) {
 		this.from = from;
 		this.sourceDeviceId = sourceDeviceId;
-		this.keys = new HashMap<>();
+		this.keys = new SparseArray<>();
 		this.iv = generateIv();
 		this.innerKey = generateKey();
 	}
@@ -159,6 +160,10 @@ public class XmppAxolotlMessage {
 		return iv;
 	}
 
+	public boolean hasPayload() {
+		return ciphertext != null;
+	}
+
 	public void encrypt(String plaintext) throws CryptoFailedException {
 		try {
 			SecretKey secretKey = new SecretKeySpec(innerKey, KEYTYPE);
@@ -205,10 +210,6 @@ public class XmppAxolotlMessage {
 		return sourceDeviceId;
 	}
 
-	public byte[] getCiphertext() {
-		return ciphertext;
-	}
-
 	public void addDevice(XmppAxolotlSession session) {
 		XmppAxolotlSession.AxolotlKey key;
 		if (authtagPlusInnerKey != null) {
@@ -233,13 +234,13 @@ public class XmppAxolotlMessage {
 		Element encryptionElement = new Element(CONTAINERTAG, AxolotlService.PEP_PREFIX);
 		Element headerElement = encryptionElement.addChild(HEADER);
 		headerElement.setAttribute(SOURCEID, sourceDeviceId);
-		for (Map.Entry<Integer, XmppAxolotlSession.AxolotlKey> keyEntry : keys.entrySet()) {
+		for(int i = 0; i < keys.size(); ++i) {
 			Element keyElement = new Element(KEYTAG);
-			keyElement.setAttribute(REMOTEID, keyEntry.getKey());
-			if (keyEntry.getValue().prekey) {
+			keyElement.setAttribute(REMOTEID, keys.keyAt(i));
+			if (keys.valueAt(i).prekey) {
 				keyElement.setAttribute("prekey","true");
 			}
-			keyElement.setContent(Base64.encodeToString(keyEntry.getValue().key, Base64.NO_WRAP));
+			keyElement.setContent(Base64.encodeToString(keys.valueAt(i).key, Base64.NO_WRAP));
 			headerElement.addChild(keyElement);
 		}
 		headerElement.addChild(IVTAG).setContent(Base64.encodeToString(iv, Base64.NO_WRAP));
@@ -267,7 +268,6 @@ public class XmppAxolotlMessage {
 		byte[] key = unpackKey(session, sourceDeviceId);
 		if (key != null) {
 			try {
-
 				if (key.length >= 32) {
 					int authtaglength = key.length - 16;
 					Log.d(Config.LOGTAG,"found auth tag as part of omemo key");

src/main/java/eu/siacs/conversations/crypto/axolotl/XmppAxolotlSession.java 🔗

@@ -43,15 +43,12 @@ public class XmppAxolotlSession implements Comparable<XmppAxolotlSession> {
 		this.account = account;
 	}
 
-	public Integer getPreKeyId() {
+	public Integer getPreKeyIdAndReset() {
+		final Integer preKeyId = this.preKeyId;
+		this.preKeyId = null;
 		return preKeyId;
 	}
 
-	public void resetPreKeyId() {
-
-		preKeyId = null;
-	}
-
 	public String getFingerprint() {
 		return identityKey == null ? null : CryptoHelper.bytesToHex(identityKey.getPublicKey().serialize());
 	}
@@ -87,11 +84,10 @@ public class XmppAxolotlSession implements Comparable<XmppAxolotlSession> {
 		FingerprintStatus status = getTrust();
 		if (!status.isCompromised()) {
 			try {
-				CiphertextMessage ciphertextMessage;
-				try {
-					ciphertextMessage = new PreKeySignalMessage(encryptedKey.key);
-					Optional<Integer> optionalPreKeyId = ((PreKeySignalMessage) ciphertextMessage).getPreKeyId();
-					IdentityKey identityKey = ((PreKeySignalMessage) ciphertextMessage).getIdentityKey();
+				if (encryptedKey.prekey) {
+					PreKeySignalMessage preKeySignalMessage = new PreKeySignalMessage(encryptedKey.key);
+					Optional<Integer> optionalPreKeyId = preKeySignalMessage.getPreKeyId();
+					IdentityKey identityKey = preKeySignalMessage.getIdentityKey();
 					if (!optionalPreKeyId.isPresent()) {
 						throw new CryptoFailedException("PreKeyWhisperMessage did not contain a PreKeyId");
 					}
@@ -100,15 +96,13 @@ public class XmppAxolotlSession implements Comparable<XmppAxolotlSession> {
 						throw new CryptoFailedException("Received PreKeyWhisperMessage but preexisting identity key changed.");
 					}
 					this.identityKey = identityKey;
-				} catch (InvalidVersionException | InvalidMessageException e) {
-					ciphertextMessage = new SignalMessage(encryptedKey.key);
-				}
-				if (ciphertextMessage instanceof PreKeySignalMessage) {
-					plaintext = cipher.decrypt((PreKeySignalMessage) ciphertextMessage);
+					plaintext = cipher.decrypt(preKeySignalMessage);
 				} else {
-					plaintext = cipher.decrypt((SignalMessage) ciphertextMessage);
+					SignalMessage signalMessage = new SignalMessage(encryptedKey.key);
+					plaintext = cipher.decrypt(signalMessage);
+					preKeyId = null; //better safe than sorry because we use that to do special after prekey handling
 				}
-			} catch (InvalidKeyException | LegacyMessageException | InvalidMessageException | DuplicateMessageException | NoSessionException | InvalidKeyIdException | UntrustedIdentityException e) {
+			} catch (InvalidVersionException | InvalidKeyException | LegacyMessageException | InvalidMessageException | DuplicateMessageException | NoSessionException | InvalidKeyIdException | UntrustedIdentityException e) {
 				if (!(e instanceof DuplicateMessageException)) {
 					e.printStackTrace();
 				}

src/main/java/eu/siacs/conversations/parser/MessageParser.java 🔗

@@ -163,24 +163,28 @@ public class MessageParser extends AbstractParser implements OnMessagePacketRece
 		return false;
 	}
 
-	private Message parseAxolotlChat(Element axolotlMessage, Jid from, Conversation conversation, int status) {
-		AxolotlService service = conversation.getAccount().getAxolotlService();
-		XmppAxolotlMessage xmppAxolotlMessage;
+	private Message parseAxolotlChat(Element axolotlMessage, Jid from, Conversation conversation, int status, boolean postpone) {
+		final AxolotlService service = conversation.getAccount().getAxolotlService();
+		final XmppAxolotlMessage xmppAxolotlMessage;
 		try {
 			xmppAxolotlMessage = XmppAxolotlMessage.fromElement(axolotlMessage, from.toBareJid());
 		} catch (Exception e) {
 			Log.d(Config.LOGTAG, conversation.getAccount().getJid().toBareJid() + ": invalid omemo message received " + e.getMessage());
 			return null;
 		}
-		XmppAxolotlMessage.XmppAxolotlPlaintextMessage plaintextMessage = service.processReceivingPayloadMessage(xmppAxolotlMessage);
-		if (plaintextMessage != null) {
-			Message finishedMessage = new Message(conversation, plaintextMessage.getPlaintext(), Message.ENCRYPTION_AXOLOTL, status);
-			finishedMessage.setFingerprint(plaintextMessage.getFingerprint());
-			Log.d(Config.LOGTAG, AxolotlService.getLogprefix(finishedMessage.getConversation().getAccount()) + " Received Message with session fingerprint: " + plaintextMessage.getFingerprint());
-			return finishedMessage;
+		if (xmppAxolotlMessage.hasPayload()) {
+			final XmppAxolotlMessage.XmppAxolotlPlaintextMessage plaintextMessage = service.processReceivingPayloadMessage(xmppAxolotlMessage, postpone);
+			if (plaintextMessage != null) {
+				Message finishedMessage = new Message(conversation, plaintextMessage.getPlaintext(), Message.ENCRYPTION_AXOLOTL, status);
+				finishedMessage.setFingerprint(plaintextMessage.getFingerprint());
+				Log.d(Config.LOGTAG, AxolotlService.getLogprefix(finishedMessage.getConversation().getAccount()) + " Received Message with session fingerprint: " + plaintextMessage.getFingerprint());
+				return finishedMessage;
+			}
 		} else {
-			return null;
+			Log.d(Config.LOGTAG,conversation.getAccount().getJid().toBareJid()+": received OMEMO key transport message");
+			service.processReceivingKeyTransportMessage(xmppAxolotlMessage, postpone);
 		}
+		return null;
 	}
 
 	private class Invite {
@@ -468,7 +472,7 @@ public class MessageParser extends AbstractParser implements OnMessagePacketRece
 				} else {
 					origin = from;
 				}
-				message = parseAxolotlChat(axolotlEncrypted, origin, conversation, status);
+				message = parseAxolotlChat(axolotlEncrypted, origin, conversation, status, query != null);
 				if (message == null) {
 					if (query == null &&  extractChatState(mXmppConnectionService.find(account, counterpart.toBareJid()), isTypeGroupChat, packet)) {
 						mXmppConnectionService.updateConversationUi();

src/main/java/eu/siacs/conversations/xmpp/jingle/JingleConnection.java 🔗

@@ -448,7 +448,7 @@ public class JingleConnection implements Transferable {
 				}
 				this.file = this.mXmppConnectionService.getFileBackend().getFile(message, false);
 				if (mXmppAxolotlMessage != null) {
-					XmppAxolotlMessage.XmppAxolotlKeyTransportMessage transportMessage = account.getAxolotlService().processReceivingKeyTransportMessage(mXmppAxolotlMessage);
+					XmppAxolotlMessage.XmppAxolotlKeyTransportMessage transportMessage = account.getAxolotlService().processReceivingKeyTransportMessage(mXmppAxolotlMessage, false);
 					if (transportMessage != null) {
 						message.setEncryption(Message.ENCRYPTION_AXOLOTL);
 						this.file.setKey(transportMessage.getKey());