Introduce a new `TryFutureExt::unwrap` method

Antonio Scandurra created

Change summary

Cargo.lock                             |   1 
crates/client/src/telemetry.rs         |   4 
crates/collab_ui/src/contact_finder.rs |   2 
crates/journal/Cargo.toml              |   1 
crates/journal/src/journal.rs          |   2 
crates/lsp/src/lsp.rs                  | 191 +++++++++++++++------------
crates/project/src/project.rs          |   2 
crates/util/src/util.rs                |  40 +++++
8 files changed, 151 insertions(+), 92 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3157,6 +3157,7 @@ dependencies = [
 name = "journal"
 version = "0.1.0"
 dependencies = [
+ "anyhow",
  "chrono",
  "dirs 4.0.0",
  "editor",

crates/client/src/telemetry.rs 🔗

@@ -224,7 +224,7 @@ impl Telemetry {
                             .header("Content-Type", "application/json")
                             .body(json_bytes.into())?;
                         this.http_client.send(request).await?;
-                        Ok(())
+                        anyhow::Ok(())
                     }
                     .log_err(),
                 )
@@ -320,7 +320,7 @@ impl Telemetry {
                             .header("Content-Type", "application/json")
                             .body(json_bytes.into())?;
                         this.http_client.send(request).await?;
-                        Ok(())
+                        anyhow::Ok(())
                     }
                     .log_err(),
                 )

crates/collab_ui/src/contact_finder.rs 🔗

@@ -68,7 +68,7 @@ impl PickerDelegate for ContactFinder {
                     this.potential_contacts = potential_contacts.into();
                     cx.notify();
                 });
-                Ok(())
+                anyhow::Ok(())
             }
             .log_err()
             .await;

crates/journal/Cargo.toml 🔗

@@ -13,6 +13,7 @@ editor = { path = "../editor" }
 gpui = { path = "../gpui" }
 util = { path = "../util" }
 workspace = { path = "../workspace" }
+anyhow = "1.0"
 chrono = "0.4"
 dirs = "4.0"
 log = { version = "0.4.16", features = ["kv_unstable_serde"] }

crates/journal/src/journal.rs 🔗

@@ -73,7 +73,7 @@ pub fn new_journal_entry(app_state: Arc<AppState>, cx: &mut MutableAppContext) {
                 }
             }
 
-            Ok(())
+            anyhow::Ok(())
         }
         .log_err()
     })

crates/lsp/src/lsp.rs 🔗

