diff --git a/build.gradle b/build.gradle index 8caaf0b0e6385639d625f3159c312b556b668f24..94e28a41eacf8369500de3823d6569419312f2da 100644 --- a/build.gradle +++ b/build.gradle @@ -134,6 +134,9 @@ dependencies { testImplementation 'junit:junit:4.13.2' testImplementation 'org.robolectric:robolectric:4.14.1' + testImplementation 'org.mockito:mockito-core:5.14.2' + testImplementation 'net.bytebuddy:byte-buddy:1.15.11' + testImplementation 'net.bytebuddy:byte-buddy-agent:1.15.11' androidTestImplementation 'androidx.test.ext:junit:1.2.1' } diff --git a/src/main/java/eu/siacs/conversations/entities/Conversation.java b/src/main/java/eu/siacs/conversations/entities/Conversation.java index 63f0c1579cf3fb548be8bf744bd7a0825dcb6ec7..8c68246bb3dfb19a566ddcce77b72b76fd231716 100644 --- a/src/main/java/eu/siacs/conversations/entities/Conversation.java +++ b/src/main/java/eu/siacs/conversations/entities/Conversation.java @@ -110,6 +110,7 @@ import java.util.HashSet; import java.util.List; import java.util.ListIterator; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.Set; import java.util.Timer; @@ -218,7 +219,7 @@ public class Conversation extends AbstractEntity private int mode; private final JSONObject attributes; private Jid nextCounterpart; - private transient MucOptions mucOptions = null; + private final transient AtomicReference mucOptions = new AtomicReference<>(); private boolean messagesLeftOnServer = true; private ChatState mOutgoingChatState = Config.DEFAULT_CHAT_STATE; private ChatState mIncomingChatState = Config.DEFAULT_CHAT_STATE; @@ -1141,15 +1142,16 @@ public class Conversation extends AbstractEntity return getMucOptions().isPrivateAndNonAnonymous(); } - public synchronized MucOptions getMucOptions() { - if (this.mucOptions == null) { - this.mucOptions = new MucOptions(this); - } - return this.mucOptions; + public @NonNull MucOptions getMucOptions() { + return this.mucOptions.updateAndGet( + existing -> existing != null ? existing : new MucOptions(this) + ); } - public void resetMucOptions() { - this.mucOptions = null; + public @NonNull MucOptions resetMucOptions() { + return this.mucOptions.updateAndGet( + ignoredExisting -> new MucOptions(this) + ); } public void setContactJid(final Jid jid) { @@ -1344,7 +1346,7 @@ public class Conversation extends AbstractEntity public boolean storeInCache() { if ("cache".equals(getAttribute("storeMedia"))) return true; if ("shared".equals(getAttribute("storeMedia"))) return false; - if (mode == Conversation.MODE_MULTI && !mucOptions.isPrivateAndNonAnonymous()) return true; + if (mode == Conversation.MODE_MULTI && !getMucOptions().isPrivateAndNonAnonymous()) return true; return false; } diff --git a/src/main/java/eu/siacs/conversations/xmpp/manager/MultiUserChatManager.java b/src/main/java/eu/siacs/conversations/xmpp/manager/MultiUserChatManager.java index be523106094da92027fa306d89268d025099ba79..4fe8833f7cafa7fb917367a3a081c444d1c957fa 100644 --- a/src/main/java/eu/siacs/conversations/xmpp/manager/MultiUserChatManager.java +++ b/src/main/java/eu/siacs/conversations/xmpp/manager/MultiUserChatManager.java @@ -87,8 +87,7 @@ public class MultiUserChatManager extends AbstractManager { if (Config.MUC_LEAVE_BEFORE_JOIN) { unavailable(conversation); } - conversation.resetMucOptions(); - conversation.getMucOptions().setAutoPushConfiguration(autoPushConfiguration); + conversation.resetMucOptions().setAutoPushConfiguration(autoPushConfiguration); conversation.setHasMessagesLeftOnServer(false); final var disco = fetchDiscoInfo(conversation); diff --git a/src/test/java/eu/siacs/conversations/entities/ConversationGetMucOptionsRaceTest.java b/src/test/java/eu/siacs/conversations/entities/ConversationGetMucOptionsRaceTest.java new file mode 100644 index 0000000000000000000000000000000000000000..72d5977f50554147bb41fd59a4658f61d6283cc6 --- /dev/null +++ b/src/test/java/eu/siacs/conversations/entities/ConversationGetMucOptionsRaceTest.java @@ -0,0 +1,183 @@ +package eu.siacs.conversations.entities; + +import static net.bytebuddy.matcher.ElementMatchers.named; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.robolectric.annotation.Config; +import org.robolectric.annotation.ConscryptMode; + +import android.os.Build; +import eu.siacs.conversations.Conversations; +import eu.siacs.conversations.xmpp.Jid; +import junit.framework.Assert; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.asm.AsmVisitorWrapper; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.jar.asm.MethodVisitor; +import net.bytebuddy.jar.asm.Opcodes; +import net.bytebuddy.jar.asm.Type; +import net.bytebuddy.pool.TypePool; + +@RunWith(RobolectricTestRunner.class) +@Config(sdk = Build.VERSION_CODES.TIRAMISU, application = Conversations.class) +@ConscryptMode(ConscryptMode.Mode.OFF) +public class ConversationGetMucOptionsRaceTest { + + static int fieldReadCount; + static volatile CountDownLatch remainingReads; + static volatile CountDownLatch resetDone; + + public static void gate() { + final var reads = remainingReads; + final var reset = resetDone; + if (reads == null || reset == null) return; + final boolean lastRead = reads.getCount() == 1; + reads.countDown(); + if (lastRead) { + try { + reset.await(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + static class GetMucOptionsInstrumentor extends MethodVisitor { + private int count = 0; + + GetMucOptionsInstrumentor(MethodVisitor delegate) { + super(Opcodes.ASM9, delegate); + } + + @Override + public void visitFieldInsn( + int opcode, String owner, String name, String descriptor + ) { + if (opcode == Opcodes.GETFIELD && "mucOptions".equals(name)) { + count++; + super.visitMethodInsn( + Opcodes.INVOKESTATIC, + Type.getInternalName( + ConversationGetMucOptionsRaceTest.class), + "gate", + "()V", + false + ); + } + super.visitFieldInsn(opcode, owner, name, descriptor); + } + + @Override + public void visitEnd() { + fieldReadCount = count; + super.visitEnd(); + } + } + + @SuppressWarnings("unchecked") + @BeforeClass + public static void instrumentConversation() throws Exception { + Class.forName("net.bytebuddy.agent.ByteBuddyAgent") + .getMethod("install") + .invoke(null); + + final var strategy = (ClassLoadingStrategy) + Class.forName( + "net.bytebuddy.dynamic.loading.ClassReloadingStrategy") + .getMethod("fromInstalledAgent") + .invoke(null); + + new ByteBuddy() + .redefine(Conversation.class) + .visit(new AsmVisitorWrapper.ForDeclaredMethods() + .method( + named("getMucOptions"), + new AsmVisitorWrapper.ForDeclaredMethods + .MethodVisitorWrapper() { + @Override + public MethodVisitor wrap( + TypeDescription instrumentedType, + MethodDescription instrumentedMethod, + MethodVisitor methodVisitor, + Implementation.Context implementationContext, + TypePool typePool, + int writerFlags, + int readerFlags + ) { + return new GetMucOptionsInstrumentor( + methodVisitor); + } + } + ) + ) + .make() + .load(Conversation.class.getClassLoader(), strategy); + } + + @Test + public void testGetMucOptionsNeverReturnsNull() throws Throwable { + final var account = mock(Account.class); + when(account.getJid()).thenReturn( + Jid.ofLocalAndDomain("testAccount", "example.org")); + + final var conversation = new Conversation( + "Test MUC", + account, + Jid.ofLocalAndDomain("testMuc", "example.org"), + Conversation.MODE_MULTI + ); + conversation.getMucOptions(); + + remainingReads = new CountDownLatch(fieldReadCount); + resetDone = new CountDownLatch(1); + + final var result = new AtomicReference(); + final var error = new AtomicReference(); + + Thread reader = new Thread(() -> { + try { + result.set(conversation.getMucOptions()); + } catch (Throwable t) { + error.set(t); + } + }); + + Thread resetter = new Thread(() -> { + try { + remainingReads.await(); + conversation.resetMucOptions(); + resetDone.countDown(); + } catch (Throwable t) { + error.set(t); + } + }); + + reader.start(); + resetter.start(); + + reader.join(10_000); + resetter.join(10_000); + + remainingReads = null; + resetDone = null; + + if (error.get() != null) throw error.get(); + + Assert.assertNotNull( + "getMucOptions() returned null" + + " — the field must not be re-read after the null check", + result.get() + ); + } +}