Avoid possible memory leak of FakeServer in tests

Max Brunsfeld created

Change summary

crates/client/src/test.rs | 20 ++++++++++++--------
1 file changed, 12 insertions(+), 8 deletions(-)

Detailed changes

crates/client/src/test.rs 🔗

@@ -41,12 +41,14 @@ impl FakeServer {
         Arc::get_mut(client)
             .unwrap()
             .override_authenticate({
-                let state = server.state.clone();
+                let state = Arc::downgrade(&server.state);
                 move |cx| {
-                    let mut state = state.lock();
-                    state.auth_count += 1;
-                    let access_token = state.access_token.to_string();
+                    let state = state.clone();
                     cx.spawn(move |_| async move {
+                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
+                        let mut state = state.lock();
+                        state.auth_count += 1;
+                        let access_token = state.access_token.to_string();
                         Ok(Credentials {
                             user_id: client_user_id,
                             access_token,
@@ -55,21 +57,23 @@ impl FakeServer {
                 }
             })
             .override_establish_connection({
-                let peer = server.peer.clone();
-                let state = server.state.clone();
+                let peer = Arc::downgrade(&server.peer).clone();
+                let state = Arc::downgrade(&server.state);
                 move |credentials, cx| {
                     let peer = peer.clone();
                     let state = state.clone();
                     let credentials = credentials.clone();
                     cx.spawn(move |cx| async move {
-                        assert_eq!(credentials.user_id, client_user_id);
-
+                        let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
+                        let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
                         if state.lock().forbid_connections {
                             Err(EstablishConnectionError::Other(anyhow!(
                                 "server is forbidding connections"
                             )))?
                         }
 
+                        assert_eq!(credentials.user_id, client_user_id);
+
                         if credentials.access_token != state.lock().access_token.to_string() {
                             Err(EstablishConnectionError::Unauthorized)?
                         }