@@ -160,15 +160,13 @@ impl LanguageServer {
         server: Option<Child>,
         root_path: &Path,
         cx: AsyncAppContext,
-        mut on_unhandled_notification: F,
+        on_unhandled_notification: F,
     ) -> Self
     where
         Stdin: AsyncWrite + Unpin + Send + 'static,
         Stdout: AsyncRead + Unpin + Send + 'static,
         F: FnMut(AnyNotification) + 'static + Send,
     {
-        let mut stdin = BufWriter::new(stdin);
-        let mut stdout = BufReader::new(stdout);
         let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
         let notification_handlers =
             Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
@@ -177,89 +175,19 @@ impl LanguageServer {
         let input_task = cx.spawn(|cx| {
             let notification_handlers = notification_handlers.clone();
             let response_handlers = response_handlers.clone();
-            async move {
-                let _clear_response_handlers = util::defer({
-                    let response_handlers = response_handlers.clone();
-                    move || {
-                        response_handlers.lock().take();
-                    }
-                });
-                let mut buffer = Vec::new();
-                loop {
-                    buffer.clear();
-                    stdout.read_until(b'\n', &mut buffer).await?;
-                    stdout.read_until(b'\n', &mut buffer).await?;
-                    let message_len: usize = std::str::from_utf8(&buffer)?
-                        .strip_prefix(CONTENT_LEN_HEADER)
-                        .ok_or_else(|| anyhow!("invalid header"))?
-                        .trim_end()
-                        .parse()?;
-
-                    buffer.resize(message_len, 0);
-                    stdout.read_exact(&mut buffer).await?;
-                    log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
-
-                    if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
-                        if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
-                            handler(msg.id, msg.params.get(), cx.clone());
-                        } else {
-                            on_unhandled_notification(msg);
-                        }
-                    } else if let Ok(AnyResponse {
-                        id, error, result, ..
-                    }) = serde_json::from_slice(&buffer)
-                    {
-                        if let Some(handler) = response_handlers
-                            .lock()
-                            .as_mut()
-                            .and_then(|handlers| handlers.remove(&id))
-                        {
-                            if let Some(error) = error {
-                                handler(Err(error));
-                            } else if let Some(result) = result {
-                                handler(Ok(result.get()));
-                            } else {
-                                handler(Ok("null"));
-                            }
-                        }
-                    } else {
-                        warn!(
-                            "Failed to deserialize message:\n{}",
-                            std::str::from_utf8(&buffer)?
-                        );
-                    }
-
-                    // Don't starve the main thread when receiving lots of messages at once.
-                    smol::future::yield_now().await;
-                }
-            }
+            Self::handle_input(
+                stdout,
+                on_unhandled_notification,
+                notification_handlers,
+                response_handlers,
+                cx,
+            )
             .log_err()
         });
         let (output_done_tx, output_done_rx) = barrier::channel();
         let output_task = cx.background().spawn({
             let response_handlers = response_handlers.clone();
-            async move {
-                let _clear_response_handlers = util::defer({
-                    let response_handlers = response_handlers.clone();
-                    move || {
-                        response_handlers.lock().take();
-                    }
-                });
-                let mut content_len_buffer = Vec::new();
-                while let Ok(message) = outbound_rx.recv().await {
-                    log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
-                    content_len_buffer.clear();
-                    write!(content_len_buffer, "{}", message.len()).unwrap();
-                    stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
-                    stdin.write_all(&content_len_buffer).await?;
-                    stdin.write_all("\r\n\r\n".as_bytes()).await?;
-                    stdin.write_all(&message).await?;
-                    stdin.flush().await?;
-                }
-                drop(output_done_tx);
-                Ok(())
-            }
-            .log_err()
+            Self::handle_output(stdin, outbound_rx, output_done_tx, response_handlers).log_err()
         });
 
         Self {
@@ -278,6 +206,105 @@ impl LanguageServer {
         }
     }
 
+    async fn handle_input<Stdout, F>(
+        stdout: Stdout,
+        mut on_unhandled_notification: F,
+        notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+        response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
+        cx: AsyncAppContext,
+    ) -> anyhow::Result<()>
+    where
+        Stdout: AsyncRead + Unpin + Send + 'static,
+        F: FnMut(AnyNotification) + 'static + Send,
+    {
+        let mut stdout = BufReader::new(stdout);
+        let _clear_response_handlers = util::defer({
+            let response_handlers = response_handlers.clone();
+            move || {
+                response_handlers.lock().take();
+            }
+        });
+        let mut buffer = Vec::new();
+        loop {
+            buffer.clear();
+            stdout.read_until(b'\n', &mut buffer).await?;
+            stdout.read_until(b'\n', &mut buffer).await?;
+            let message_len: usize = std::str::from_utf8(&buffer)?
+                .strip_prefix(CONTENT_LEN_HEADER)
+                .ok_or_else(|| anyhow!("invalid header"))?
+                .trim_end()
+                .parse()?;
+
+            buffer.resize(message_len, 0);
+            stdout.read_exact(&mut buffer).await?;
+            log::trace!("incoming message:{}", String::from_utf8_lossy(&buffer));
+
+            if let Ok(msg) = serde_json::from_slice::<AnyNotification>(&buffer) {
+                if let Some(handler) = notification_handlers.lock().get_mut(msg.method) {
+                    handler(msg.id, msg.params.get(), cx.clone());
+                } else {
+                    on_unhandled_notification(msg);
+                }
+            } else if let Ok(AnyResponse {
+                id, error, result, ..
+            }) = serde_json::from_slice(&buffer)
+            {
+                if let Some(handler) = response_handlers
+                    .lock()
+                    .as_mut()
+                    .and_then(|handlers| handlers.remove(&id))
+                {
+                    if let Some(error) = error {
+                        handler(Err(error));
+                    } else if let Some(result) = result {
+                        handler(Ok(result.get()));
+                    } else {
+                        handler(Ok("null"));
+                    }
+                }
+            } else {
+                warn!(
+                    "Failed to deserialize message:\n{}",
+                    std::str::from_utf8(&buffer)?
+                );
+            }
+
+            // Don't starve the main thread when receiving lots of messages at once.
+            smol::future::yield_now().await;
+        }
+    }
+
+    async fn handle_output<Stdin>(
+        stdin: Stdin,
+        outbound_rx: channel::Receiver<Vec<u8>>,
+        output_done_tx: barrier::Sender,
+        response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
+    ) -> anyhow::Result<()>
+    where
+        Stdin: AsyncWrite + Unpin + Send + 'static,
+    {
+        let mut stdin = BufWriter::new(stdin);
+        let _clear_response_handlers = util::defer({
+            let response_handlers = response_handlers.clone();
+            move || {
+                response_handlers.lock().take();
+            }
+        });
+        let mut content_len_buffer = Vec::new();
+        while let Ok(message) = outbound_rx.recv().await {
+            log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
+            content_len_buffer.clear();
+            write!(content_len_buffer, "{}", message.len()).unwrap();
+            stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
+            stdin.write_all(&content_len_buffer).await?;
+            stdin.write_all("\r\n\r\n".as_bytes()).await?;
+            stdin.write_all(&message).await?;
+            stdin.flush().await?;
+        }
+        drop(output_done_tx);
+        Ok(())
+    }
+
     /// Initializes a language server.
     /// Note that `options` is used directly to construct [`InitializeParams`],
     /// which is why it is owned.
