1package eu.siacs.conversations.entities;
2
3import static net.bytebuddy.matcher.ElementMatchers.named;
4import static org.mockito.Mockito.mock;
5import static org.mockito.Mockito.when;
6
7import java.util.concurrent.CountDownLatch;
8import java.util.concurrent.atomic.AtomicReference;
9
10import org.junit.BeforeClass;
11import org.junit.Test;
12import org.junit.runner.RunWith;
13import org.robolectric.RobolectricTestRunner;
14import org.robolectric.annotation.Config;
15import org.robolectric.annotation.ConscryptMode;
16
17import android.os.Build;
18import eu.siacs.conversations.Conversations;
19import eu.siacs.conversations.xmpp.Jid;
20import junit.framework.Assert;
21import net.bytebuddy.ByteBuddy;
22import net.bytebuddy.asm.AsmVisitorWrapper;
23import net.bytebuddy.description.method.MethodDescription;
24import net.bytebuddy.description.type.TypeDescription;
25import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
26import net.bytebuddy.implementation.Implementation;
27import net.bytebuddy.jar.asm.MethodVisitor;
28import net.bytebuddy.jar.asm.Opcodes;
29import net.bytebuddy.jar.asm.Type;
30import net.bytebuddy.pool.TypePool;
31
32@RunWith(RobolectricTestRunner.class)
33@Config(sdk = Build.VERSION_CODES.TIRAMISU, application = Conversations.class)
34@ConscryptMode(ConscryptMode.Mode.OFF)
35public class ConversationGetMucOptionsRaceTest {
36
37 static int fieldReadCount;
38 static volatile CountDownLatch remainingReads;
39 static volatile CountDownLatch resetDone;
40
41 public static void gate() {
42 final var reads = remainingReads;
43 final var reset = resetDone;
44 if (reads == null || reset == null) return;
45 final boolean lastRead = reads.getCount() == 1;
46 reads.countDown();
47 if (lastRead) {
48 try {
49 reset.await();
50 } catch (InterruptedException e) {
51 throw new RuntimeException(e);
52 }
53 }
54 }
55
56 static class GetMucOptionsInstrumentor extends MethodVisitor {
57 private int count = 0;
58
59 GetMucOptionsInstrumentor(MethodVisitor delegate) {
60 super(Opcodes.ASM9, delegate);
61 }
62
63 @Override
64 public void visitFieldInsn(
65 int opcode, String owner, String name, String descriptor
66 ) {
67 if (opcode == Opcodes.GETFIELD && "mucOptions".equals(name)) {
68 count++;
69 super.visitMethodInsn(
70 Opcodes.INVOKESTATIC,
71 Type.getInternalName(
72 ConversationGetMucOptionsRaceTest.class),
73 "gate",
74 "()V",
75 false
76 );
77 }
78 super.visitFieldInsn(opcode, owner, name, descriptor);
79 }
80
81 @Override
82 public void visitEnd() {
83 fieldReadCount = count;
84 super.visitEnd();
85 }
86 }
87
88 @SuppressWarnings("unchecked")
89 @BeforeClass
90 public static void instrumentConversation() throws Exception {
91 Class.forName("net.bytebuddy.agent.ByteBuddyAgent")
92 .getMethod("install")
93 .invoke(null);
94
95 final var strategy = (ClassLoadingStrategy<ClassLoader>)
96 Class.forName(
97 "net.bytebuddy.dynamic.loading.ClassReloadingStrategy")
98 .getMethod("fromInstalledAgent")
99 .invoke(null);
100
101 new ByteBuddy()
102 .redefine(Conversation.class)
103 .visit(new AsmVisitorWrapper.ForDeclaredMethods()
104 .method(
105 named("getMucOptions"),
106 new AsmVisitorWrapper.ForDeclaredMethods
107 .MethodVisitorWrapper() {
108 @Override
109 public MethodVisitor wrap(
110 TypeDescription instrumentedType,
111 MethodDescription instrumentedMethod,
112 MethodVisitor methodVisitor,
113 Implementation.Context implementationContext,
114 TypePool typePool,
115 int writerFlags,
116 int readerFlags
117 ) {
118 return new GetMucOptionsInstrumentor(
119 methodVisitor);
120 }
121 }
122 )
123 )
124 .make()
125 .load(Conversation.class.getClassLoader(), strategy);
126 }
127
128 @Test
129 public void testGetMucOptionsNeverReturnsNull() throws Throwable {
130 final var account = mock(Account.class);
131 when(account.getJid()).thenReturn(
132 Jid.ofLocalAndDomain("testAccount", "example.org"));
133
134 final var conversation = new Conversation(
135 "Test MUC",
136 account,
137 Jid.ofLocalAndDomain("testMuc", "example.org"),
138 Conversation.MODE_MULTI
139 );
140 conversation.getMucOptions();
141
142 remainingReads = new CountDownLatch(fieldReadCount);
143 resetDone = new CountDownLatch(1);
144
145 final var result = new AtomicReference<MucOptions>();
146 final var error = new AtomicReference<Throwable>();
147
148 Thread reader = new Thread(() -> {
149 try {
150 result.set(conversation.getMucOptions());
151 } catch (Throwable t) {
152 error.set(t);
153 }
154 });
155
156 Thread resetter = new Thread(() -> {
157 try {
158 remainingReads.await();
159 conversation.resetMucOptions();
160 resetDone.countDown();
161 } catch (Throwable t) {
162 error.set(t);
163 }
164 });
165
166 reader.start();
167 resetter.start();
168
169 reader.join(10_000);
170 resetter.join(10_000);
171
172 remainingReads = null;
173 resetDone = null;
174
175 if (error.get() != null) throw error.get();
176
177 Assert.assertNotNull(
178 "getMucOptions() returned null"
179 + " — the field must not be re-read after the null check",
180 result.get()
181 );
182 }
183}