Clean up tasks properly when dropping a FakeLanguageServer

Max Brunsfeld created

* Make sure the fake's IO tasks are stopped
* Ensure that the fake's stdout is closed, so that the corresponding language
  server's IO tasks are woken up and halted.

Change summary

crates/lsp/src/lsp.rs | 166 +++++++++++++++++++++++++++-----------------
1 file changed, 102 insertions(+), 64 deletions(-)

Detailed changes

crates/lsp/src/lsp.rs 🔗

@@ -476,18 +476,22 @@ impl Drop for Subscription {
 
 #[cfg(any(test, feature = "test-support"))]
 pub struct FakeLanguageServer {
-    handlers: Arc<
-        Mutex<
-            HashMap<
-                &'static str,
-                Box<dyn Send + FnMut(usize, &[u8], gpui::AsyncAppContext) -> Vec<u8>>,
-            >,
-        >,
-    >,
+    handlers: FakeLanguageServerHandlers,
     outgoing_tx: futures::channel::mpsc::UnboundedSender<Vec<u8>>,
     incoming_rx: futures::channel::mpsc::UnboundedReceiver<Vec<u8>>,
+    _input_task: Task<Result<()>>,
+    _output_task: Task<Result<()>>,
 }
 
+type FakeLanguageServerHandlers = Arc<
+    Mutex<
+        HashMap<
+            &'static str,
+            Box<dyn Send + FnMut(usize, &[u8], gpui::AsyncAppContext) -> Vec<u8>>,
+        >,
+    >,
+>;
+
 #[cfg(any(test, feature = "test-support"))]
 impl LanguageServer {
     pub fn fake(cx: &mut gpui::MutableAppContext) -> (Arc<Self>, FakeLanguageServer) {
@@ -533,59 +537,69 @@ impl FakeLanguageServer {
 
         let (incoming_tx, incoming_rx) = futures::channel::mpsc::unbounded();
         let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
-        let this = Self {
-            outgoing_tx: outgoing_tx.clone(),
-            incoming_rx,
-            handlers: Default::default(),
-        };
+        let handlers = FakeLanguageServerHandlers::default();
 
-        // Receive incoming messages
-        let handlers = this.handlers.clone();
-        cx.spawn(|cx| async move {
-            let mut buffer = Vec::new();
-            let mut stdin = smol::io::BufReader::new(stdin);
-            while Self::receive(&mut stdin, &mut buffer).await.is_ok() {
-                cx.background().simulate_random_delay().await;
-                if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) {
-                    assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
-
-                    if let Some(handler) = handlers.lock().get_mut(request.method) {
-                        let response =
-                            handler(request.id, request.params.get().as_bytes(), cx.clone());
-                        log::debug!("handled lsp request. method:{}", request.method);
-                        outgoing_tx.unbounded_send(response)?;
-                    } else {
-                        log::debug!("unhandled lsp request. method:{}", request.method);
-                        outgoing_tx.unbounded_send(
-                            serde_json::to_vec(&AnyResponse {
+        let input_task = cx.spawn(|cx| {
+            let handlers = handlers.clone();
+            let outgoing_tx = outgoing_tx.clone();
+            async move {
+                let mut buffer = Vec::new();
+                let mut stdin = smol::io::BufReader::new(stdin);
+                while Self::receive(&mut stdin, &mut buffer).await.is_ok() {
+                    cx.background().simulate_random_delay().await;
+                    if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) {
+                        assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
+
+                        let response;
+                        if let Some(handler) = handlers.lock().get_mut(request.method) {
+                            response =
+                                handler(request.id, request.params.get().as_bytes(), cx.clone());
+                            log::debug!("handled lsp request. method:{}", request.method);
+                        } else {
+                            response = serde_json::to_vec(&AnyResponse {
                                 id: request.id,
                                 error: Some(Error {
                                     message: "no handler".to_string(),
                                 }),
                                 result: None,
                             })
-                            .unwrap(),
-                        )?;
+                            .unwrap();
+                            log::debug!("unhandled lsp request. method:{}", request.method);
+                        }
+                        outgoing_tx.unbounded_send(response)?;
+                    } else {
+                        incoming_tx.unbounded_send(buffer.clone())?;
                     }
-                } else {
-                    incoming_tx.unbounded_send(buffer.clone())?;
                 }
+                Ok::<_, anyhow::Error>(())
             }
-            Ok::<_, anyhow::Error>(())
-        })
-        .detach();
-
-        // Send outgoing messages
-        cx.background()
-            .spawn(async move {
-                let mut stdout = smol::io::BufWriter::new(stdout);
-                while let Some(notification) = outgoing_rx.next().await {
-                    Self::send(&mut stdout, &notification).await;
-                }
-            })
-            .detach();
+        });
 
-        this
+        let output_task = cx.background().spawn(async move {
+            let mut stdout = smol::io::BufWriter::new(PipeWriterCloseOnDrop(stdout));
+            while let Some(message) = outgoing_rx.next().await {
+                stdout
+                    .write_all(CONTENT_LEN_HEADER.as_bytes())
+                    .await
+                    .unwrap();
+                stdout
+                    .write_all((format!("{}", message.len())).as_bytes())
+                    .await
+                    .unwrap();
+                stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
+                stdout.write_all(&message).await.unwrap();
+                stdout.flush().await.unwrap();
+            }
+            Ok(())
+        });
+
+        Self {
+            outgoing_tx,
+            incoming_rx,
+            handlers,
+            _input_task: input_task,
+            _output_task: output_task,
+        }
     }
 
     pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
@@ -665,20 +679,6 @@ impl FakeLanguageServer {
         .await;
     }
 
-    async fn send(stdout: &mut smol::io::BufWriter<async_pipe::PipeWriter>, message: &[u8]) {
-        stdout
-            .write_all(CONTENT_LEN_HEADER.as_bytes())
-            .await
-            .unwrap();
-        stdout
-            .write_all((format!("{}", message.len())).as_bytes())
-            .await
-            .unwrap();
-        stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
-        stdout.write_all(&message).await.unwrap();
-        stdout.flush().await.unwrap();
-    }
-
     async fn receive(
         stdin: &mut smol::io::BufReader<async_pipe::PipeReader>,
         buffer: &mut Vec<u8>,
@@ -699,6 +699,44 @@ impl FakeLanguageServer {
     }
 }
 
+struct PipeWriterCloseOnDrop(async_pipe::PipeWriter);
+
+impl Drop for PipeWriterCloseOnDrop {
+    fn drop(&mut self) {
+        self.0.close().ok();
+    }
+}
+
+impl AsyncWrite for PipeWriterCloseOnDrop {
+    fn poll_write(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+        buf: &[u8],
+    ) -> std::task::Poll<std::io::Result<usize>> {
+        let pipe = &mut self.0;
+        smol::pin!(pipe);
+        pipe.poll_write(cx, buf)
+    }
+
+    fn poll_flush(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<std::io::Result<()>> {
+        let pipe = &mut self.0;
+        smol::pin!(pipe);
+        pipe.poll_flush(cx)
+    }
+
+    fn poll_close(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<std::io::Result<()>> {
+        let pipe = &mut self.0;
+        smol::pin!(pipe);
+        pipe.poll_close(cx)
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;