Ensure that futures returns from `RpcClient` are 'static

Antonio Scandurra created

Change summary

zed/src/rpc_client.rs | 123 ++++++++++++++++++++++++--------------------
1 file changed, 68 insertions(+), 55 deletions(-)

Detailed changes

zed/src/rpc_client.rs 🔗

@@ -12,6 +12,7 @@ use smol::{
 };
 use std::{
     collections::HashMap,
+    future::Future,
     sync::{
         atomic::{self, AtomicI32},
         Arc,
@@ -32,7 +33,7 @@ impl<Conn> RpcClient<Conn>
 where
     Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
 {
-    pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
+    pub fn new(conn: Conn, executor: Arc<Background>) -> Arc<Self> {
         let response_channels = Arc::new(Mutex::new(HashMap::new()));
         let (conn_rx, conn_tx) = smol::io::split(conn);
         let (_drop_tx, drop_rx) = barrier::channel();
@@ -45,12 +46,12 @@ where
             ))
             .detach();
 
-        Self {
+        Arc::new(Self {
             response_channels,
             outgoing: Mutex::new(MessageStream::new(conn_tx)),
             _drop_tx,
             next_message_id: AtomicI32::new(0),
-        }
+        })
     }
 
     async fn handle_incoming(
@@ -101,63 +102,75 @@ where
         }
     }
 
-    pub async fn request<T: RequestMessage>(&self, req: T) -> Result<T::Response> {
-        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-        let (tx, mut rx) = mpsc::channel(1);
-        self.response_channels
-            .lock()
-            .await
-            .insert(message_id, (tx, true));
-        self.outgoing
-            .lock()
-            .await
-            .write_message(&proto::FromClient {
-                id: message_id,
-                variant: Some(req.to_variant()),
-            })
-            .await?;
-        let response = rx
-            .recv()
-            .await
-            .expect("response channel was unexpectedly dropped");
-        T::Response::from_variant(response)
-            .ok_or_else(|| anyhow!("received response of the wrong t"))
+    pub fn request<T: RequestMessage>(
+        self: &Arc<Self>,
+        req: T,
+    ) -> impl Future<Output = Result<T::Response>> {
+        let this = self.clone();
+        async move {
+            let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+            let (tx, mut rx) = mpsc::channel(1);
+            this.response_channels
+                .lock()
+                .await
+                .insert(message_id, (tx, true));
+            this.outgoing
+                .lock()
+                .await
+                .write_message(&proto::FromClient {
+                    id: message_id,
+                    variant: Some(req.to_variant()),
+                })
+                .await?;
+            let response = rx
+                .recv()
+                .await
+                .expect("response channel was unexpectedly dropped");
+            T::Response::from_variant(response)
+                .ok_or_else(|| anyhow!("received response of the wrong t"))
+        }
     }
 
-    pub async fn send<T: SendMessage>(&self, message: T) -> Result<()> {
-        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-        self.outgoing
-            .lock()
-            .await
-            .write_message(&proto::FromClient {
-                id: message_id,
-                variant: Some(message.to_variant()),
-            })
-            .await?;
-        Ok(())
+    pub fn send<T: SendMessage>(self: &Arc<Self>, message: T) -> impl Future<Output = Result<()>> {
+        let this = self.clone();
+        async move {
+            let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+            this.outgoing
+                .lock()
+                .await
+                .write_message(&proto::FromClient {
+                    id: message_id,
+                    variant: Some(message.to_variant()),
+                })
+                .await?;
+            Ok(())
+        }
     }
 
-    pub async fn subscribe<T: SubscribeMessage>(
-        &self,
+    pub fn subscribe<T: SubscribeMessage>(
+        self: &Arc<Self>,
         subscription: T,
-    ) -> Result<impl Stream<Item = Result<T::Event>>> {
-        let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
-        let (tx, rx) = mpsc::channel(256);
-        self.response_channels
-            .lock()
-            .await
-            .insert(message_id, (tx, false));
-        self.outgoing
-            .lock()
-            .await
-            .write_message(&proto::FromClient {
-                id: message_id,
-                variant: Some(subscription.to_variant()),
-            })
-            .await?;
-        Ok(rx.map(|event| {
-            T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
-        }))
+    ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Event>>>> {
+        let this = self.clone();
+        async move {
+            let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
+            let (tx, rx) = mpsc::channel(256);
+            this.response_channels
+                .lock()
+                .await
+                .insert(message_id, (tx, false));
+            this.outgoing
+                .lock()
+                .await
+                .write_message(&proto::FromClient {
+                    id: message_id,
+                    variant: Some(subscription.to_variant()),
+                })
+                .await?;
+            Ok(rx.map(|event| {
+                T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
+            }))
+        }
     }
 }