repl: Use split() on shell and control dealer sockets (#48823)

Kyle Kelley created

Hot on the heels of https://github.com/zed-industries/zed/pull/48817 I'm
bringing the best improvement to the repl underneath: `split()`-able
sockets! Much more will be unlocked by having this.

This split the shell and control `DealerSocket` connections into
independent send/recv halves using the new `split()` API from zeromq
0.5.0 and runtimelib 1.x. This also nicely cleaned things up so we could
have a single `select!` loop over iopub, shell, and control recv halves.
That replaces three separate recv tasks.

This likely closes some issues for certain kernels that would get stuck
either during startup or other flows due to them not sending replies to
specific requests. I'll see if I can find issues around this and update
the release notes after.

This allows us to unlock some nifty new things we can do on the shell
socket, particularly autocompletion for in-memory values, stdin support,
and others. I _think_ it also help with sending and receving
`KernelInfo`, which not all kernels do properly at the start. This makes
us a bit more resilient to errant kernels.

Release Notes:

- N/A

Change summary

crates/repl/src/kernels/mod.rs           |   8 -
crates/repl/src/kernels/native_kernel.rs | 122 ++++++++-----------------
2 files changed, 42 insertions(+), 88 deletions(-)

Detailed changes

crates/repl/src/kernels/mod.rs 🔗

@@ -1,11 +1,7 @@
 mod native_kernel;
 use std::{fmt::Debug, future::Future, path::PathBuf};
 
-use futures::{
-    channel::mpsc::{self, Receiver},
-    future::Shared,
-    stream,
-};
+use futures::{channel::mpsc, future::Shared};
 use gpui::{App, Entity, Task, Window};
 use language::LanguageName;
 pub use native_kernel::*;
@@ -26,8 +22,6 @@ pub trait KernelSession: Sized {
     fn kernel_errored(&mut self, error_message: String, cx: &mut Context<Self>);
 }
 
-pub type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
-
 #[derive(Debug, Clone, PartialEq, Eq)]
 pub enum KernelSpecification {
     Remote(RemoteKernelSpecification),

crates/repl/src/kernels/native_kernel.rs 🔗

@@ -1,9 +1,9 @@
 use anyhow::{Context as _, Result};
 use futures::{
-    AsyncBufReadExt as _, SinkExt as _,
+    AsyncBufReadExt as _, FutureExt as _, StreamExt as _,
     channel::mpsc::{self},
     io::BufReader,
-    stream::{FuturesUnordered, SelectAll, StreamExt},
+    stream::FuturesUnordered,
 };
 use gpui::{App, AppContext as _, Entity, EntityId, Task, Window};
 use jupyter_protocol::{
@@ -151,46 +151,33 @@ impl NativeRunningKernel {
 
             let session_id = Uuid::new_v4().to_string();
 
-            let mut iopub_socket =
+            let iopub_socket =
                 runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
                     .await?;
-            let mut shell_socket =
+            let shell_socket =
                 runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
-            let mut control_socket =
+            let control_socket =
                 runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
 
+            let (mut shell_send, shell_recv) = shell_socket.split();
+            let (mut control_send, control_recv) = control_socket.split();
+
             let (request_tx, mut request_rx) =
                 futures::channel::mpsc::channel::<JupyterMessage>(100);
 
-            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
-            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
-
-            let mut messages_rx = SelectAll::new();
-            messages_rx.push(control_reply_rx);
-            messages_rx.push(shell_reply_rx);
-
-            cx.spawn({
-                let session = session.clone();
-
-                async move |cx| {
-                    while let Some(message) = messages_rx.next().await {
-                        session
-                            .update_in(cx, |session, window, cx| {
-                                session.route(&message, window, cx);
-                            })
-                            .ok();
-                    }
-                }
-            })
-            .detach();
-
-            // iopub task
-            let iopub_task = cx.spawn({
+            let recv_task = cx.spawn({
                 let session = session.clone();
+                let mut iopub = iopub_socket;
+                let mut shell = shell_recv;
+                let mut control = control_recv;
 
                 async move |cx| -> anyhow::Result<()> {
                     loop {
-                        let message = iopub_socket.read().await?;
+                        let message = futures::select! {
+                            msg = iopub.read().fuse() => msg.context("iopub recv")?,
+                            msg = shell.read().fuse() => msg.context("shell recv")?,
+                            msg = control.read().fuse() => msg.context("control recv")?,
+                        };
                         session
                             .update_in(cx, |session, window, cx| {
                                 session.route(&message, window, cx);
@@ -200,10 +187,6 @@ impl NativeRunningKernel {
                 }
             });
 
-            let (mut control_request_tx, mut control_request_rx) =
-                futures::channel::mpsc::channel(100);
-            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
-
             let routing_task = cx.background_spawn({
                 async move {
                     while let Some(message) = request_rx.next().await {
@@ -211,10 +194,10 @@ impl NativeRunningKernel {
                             JupyterMessageContent::DebugRequest(_)
                             | JupyterMessageContent::InterruptRequest(_)
                             | JupyterMessageContent::ShutdownRequest(_) => {
-                                control_request_tx.send(message).await?;
+                                control_send.send(message).await?;
                             }
                             _ => {
-                                shell_request_tx.send(message).await?;
+                                shell_send.send(message).await?;
                             }
                         }
                     }
@@ -222,52 +205,31 @@ impl NativeRunningKernel {
                 }
             });
 
-            let shell_task = cx.background_spawn({
-                async move {
-                    while let Some(message) = shell_request_rx.next().await {
-                        shell_socket.send(message).await.ok();
-                        let reply = shell_socket.read().await?;
-                        shell_reply_tx.send(reply).await?;
-                    }
-                    anyhow::Ok(())
-                }
-            });
-
-            let control_task = cx.background_spawn({
-                async move {
-                    while let Some(message) = control_request_rx.next().await {
-                        control_socket.send(message).await.ok();
-                        let reply = control_socket.read().await?;
-                        control_reply_tx.send(reply).await?;
-                    }
-                    anyhow::Ok(())
-                }
-            });
-
             let stderr = process.stderr.take();
-
-            cx.spawn(async move |_cx| {
-                if stderr.is_none() {
-                    return;
-                }
-                let reader = BufReader::new(stderr.unwrap());
-                let mut lines = reader.lines();
-                while let Some(Ok(line)) = lines.next().await {
-                    log::error!("kernel: {}", line);
-                }
-            })
-            .detach();
-
             let stdout = process.stdout.take();
 
             cx.spawn(async move |_cx| {
-                if stdout.is_none() {
-                    return;
-                }
-                let reader = BufReader::new(stdout.unwrap());
-                let mut lines = reader.lines();
-                while let Some(Ok(line)) = lines.next().await {
-                    log::info!("kernel: {}", line);
+                use futures::future::Either;
+
+                let stderr_lines = match stderr {
+                    Some(s) => Either::Left(
+                        BufReader::new(s)
+                            .lines()
+                            .map(|line| (log::Level::Error, line)),
+                    ),
+                    None => Either::Right(futures::stream::empty()),
+                };
+                let stdout_lines = match stdout {
+                    Some(s) => Either::Left(
+                        BufReader::new(s)
+                            .lines()
+                            .map(|line| (log::Level::Info, line)),
+                    ),
+                    None => Either::Right(futures::stream::empty()),
+                };
+                let mut lines = futures::stream::select(stderr_lines, stdout_lines);
+                while let Some((level, Ok(line))) = lines.next().await {
+                    log::log!(level, "kernel: {}", line);
                 }
             })
             .detach();
@@ -283,9 +245,7 @@ impl NativeRunningKernel {
                     }
 
                     let mut tasks = FuturesUnordered::new();
-                    tasks.push(with_name("iopub task", iopub_task));
-                    tasks.push(with_name("shell task", shell_task));
-                    tasks.push(with_name("control task", control_task));
+                    tasks.push(with_name("recv task", recv_task));
                     tasks.push(with_name("routing task", routing_task));
 
                     while let Some((name, result)) = tasks.next().await {