diff --git a/crates/repl/src/kernels/mod.rs b/crates/repl/src/kernels/mod.rs index 87ed6ed398fdd4f5ef23d7fef04fca9e000f1e73..aaaaa40765e1fa4c8abc3cab4394f4f91d77a6b6 100644 --- a/crates/repl/src/kernels/mod.rs +++ b/crates/repl/src/kernels/mod.rs @@ -1,11 +1,7 @@ mod native_kernel; use std::{fmt::Debug, future::Future, path::PathBuf}; -use futures::{ - channel::mpsc::{self, Receiver}, - future::Shared, - stream, -}; +use futures::{channel::mpsc, future::Shared}; use gpui::{App, Entity, Task, Window}; use language::LanguageName; pub use native_kernel::*; @@ -26,8 +22,6 @@ pub trait KernelSession: Sized { fn kernel_errored(&mut self, error_message: String, cx: &mut Context); } -pub type JupyterMessageChannel = stream::SelectAll>; - #[derive(Debug, Clone, PartialEq, Eq)] pub enum KernelSpecification { Remote(RemoteKernelSpecification), diff --git a/crates/repl/src/kernels/native_kernel.rs b/crates/repl/src/kernels/native_kernel.rs index c3fd57557a2c77c8e1c490aae6e7533cd7d161b6..572626d5323ecca6b3804c692c53eed81b599ae8 100644 --- a/crates/repl/src/kernels/native_kernel.rs +++ b/crates/repl/src/kernels/native_kernel.rs @@ -1,9 +1,9 @@ use anyhow::{Context as _, Result}; use futures::{ - AsyncBufReadExt as _, SinkExt as _, + AsyncBufReadExt as _, FutureExt as _, StreamExt as _, channel::mpsc::{self}, io::BufReader, - stream::{FuturesUnordered, SelectAll, StreamExt}, + stream::FuturesUnordered, }; use gpui::{App, AppContext as _, Entity, EntityId, Task, Window}; use jupyter_protocol::{ @@ -151,46 +151,33 @@ impl NativeRunningKernel { let session_id = Uuid::new_v4().to_string(); - let mut iopub_socket = + let iopub_socket = runtimelib::create_client_iopub_connection(&connection_info, "", &session_id) .await?; - let mut shell_socket = + let shell_socket = runtimelib::create_client_shell_connection(&connection_info, &session_id).await?; - let mut control_socket = + let control_socket = runtimelib::create_client_control_connection(&connection_info, &session_id).await?; + let (mut shell_send, shell_recv) = shell_socket.split(); + let (mut control_send, control_recv) = control_socket.split(); + let (request_tx, mut request_rx) = futures::channel::mpsc::channel::(100); - let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100); - let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100); - - let mut messages_rx = SelectAll::new(); - messages_rx.push(control_reply_rx); - messages_rx.push(shell_reply_rx); - - cx.spawn({ - let session = session.clone(); - - async move |cx| { - while let Some(message) = messages_rx.next().await { - session - .update_in(cx, |session, window, cx| { - session.route(&message, window, cx); - }) - .ok(); - } - } - }) - .detach(); - - // iopub task - let iopub_task = cx.spawn({ + let recv_task = cx.spawn({ let session = session.clone(); + let mut iopub = iopub_socket; + let mut shell = shell_recv; + let mut control = control_recv; async move |cx| -> anyhow::Result<()> { loop { - let message = iopub_socket.read().await?; + let message = futures::select! { + msg = iopub.read().fuse() => msg.context("iopub recv")?, + msg = shell.read().fuse() => msg.context("shell recv")?, + msg = control.read().fuse() => msg.context("control recv")?, + }; session .update_in(cx, |session, window, cx| { session.route(&message, window, cx); @@ -200,10 +187,6 @@ impl NativeRunningKernel { } }); - let (mut control_request_tx, mut control_request_rx) = - futures::channel::mpsc::channel(100); - let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100); - let routing_task = cx.background_spawn({ async move { while let Some(message) = request_rx.next().await { @@ -211,10 +194,10 @@ impl NativeRunningKernel { JupyterMessageContent::DebugRequest(_) | JupyterMessageContent::InterruptRequest(_) | JupyterMessageContent::ShutdownRequest(_) => { - control_request_tx.send(message).await?; + control_send.send(message).await?; } _ => { - shell_request_tx.send(message).await?; + shell_send.send(message).await?; } } } @@ -222,52 +205,31 @@ impl NativeRunningKernel { } }); - let shell_task = cx.background_spawn({ - async move { - while let Some(message) = shell_request_rx.next().await { - shell_socket.send(message).await.ok(); - let reply = shell_socket.read().await?; - shell_reply_tx.send(reply).await?; - } - anyhow::Ok(()) - } - }); - - let control_task = cx.background_spawn({ - async move { - while let Some(message) = control_request_rx.next().await { - control_socket.send(message).await.ok(); - let reply = control_socket.read().await?; - control_reply_tx.send(reply).await?; - } - anyhow::Ok(()) - } - }); - let stderr = process.stderr.take(); - - cx.spawn(async move |_cx| { - if stderr.is_none() { - return; - } - let reader = BufReader::new(stderr.unwrap()); - let mut lines = reader.lines(); - while let Some(Ok(line)) = lines.next().await { - log::error!("kernel: {}", line); - } - }) - .detach(); - let stdout = process.stdout.take(); cx.spawn(async move |_cx| { - if stdout.is_none() { - return; - } - let reader = BufReader::new(stdout.unwrap()); - let mut lines = reader.lines(); - while let Some(Ok(line)) = lines.next().await { - log::info!("kernel: {}", line); + use futures::future::Either; + + let stderr_lines = match stderr { + Some(s) => Either::Left( + BufReader::new(s) + .lines() + .map(|line| (log::Level::Error, line)), + ), + None => Either::Right(futures::stream::empty()), + }; + let stdout_lines = match stdout { + Some(s) => Either::Left( + BufReader::new(s) + .lines() + .map(|line| (log::Level::Info, line)), + ), + None => Either::Right(futures::stream::empty()), + }; + let mut lines = futures::stream::select(stderr_lines, stdout_lines); + while let Some((level, Ok(line))) = lines.next().await { + log::log!(level, "kernel: {}", line); } }) .detach(); @@ -283,9 +245,7 @@ impl NativeRunningKernel { } let mut tasks = FuturesUnordered::new(); - tasks.push(with_name("iopub task", iopub_task)); - tasks.push(with_name("shell task", shell_task)); - tasks.push(with_name("control task", control_task)); + tasks.push(with_name("recv task", recv_task)); tasks.push(with_name("routing task", routing_task)); while let Some((name, result)) = tasks.next().await {