ConversationGetMucOptionsRaceTest.java

  1package eu.siacs.conversations.entities;
  2
  3import static net.bytebuddy.matcher.ElementMatchers.named;
  4import static org.mockito.Mockito.mock;
  5import static org.mockito.Mockito.when;
  6
  7import java.util.concurrent.CountDownLatch;
  8import java.util.concurrent.atomic.AtomicReference;
  9
 10import org.junit.BeforeClass;
 11import org.junit.Test;
 12import org.junit.runner.RunWith;
 13import org.robolectric.RobolectricTestRunner;
 14import org.robolectric.annotation.Config;
 15import org.robolectric.annotation.ConscryptMode;
 16
 17import android.os.Build;
 18import eu.siacs.conversations.Conversations;
 19import eu.siacs.conversations.xmpp.Jid;
 20import junit.framework.Assert;
 21import net.bytebuddy.ByteBuddy;
 22import net.bytebuddy.asm.AsmVisitorWrapper;
 23import net.bytebuddy.description.method.MethodDescription;
 24import net.bytebuddy.description.type.TypeDescription;
 25import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
 26import net.bytebuddy.implementation.Implementation;
 27import net.bytebuddy.jar.asm.MethodVisitor;
 28import net.bytebuddy.jar.asm.Opcodes;
 29import net.bytebuddy.jar.asm.Type;
 30import net.bytebuddy.pool.TypePool;
 31
 32@RunWith(RobolectricTestRunner.class)
 33@Config(sdk = Build.VERSION_CODES.TIRAMISU, application = Conversations.class)
 34@ConscryptMode(ConscryptMode.Mode.OFF)
 35public class ConversationGetMucOptionsRaceTest {
 36
 37    static int fieldReadCount;
 38    static volatile CountDownLatch remainingReads;
 39    static volatile CountDownLatch resetDone;
 40
 41    public static void gate() {
 42        final var reads = remainingReads;
 43        final var reset = resetDone;
 44        if (reads == null || reset == null) return;
 45        final boolean lastRead = reads.getCount() == 1;
 46        reads.countDown();
 47        if (lastRead) {
 48            try {
 49                reset.await();
 50            } catch (InterruptedException e) {
 51                throw new RuntimeException(e);
 52            }
 53        }
 54    }
 55
 56    static class GetMucOptionsInstrumentor extends MethodVisitor {
 57        private int count = 0;
 58
 59        GetMucOptionsInstrumentor(MethodVisitor delegate) {
 60            super(Opcodes.ASM9, delegate);
 61        }
 62
 63        @Override
 64        public void visitFieldInsn(
 65            int opcode, String owner, String name, String descriptor
 66        ) {
 67            if (opcode == Opcodes.GETFIELD && "mucOptions".equals(name)) {
 68                count++;
 69                super.visitMethodInsn(
 70                    Opcodes.INVOKESTATIC,
 71                    Type.getInternalName(
 72                        ConversationGetMucOptionsRaceTest.class),
 73                    "gate",
 74                    "()V",
 75                    false
 76                );
 77            }
 78            super.visitFieldInsn(opcode, owner, name, descriptor);
 79        }
 80
 81        @Override
 82        public void visitEnd() {
 83            fieldReadCount = count;
 84            super.visitEnd();
 85        }
 86    }
 87
 88    @SuppressWarnings("unchecked")
 89    @BeforeClass
 90    public static void instrumentConversation() throws Exception {
 91        Class.forName("net.bytebuddy.agent.ByteBuddyAgent")
 92            .getMethod("install")
 93            .invoke(null);
 94
 95        final var strategy = (ClassLoadingStrategy<ClassLoader>)
 96            Class.forName(
 97                "net.bytebuddy.dynamic.loading.ClassReloadingStrategy")
 98                .getMethod("fromInstalledAgent")
 99                .invoke(null);
100
101        new ByteBuddy()
102            .redefine(Conversation.class)
103            .visit(new AsmVisitorWrapper.ForDeclaredMethods()
104                .method(
105                    named("getMucOptions"),
106                    new AsmVisitorWrapper.ForDeclaredMethods
107                            .MethodVisitorWrapper() {
108                        @Override
109                        public MethodVisitor wrap(
110                            TypeDescription instrumentedType,
111                            MethodDescription instrumentedMethod,
112                            MethodVisitor methodVisitor,
113                            Implementation.Context implementationContext,
114                            TypePool typePool,
115                            int writerFlags,
116                            int readerFlags
117                        ) {
118                            return new GetMucOptionsInstrumentor(
119                                methodVisitor);
120                        }
121                    }
122                )
123            )
124            .make()
125            .load(Conversation.class.getClassLoader(), strategy);
126    }
127
128    @Test
129    public void testGetMucOptionsNeverReturnsNull() throws Throwable {
130        final var account = mock(Account.class);
131        when(account.getJid()).thenReturn(
132            Jid.ofLocalAndDomain("testAccount", "example.org"));
133
134        final var conversation = new Conversation(
135            "Test MUC",
136            account,
137            Jid.ofLocalAndDomain("testMuc", "example.org"),
138            Conversation.MODE_MULTI
139        );
140        conversation.getMucOptions();
141
142        remainingReads = new CountDownLatch(fieldReadCount);
143        resetDone = new CountDownLatch(1);
144
145        final var result = new AtomicReference<MucOptions>();
146        final var error = new AtomicReference<Throwable>();
147
148        Thread reader = new Thread(() -> {
149            try {
150                result.set(conversation.getMucOptions());
151            } catch (Throwable t) {
152                error.set(t);
153            }
154        });
155
156        Thread resetter = new Thread(() -> {
157            try {
158                remainingReads.await();
159                conversation.resetMucOptions();
160                resetDone.countDown();
161            } catch (Throwable t) {
162                error.set(t);
163            }
164        });
165
166        reader.start();
167        resetter.start();
168
169        reader.join(10_000);
170        resetter.join(10_000);
171
172        remainingReads = null;
173        resetDone = null;
174
175        if (error.get() != null) throw error.get();
176
177        Assert.assertNotNull(
178            "getMucOptions() returned null"
179                + " — the field must not be re-read after the null check",
180            result.get()
181        );
182    }
183}