repl: Add restart kernel action and improve shutdown (#16609)

Kyle Kelley created

- Implement restart kernel functionality
- Clean up shutdown process to properly drop messaging and exit status
tasks
- Refactor kernel state handling for better consistency

Closes #16037

Release Notes:

- repl: Added restart kernel action
- repl: Fixed issue with shutting down kernels that are in a failure
state

Change summary

crates/quick_action_bar/src/repl_menu.rs |  25 ++++
crates/repl/src/kernels.rs               |   8 +
crates/repl/src/outputs.rs               |   4 
crates/repl/src/repl.rs                  |   2 
crates/repl/src/repl_editor.rs           |  21 +++
crates/repl/src/repl_sessions_ui.rs      |  14 ++
crates/repl/src/session.rs               | 144 +++++++++++++++++++------
7 files changed, 177 insertions(+), 41 deletions(-)

Detailed changes

crates/quick_action_bar/src/repl_menu.rs 🔗

@@ -173,8 +173,6 @@ impl QuickActionBar {
                             url: format!("{}#change-kernel", ZED_REPL_DOCUMENTATION),
                         }),
                     )
-                    // TODO: Add Restart action
-                    // .action("Restart", Box::new(gpui::NoAction))
                     .custom_entry(
                         move |_cx| {
                             Label::new("Shut Down Kernel")
@@ -189,6 +187,20 @@ impl QuickActionBar {
                             }
                         },
                     )
+                    .custom_entry(
+                        move |_cx| {
+                            Label::new("Restart Kernel")
+                                .size(LabelSize::Small)
+                                .color(Color::Error)
+                                .into_any_element()
+                        },
+                        {
+                            let editor = editor.clone();
+                            move |cx| {
+                                repl::restart(editor.clone(), cx);
+                            }
+                        },
+                    )
                     .separator()
                     .action("View Sessions", Box::new(repl::Sessions))
                     // TODO: Add shut down all kernels action
@@ -305,6 +317,15 @@ fn session_state(session: View<Session>, cx: &WindowContext) -> ReplMenuState {
     };
 
     let menu_state = match &session.kernel {
+        Kernel::Restarting => ReplMenuState {
+            tooltip: format!("Restarting {}", kernel_name).into(),
+            icon_is_animating: true,
+            popover_disabled: true,
+            icon_color: Color::Muted,
+            indicator: Some(Indicator::dot().color(Color::Muted)),
+            status: session.kernel.status(),
+            ..fill_fields()
+        },
         Kernel::RunningKernel(kernel) => match &kernel.execution_state {
             ExecutionState::Idle => ReplMenuState {
                 tooltip: format!("Run code on {} ({})", kernel_name, kernel_language).into(),

crates/repl/src/kernels.rs 🔗

@@ -87,6 +87,7 @@ pub enum KernelStatus {
     Error,
     ShuttingDown,
     Shutdown,
+    Restarting,
 }
 
 impl KernelStatus {
@@ -107,6 +108,7 @@ impl ToString for KernelStatus {
             KernelStatus::Error => "Error".to_string(),
             KernelStatus::ShuttingDown => "Shutting Down".to_string(),
             KernelStatus::Shutdown => "Shutdown".to_string(),
+            KernelStatus::Restarting => "Restarting".to_string(),
         }
     }
 }
@@ -122,6 +124,7 @@ impl From<&Kernel> for KernelStatus {
             Kernel::ErroredLaunch(_) => KernelStatus::Error,
             Kernel::ShuttingDown => KernelStatus::ShuttingDown,
             Kernel::Shutdown => KernelStatus::Shutdown,
+            Kernel::Restarting => KernelStatus::Restarting,
         }
     }
 }