@@ -389,7 +416,7 @@ impl LanguageServer {
                     output_done.recv().await;
                     log::debug!("language server shutdown finished");
                     drop(tasks);
-                    Ok(())
+                    anyhow::Ok(())
                 }
                 .log_err(),
             )

crates/util/src/util.rs 🔗

@@ -124,11 +124,15 @@ pub trait TryFutureExt {
     fn warn_on_err(self) -> LogErrorFuture<Self>
     where
         Self: Sized;
+    fn unwrap(self) -> UnwrapFuture<Self>
+    where
+        Self: Sized;
 }
 
-impl<F, T> TryFutureExt for F
+impl<F, T, E> TryFutureExt for F
 where
-    F: Future<Output = anyhow::Result<T>>,
+    F: Future<Output = Result<T, E>>,
+    E: std::fmt::Debug,
 {
     fn log_err(self) -> LogErrorFuture<Self>
     where
@@ -143,17 +147,25 @@ where
     {
         LogErrorFuture(self, log::Level::Warn)
     }
+
+    fn unwrap(self) -> UnwrapFuture<Self>
+    where
+        Self: Sized,
+    {
+        UnwrapFuture(self)
+    }
 }
 
 pub struct LogErrorFuture<F>(F, log::Level);
 
-impl<F, T> Future for LogErrorFuture<F>
+impl<F, T, E> Future for LogErrorFuture<F>
 where
-    F: Future<Output = anyhow::Result<T>>,
+    F: Future<Output = Result<T, E>>,
+    E: std::fmt::Debug,
 {
     type Output = Option<T>;
 
-    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
         let level = self.1;
         let inner = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
         match inner.poll(cx) {
@@ -169,6 +181,24 @@ where
     }
 }
 
+pub struct UnwrapFuture<F>(F);
+
+impl<F, T, E> Future for UnwrapFuture<F>
+where
+    F: Future<Output = Result<T, E>>,
+    E: std::fmt::Debug,
+{
+    type Output = T;
+
+    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+        let inner = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0) };
+        match inner.poll(cx) {
+            Poll::Ready(result) => Poll::Ready(result.unwrap()),
+            Poll::Pending => Poll::Pending,
+        }
+    }
+}
+
 struct Defer<F: FnOnce()>(Option<F>);
 
 impl<F: FnOnce()> Drop for Defer<F> {