@@ -133,6 +136,7 @@ pub enum Kernel {
     ErroredLaunch(String),
     ShuttingDown,
     Shutdown,
+    Restarting,
 }
 
 impl Kernel {
@@ -160,7 +164,7 @@ impl Kernel {
 
     pub fn is_shutting_down(&self) -> bool {
         match self {
-            Kernel::ShuttingDown => true,
+            Kernel::Restarting | Kernel::ShuttingDown => true,
             Kernel::RunningKernel(_)
             | Kernel::StartingKernel(_)
             | Kernel::ErroredLaunch(_)
@@ -324,7 +328,7 @@ impl RunningKernel {
                     _control_task: control_task,
                     _routing_task: routing_task,
                     connection_path,
-                    execution_state: ExecutionState::Busy,
+                    execution_state: ExecutionState::Idle,
                     kernel_info: None,
                 },
                 messages_rx,

crates/repl/src/outputs.rs 🔗

@@ -420,6 +420,7 @@ pub enum ExecutionStatus {
     ShuttingDown,
     Shutdown,
     KernelErrored(String),
+    Restarting,
 }
 
 pub struct ExecutionView {
@@ -613,6 +614,9 @@ impl Render for ExecutionView {
             ExecutionStatus::ShuttingDown => Label::new("Kernel shutting down...")
                 .color(Color::Muted)
                 .into_any_element(),
+            ExecutionStatus::Restarting => Label::new("Kernel restarting...")
+                .color(Color::Muted)
+                .into_any_element(),
             ExecutionStatus::Shutdown => Label::new("Kernel shutdown")
                 .color(Color::Muted)
                 .into_any_element(),

crates/repl/src/repl.rs 🔗

@@ -20,7 +20,7 @@ pub use crate::jupyter_settings::JupyterSettings;
 pub use crate::kernels::{Kernel, KernelSpecification, KernelStatus};
 pub use crate::repl_editor::*;
 pub use crate::repl_sessions_ui::{
-    ClearOutputs, Interrupt, ReplSessionsPage, Run, Sessions, Shutdown,
+    ClearOutputs, Interrupt, ReplSessionsPage, Restart, Run, Sessions, Shutdown,
 };
 use crate::repl_store::ReplStore;
 pub use crate::session::Session;

crates/repl/src/repl_editor.rs 🔗

@@ -168,6 +168,27 @@ pub fn shutdown(editor: WeakView<Editor>, cx: &mut WindowContext) {
     });
 }
 
+pub fn restart(editor: WeakView<Editor>, cx: &mut WindowContext) {
+    let Some(editor) = editor.upgrade() else {
+        return;
+    };
+
+    let entity_id = editor.entity_id();
+
+    let Some(session) = ReplStore::global(cx)
+        .read(cx)
+        .get_session(entity_id)
+        .cloned()
+    else {
+        return;
+    };
+
+    session.update(cx, |session, cx| {
+        session.restart(cx);
+        cx.notify();
+    });
+}
+
 fn cell_range(buffer: &BufferSnapshot, start_row: u32, end_row: u32) -> Range<Point> {
     let mut snippet_end_row = end_row;
     while buffer.is_line_blank(snippet_end_row) && snippet_end_row > start_row {

crates/repl/src/repl_sessions_ui.rs 🔗

@@ -23,6 +23,7 @@ actions!(
         Sessions,
         Interrupt,
         Shutdown,
+        Restart,
         RefreshKernelspecs
     ]
 );
@@ -126,6 +127,19 @@ pub fn init(cx: &mut AppContext) {
                 }
             })
             .detach();
+
+        editor
+            .register_action({
+                let editor_handle = editor_handle.clone();
+                move |_: &Restart, cx| {
+                    if !JupyterSettings::enabled(cx) {
+                        return;
+                    }
+
+                    crate::restart(editor_handle.clone(), cx);
+                }
+            })
+            .detach();
     })
     .detach();
 }

crates/repl/src/session.rs 🔗

@@ -31,10 +31,12 @@ use theme::ActiveTheme;
 use ui::{prelude::*, IconButtonShape, Tooltip};
 
 pub struct Session {
+    fs: Arc<dyn Fs>,
     editor: WeakView<Editor>,
     pub kernel: Kernel,
     blocks: HashMap<String, EditorBlock>,
-    messaging_task: Task<()>,
+    messaging_task: Option<Task<()>>,
+    process_status_task: Option<Task<()>>,
     pub kernel_specification: KernelSpecification,
     telemetry: Arc<Telemetry>,
     _buffer_subscription: Subscription,
@@ -192,24 +194,50 @@ impl Session {
         kernel_specification: KernelSpecification,
         cx: &mut ViewContext<Self>,
     ) -> Self {
-        let kernel_language = kernel_specification.kernelspec.language.clone();
+        let subscription = match editor.upgrade() {
+            Some(editor) => {
+                let buffer = editor.read(cx).buffer().clone();
+                cx.subscribe(&buffer, Self::on_buffer_event)
+            }
+            None => Subscription::new(|| {}),
+        };
+
+        let mut session = Self {
+            fs,
+            editor,
+            kernel: Kernel::StartingKernel(Task::ready(()).shared()),
+            messaging_task: None,
+            process_status_task: None,
+            blocks: HashMap::default(),
+            kernel_specification,
+            _buffer_subscription: subscription,
+            telemetry,
+        };
+
+        session.start_kernel(cx);
+        session
+    }
+
+    fn start_kernel(&mut self, cx: &mut ViewContext<Self>) {
+        let kernel_language = self.kernel_specification.kernelspec.language.clone();
+        let entity_id = self.editor.entity_id();
+        let working_directory = self
+            .editor
+            .upgrade()
+            .and_then(|editor| editor.read(cx).working_directory(cx))
+            .unwrap_or_else(temp_dir);
 
-        telemetry.report_repl_event(
+        self.telemetry.report_repl_event(
             kernel_language.clone(),
             KernelStatus::Starting.to_string(),
             cx.entity_id().to_string(),
         );
 
-        let entity_id = editor.entity_id();
-        let working_directory = editor
-            .upgrade()
-            .and_then(|editor| editor.read(cx).working_directory(cx))
-            .unwrap_or_else(temp_dir);
         let kernel = RunningKernel::new(
-            kernel_specification.clone(),
+            self.kernel_specification.clone(),
             entity_id,
             working_directory,
-            fs.clone(),
+            self.fs.clone(),
             cx,
         );
 
@@ -229,6 +257,7 @@ impl Session {
                                 let reader = BufReader::new(stderr.unwrap());
                                 let mut lines = reader.lines();
                                 while let Some(Ok(line)) = lines.next().await {
+                                    // todo!(): Log stdout and stderr to something the session can show
                                     log::error!("kernel: {}", line);
                                 }
                             })
@@ -251,7 +280,7 @@ impl Session {
                             let status = kernel.process.status();
                             session.kernel(Kernel::RunningKernel(kernel), cx);
 
-                            cx.spawn(|session, mut cx| async move {
+                            let process_status_task = cx.spawn(|session, mut cx| async move {
                                 let error_message = match status.await {
                                     Ok(status) => {
                                         if status.success() {
@@ -299,10 +328,11 @@ impl Session {
                                         cx.notify();
                                     })
                                     .ok();
-                            })
-                            .detach();
+                            });
+
+                            session.process_status_task = Some(process_status_task);
 
-                            session.messaging_task = cx.spawn(|session, mut cx| async move {
+                            session.messaging_task = Some(cx.spawn(|session, mut cx| async move {
                                 while let Some(message) = messages_rx.next().await {
                                     session
                                         .update(&mut cx, |session, cx| {
@@ -310,9 +340,9 @@ impl Session {
                                         })
                                         .ok();
                                 }
-                            });
+                            }));
 
-                            // todo!(@rgbkrk): send kernelinforequest once our shell channel read/writes are split
+                            // todo!(@rgbkrk): send KernelInfoRequest once our shell channel read/writes are split
                             // cx.spawn(|this, mut cx| async move {
                             //     cx.background_executor()
                             //         .timer(Duration::from_millis(120))
@@ -336,23 +366,8 @@ impl Session {
             })
             .shared();
 
-        let subscription = match editor.upgrade() {
-            Some(editor) => {
-                let buffer = editor.read(cx).buffer().clone();
-                cx.subscribe(&buffer, Self::on_buffer_event)
-            }
-            None => Subscription::new(|| {}),
-        };
-
-        return Self {
-            editor,
-            kernel: Kernel::StartingKernel(pending_kernel),
-            messaging_task: Task::ready(()),
-            blocks: HashMap::default(),
-            kernel_specification,
-            _buffer_subscription: subscription,
-            telemetry,
-        };
+        self.kernel(Kernel::StartingKernel(pending_kernel), cx);
+        cx.notify();
     }
 
     fn on_buffer_event(
@@ -453,6 +468,7 @@ impl Session {
             .ok();
 
         let status = match &self.kernel {
+            Kernel::Restarting => ExecutionStatus::Restarting,
             Kernel::RunningKernel(_) => ExecutionStatus::Queued,
             Kernel::StartingKernel(_) => ExecutionStatus::ConnectingToKernel,
             Kernel::ErroredLaunch(error) => ExecutionStatus::KernelErrored(error.clone()),
@@ -615,6 +631,12 @@ impl Session {
                     // Give the kernel a bit of time to clean up
                     cx.background_executor().timer(Duration::from_secs(3)).await;
 
+                    this.update(&mut cx, |session, _cx| {
+                        session.messaging_task.take();
+                        session.process_status_task.take();
+                    })
+                    .ok();
+
                     kernel.process.kill().ok();
 
                     this.update(&mut cx, |session, cx| {
@@ -626,11 +648,59 @@ impl Session {
                 })
                 .detach();
             }
-            Kernel::StartingKernel(_kernel) => {
-                self.kernel = Kernel::Shutdown;
+            _ => {
+                self.messaging_task.take();
+                self.process_status_task.take();
+                self.kernel(Kernel::Shutdown, cx);
+            }
+        }
+        cx.notify();
+    }
+
+    pub fn restart(&mut self, cx: &mut ViewContext<Self>) {
+        let kernel = std::mem::replace(&mut self.kernel, Kernel::Restarting);
+
+        match kernel {
+            Kernel::Restarting => {
+                // Do nothing if already restarting
+            }
+            Kernel::RunningKernel(mut kernel) => {
+                let mut request_tx = kernel.request_tx.clone();
+
+                cx.spawn(|this, mut cx| async move {
+                    // Send shutdown request with restart flag
+                    log::debug!("restarting kernel");
+                    let message: JupyterMessage = ShutdownRequest { restart: true }.into();
+                    request_tx.try_send(message).ok();
+
+                    this.update(&mut cx, |session, _cx| {
+                        session.messaging_task.take();
+                        session.process_status_task.take();
+                    })
+                    .ok();
+
+                    // Wait for kernel to shutdown
+                    cx.background_executor().timer(Duration::from_secs(1)).await;
+
+                    // Force kill the kernel if it hasn't shut down
+                    kernel.process.kill().ok();
+
+                    // Start a new kernel
+                    this.update(&mut cx, |session, cx| {
+                        // todo!(): Differentiate between restart and restart+clear-outputs
+                        session.clear_outputs(cx);
+                        session.start_kernel(cx);
+                    })
+                    .ok();
+                })
+                .detach();
             }
             _ => {
-                self.kernel = Kernel::Shutdown;
+                // If it's not already running, we can just clean up and start a new kernel
+                self.messaging_task.take();
+                self.process_status_task.take();
+                self.clear_outputs(cx);
+                self.start_kernel(cx);
             }
         }
         cx.notify();
@@ -663,6 +733,7 @@ impl Render for Session {
             Kernel::ErroredLaunch(err) => (Some(format!("Error: {err}")), None),
             Kernel::ShuttingDown => (Some("Shutting Down".into()), None),
             Kernel::Shutdown => (Some("Shutdown".into()), None),
+            Kernel::Restarting => (Some("Restarting".into()), None),
         };
 
         KernelListItem::new(self.kernel_specification.clone())
@@ -675,6 +746,7 @@ impl Render for Session {
                 Kernel::ErroredLaunch(_) => Color::Error,
                 Kernel::ShuttingDown => Color::Modified,
                 Kernel::Shutdown => Color::Disabled,
+                Kernel::Restarting => Color::Modified,
             })
             .child(Label::new(self.kernel_specification.name.clone()))
             .children(status_text.map(|status_text| Label::new(format!("({status_text})"))))