debugger: Kill debug sessions on app quit (#33273)

Anthony Eid , Conrad Irwin , and Remco Smits created

Before this PR force quitting Zed would leave hanging debug adapter
processes and not allow debug adapters to clean up their sessions
properly.

This PR fixes this problem by sending a disconnect/terminate to all
debug adapters and force shutting down their processes after they
respond.

Co-authored-by: Cole Miller \<cole@zed.dev\>

Release Notes:

- debugger: Shutdown and clean up debug processes when force quitting
Zed

---------

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>
Co-authored-by: Remco Smits <djsmits12@gmail.com>

Change summary

.zed/debug.json                                |  18 +
crates/dap/src/client.rs                       |  11 
crates/dap/src/transport.rs                    |  63 +++---
crates/debugger_ui/src/session/running.rs      |  10 +
crates/debugger_ui/src/tests/debugger_panel.rs | 192 ++++++++++++++++++++
crates/project/src/debugger/session.rs         |  57 +++-
6 files changed, 288 insertions(+), 63 deletions(-)

Detailed changes

.zed/debug.json 🔗

@@ -2,11 +2,23 @@
   {
     "label": "Debug Zed (CodeLLDB)",
     "adapter": "CodeLLDB",
-    "build": { "label": "Build Zed", "command": "cargo", "args": ["build"] }
+    "build": {
+      "label": "Build Zed",
+      "command": "cargo",
+      "args": [
+        "build"
+      ]
+    }
   },
   {
     "label": "Debug Zed (GDB)",
     "adapter": "GDB",
-    "build": { "label": "Build Zed", "command": "cargo", "args": ["build"] }
-  }
+    "build": {
+      "label": "Build Zed",
+      "command": "cargo",
+      "args": [
+        "build"
+      ]
+    }
+  },
 ]

crates/dap/src/client.rs 🔗

@@ -163,8 +163,9 @@ impl DebugAdapterClient {
         self.sequence_count.fetch_add(1, Ordering::Relaxed)
     }
 
-    pub async fn shutdown(&self) -> Result<()> {
-        self.transport_delegate.shutdown().await
+    pub fn kill(&self) {
+        log::debug!("Killing DAP process");
+        self.transport_delegate.transport.lock().kill();
     }
 
     pub fn has_adapter_logs(&self) -> bool {
@@ -315,8 +316,6 @@ mod tests {
             },
             response
         );
-
-        client.shutdown().await.unwrap();
     }
 
     #[gpui::test]
@@ -368,8 +367,6 @@ mod tests {
             called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
             "Event handler was not called"
         );
-
-        client.shutdown().await.unwrap();
     }
 
     #[gpui::test]
@@ -433,7 +430,5 @@ mod tests {
             called_event_handler.load(std::sync::atomic::Ordering::SeqCst),
             "Event handler was not called"
         );
-
-        client.shutdown().await.unwrap();
     }
 }

crates/dap/src/transport.rs 🔗

@@ -63,7 +63,7 @@ pub trait Transport: Send + Sync {
             Box<dyn AsyncRead + Unpin + Send + 'static>,
         )>,
     >;
-    fn kill(&self);
+    fn kill(&mut self);
     #[cfg(any(test, feature = "test-support"))]
     fn as_fake(&self) -> &FakeTransport {
         unreachable!()
@@ -93,12 +93,18 @@ async fn start(
 
 pub(crate) struct TransportDelegate {
     log_handlers: LogHandlers,
-    pending_requests: Requests,
+    pub(crate) pending_requests: Requests,
     pub(crate) transport: Mutex<Box<dyn Transport>>,
-    server_tx: smol::lock::Mutex<Option<Sender<Message>>>,
+    pub(crate) server_tx: smol::lock::Mutex<Option<Sender<Message>>>,
     tasks: Mutex<Vec<Task<()>>>,
 }
 
+impl Drop for TransportDelegate {
+    fn drop(&mut self) {
+        self.transport.lock().kill()
+    }
+}
+
 impl TransportDelegate {
     pub(crate) async fn start(binary: &DebugAdapterBinary, cx: &mut AsyncApp) -> Result<Self> {
         let log_handlers: LogHandlers = Default::default();
@@ -354,7 +360,6 @@ impl TransportDelegate {
         let mut content_length = None;
         loop {
             buffer.truncate(0);
-
             match reader.read_line(buffer).await {
                 Ok(0) => return ConnectionResult::ConnectionReset,
                 Ok(_) => {}
@@ -412,21 +417,6 @@ impl TransportDelegate {
         ConnectionResult::Result(message)
     }
 
-    pub async fn shutdown(&self) -> Result<()> {
-        log::debug!("Start shutdown client");
-
-        if let Some(server_tx) = self.server_tx.lock().await.take().as_ref() {
-            server_tx.close();
-        }
-
-        self.pending_requests.lock().clear();
-        self.transport.lock().kill();
-
-        log::debug!("Shutdown client completed");
-
-        anyhow::Ok(())
-    }
-
     pub fn has_adapter_logs(&self) -> bool {
         self.transport.lock().has_adapter_logs()
     }
@@ -546,7 +536,7 @@ impl Transport for TcpTransport {
         true
     }
 
-    fn kill(&self) {
+    fn kill(&mut self) {
         if let Some(process) = &mut *self.process.lock() {
             process.kill();
         }
@@ -613,13 +603,13 @@ impl Transport for TcpTransport {
 impl Drop for TcpTransport {
     fn drop(&mut self) {
         if let Some(mut p) = self.process.lock().take() {
-            p.kill();
+            p.kill()
         }
     }
 }
 
 pub struct StdioTransport {
-    process: Mutex<Child>,
+    process: Mutex<Option<Child>>,
     _stderr_task: Option<Task<()>>,
 }
 
@@ -660,7 +650,7 @@ impl StdioTransport {
             ))
         });
 
-        let process = Mutex::new(process);
+        let process = Mutex::new(Some(process));
 
         Ok(Self {
             process,
@@ -674,8 +664,10 @@ impl Transport for StdioTransport {
         false
     }
 
-    fn kill(&self) {
-        self.process.lock().kill()
+    fn kill(&mut self) {
+        if let Some(process) = &mut *self.process.lock() {
+            process.kill();
+        }
     }
 
     fn connect(
@@ -686,8 +678,9 @@ impl Transport for StdioTransport {
             Box<dyn AsyncRead + Unpin + Send + 'static>,
         )>,
     > {
-        let mut process = self.process.lock();
         let result = util::maybe!({
+            let mut guard = self.process.lock();
+            let process = guard.as_mut().context("oops")?;
             Ok((
                 Box::new(process.stdin.take().context("Cannot reconnect")?) as _,
                 Box::new(process.stdout.take().context("Cannot reconnect")?) as _,
@@ -703,7 +696,9 @@ impl Transport for StdioTransport {
 
 impl Drop for StdioTransport {
     fn drop(&mut self) {
-        self.process.get_mut().kill();
+        if let Some(process) = &mut *self.process.lock() {
+            process.kill();
+        }
     }
 }
 
@@ -723,6 +718,7 @@ pub struct FakeTransport {
 
     stdin_writer: Option<PipeWriter>,
     stdout_reader: Option<PipeReader>,
+    message_handler: Option<Task<Result<()>>>,
 }
 
 #[cfg(any(test, feature = "test-support"))]
@@ -774,18 +770,19 @@ impl FakeTransport {
         let (stdin_writer, stdin_reader) = async_pipe::pipe();
         let (stdout_writer, stdout_reader) = async_pipe::pipe();
 
-        let this = Self {
+        let mut this = Self {
             request_handlers: Arc::new(Mutex::new(HashMap::default())),
             response_handlers: Arc::new(Mutex::new(HashMap::default())),
             stdin_writer: Some(stdin_writer),
             stdout_reader: Some(stdout_reader),
+            message_handler: None,
         };
 
         let request_handlers = this.request_handlers.clone();
         let response_handlers = this.response_handlers.clone();
         let stdout_writer = Arc::new(smol::lock::Mutex::new(stdout_writer));
 
-        cx.background_spawn(async move {
+        this.message_handler = Some(cx.background_spawn(async move {
             let mut reader = BufReader::new(stdin_reader);
             let mut buffer = String::new();
 
@@ -833,7 +830,6 @@ impl FakeTransport {
                                             .unwrap();
 
                                     let mut writer = stdout_writer.lock().await;
-
                                     writer
                                         .write_all(
                                             TransportDelegate::build_rpc_message(message)
@@ -870,8 +866,7 @@ impl FakeTransport {
                     }
                 }
             }
-        })
-        .detach();
+        }));
 
         Ok(this)
     }
@@ -904,7 +899,9 @@ impl Transport for FakeTransport {
         false
     }
 
-    fn kill(&self) {}
+    fn kill(&mut self) {
+        self.message_handler.take();
+    }
 
     #[cfg(any(test, feature = "test-support"))]
     fn as_fake(&self) -> &FakeTransport {

crates/debugger_ui/src/session/running.rs 🔗

@@ -701,6 +701,16 @@ impl RunningState {
             BreakpointList::new(Some(session.clone()), workspace.clone(), &project, cx);
 
         let _subscriptions = vec![
+            cx.on_app_quit(move |this, cx| {
+                let shutdown = this
+                    .session
+                    .update(cx, |session, cx| session.on_app_quit(cx));
+                let terminal = this.debug_terminal.clone();
+                async move {
+                    shutdown.await;
+                    drop(terminal)
+                }
+            }),
             cx.observe(&module_list, |_, _, cx| cx.notify()),
             cx.subscribe_in(&session, window, |this, _, event, window, cx| {
                 match event {

crates/debugger_ui/src/tests/debugger_panel.rs 🔗

@@ -1755,3 +1755,195 @@ async fn test_active_debug_line_setting(executor: BackgroundExecutor, cx: &mut T
         );
     });
 }
+
+#[gpui::test]
+async fn test_debug_adapters_shutdown_on_app_quit(
+    executor: BackgroundExecutor,
+    cx: &mut TestAppContext,
+) {
+    init_test(cx);
+
+    let fs = FakeFs::new(executor.clone());
+
+    fs.insert_tree(
+        path!("/project"),
+        json!({
+            "main.rs": "First line\nSecond line\nThird line\nFourth line",
+        }),
+    )
+    .await;
+
+    let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
+    let workspace = init_test_workspace(&project, cx).await;
+    let cx = &mut VisualTestContext::from_window(*workspace, cx);
+
+    let session = start_debug_session(&workspace, cx, |_| {}).unwrap();
+    let client = session.update(cx, |session, _| session.adapter_client().unwrap());
+
+    let disconnect_request_received = Arc::new(AtomicBool::new(false));
+    let disconnect_clone = disconnect_request_received.clone();
+
+    let disconnect_clone_for_handler = disconnect_clone.clone();
+    client.on_request::<Disconnect, _>(move |_, _| {
+        disconnect_clone_for_handler.store(true, Ordering::SeqCst);
+        Ok(())
+    });
+
+    executor.run_until_parked();
+
+    workspace
+        .update(cx, |workspace, _, cx| {
+            let panel = workspace.panel::<DebugPanel>(cx).unwrap();
+            panel.read_with(cx, |panel, _| {
+                assert!(
+                    !panel.sessions().is_empty(),
+                    "Debug session should be active"
+                );
+            });
+        })
+        .unwrap();
+
+    cx.update(|_, cx| cx.defer(|cx| cx.shutdown()));
+
+    executor.run_until_parked();
+
+    assert!(
+        disconnect_request_received.load(Ordering::SeqCst),
+        "Disconnect request should have been sent to the adapter on app shutdown"
+    );
+}
+
+#[gpui::test]
+async fn test_adapter_shutdown_with_child_sessions_on_app_quit(
+    executor: BackgroundExecutor,
+    cx: &mut TestAppContext,
+) {
+    init_test(cx);
+
+    let fs = FakeFs::new(executor.clone());
+
+    fs.insert_tree(
+        path!("/project"),
+        json!({
+            "main.rs": "First line\nSecond line\nThird line\nFourth line",
+        }),
+    )
+    .await;
+
+    let project = Project::test(fs, [path!("/project").as_ref()], cx).await;
+    let workspace = init_test_workspace(&project, cx).await;
+    let cx = &mut VisualTestContext::from_window(*workspace, cx);
+
+    let parent_session = start_debug_session(&workspace, cx, |_| {}).unwrap();
+    let parent_session_id = cx.read(|cx| parent_session.read(cx).session_id());
+    let parent_client = parent_session.update(cx, |session, _| session.adapter_client().unwrap());
+
+    let disconnect_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
+    let parent_disconnect_called = Arc::new(AtomicBool::new(false));
+    let parent_disconnect_clone = parent_disconnect_called.clone();
+    let disconnect_count_clone = disconnect_count.clone();
+
+    parent_client.on_request::<Disconnect, _>(move |_, _| {
+        parent_disconnect_clone.store(true, Ordering::SeqCst);
+        disconnect_count_clone.fetch_add(1, Ordering::SeqCst);
+
+        for _ in 0..50 {
+            if disconnect_count_clone.load(Ordering::SeqCst) >= 2 {
+                break;
+            }
+            std::thread::sleep(std::time::Duration::from_millis(1));
+        }
+
+        Ok(())
+    });
+
+    parent_client
+        .on_response::<StartDebugging, _>(move |_| {})
+        .await;
+    let _subscription = project::debugger::test::intercept_debug_sessions(cx, |_| {});
+
+    parent_client
+        .fake_reverse_request::<StartDebugging>(StartDebuggingRequestArguments {
+            configuration: json!({}),
+            request: StartDebuggingRequestArgumentsRequest::Launch,
+        })
+        .await;
+
+    cx.run_until_parked();
+
+    let child_session = project.update(cx, |project, cx| {
+        project
+            .dap_store()
+            .read(cx)
+            .session_by_id(SessionId(1))
+            .unwrap()
+    });
+    let child_session_id = cx.read(|cx| child_session.read(cx).session_id());
+    let child_client = child_session.update(cx, |session, _| session.adapter_client().unwrap());
+
+    let child_disconnect_called = Arc::new(AtomicBool::new(false));
+    let child_disconnect_clone = child_disconnect_called.clone();
+    let disconnect_count_clone = disconnect_count.clone();
+
+    child_client.on_request::<Disconnect, _>(move |_, _| {
+        child_disconnect_clone.store(true, Ordering::SeqCst);
+        disconnect_count_clone.fetch_add(1, Ordering::SeqCst);
+
+        for _ in 0..50 {
+            if disconnect_count_clone.load(Ordering::SeqCst) >= 2 {
+                break;
+            }
+            std::thread::sleep(std::time::Duration::from_millis(1));
+        }
+
+        Ok(())
+    });
+
+    executor.run_until_parked();
+
+    project.update(cx, |project, cx| {
+        let store = project.dap_store().read(cx);
+        assert!(store.session_by_id(parent_session_id).is_some());
+        assert!(store.session_by_id(child_session_id).is_some());
+    });
+
+    cx.update(|_, cx| cx.defer(|cx| cx.shutdown()));
+
+    executor.run_until_parked();
+
+    let parent_disconnect_check = parent_disconnect_called.clone();
+    let child_disconnect_check = child_disconnect_called.clone();
+    let both_disconnected = executor
+        .spawn(async move {
+            let parent_disconnect = parent_disconnect_check;
+            let child_disconnect = child_disconnect_check;
+
+            // We only have 100ms to shutdown the app
+            for _ in 0..100 {
+                if parent_disconnect.load(Ordering::SeqCst)
+                    && child_disconnect.load(Ordering::SeqCst)
+                {
+                    return true;
+                }
+
+                gpui::Timer::after(std::time::Duration::from_millis(1)).await;
+            }
+
+            false
+        })
+        .await;
+
+    assert!(
+        both_disconnected,
+        "Both parent and child sessions should receive disconnect requests"
+    );
+
+    assert!(
+        parent_disconnect_called.load(Ordering::SeqCst),
+        "Parent session should have received disconnect request"
+    );
+    assert!(
+        child_disconnect_called.load(Ordering::SeqCst),
+        "Child session should have received disconnect request"
+    );
+}

crates/project/src/debugger/session.rs 🔗

@@ -790,7 +790,7 @@ impl Session {
                 BreakpointStoreEvent::SetDebugLine | BreakpointStoreEvent::ClearDebugLines => {}
             })
             .detach();
-            cx.on_app_quit(Self::on_app_quit).detach();
+            // cx.on_app_quit(Self::on_app_quit).detach();
 
             let this = Self {
                 mode: Mode::Building,
@@ -945,6 +945,37 @@ impl Session {
         self.parent_session.as_ref()
     }
 
+    pub fn on_app_quit(&mut self, cx: &mut Context<Self>) -> Task<()> {
+        let Some(client) = self.adapter_client() else {
+            return Task::ready(());
+        };
+
+        let supports_terminate = self
+            .capabilities
+            .support_terminate_debuggee
+            .unwrap_or(false);
+
+        cx.background_spawn(async move {
+            if supports_terminate {
+                client
+                    .request::<dap::requests::Terminate>(dap::TerminateArguments {
+                        restart: Some(false),
+                    })
+                    .await
+                    .ok();
+            } else {
+                client
+                    .request::<dap::requests::Disconnect>(dap::DisconnectArguments {
+                        restart: Some(false),
+                        terminate_debuggee: Some(true),
+                        suspend_debuggee: Some(false),
+                    })
+                    .await
+                    .ok();
+            }
+        })
+    }
+
     pub fn capabilities(&self) -> &Capabilities {
         &self.capabilities
     }
@@ -1818,17 +1849,11 @@ impl Session {
         }
     }
 
-    fn on_app_quit(&mut self, cx: &mut Context<Self>) -> Task<()> {
-        let debug_adapter = self.adapter_client();
-
-        cx.background_spawn(async move {
-            if let Some(client) = debug_adapter {
-                client.shutdown().await.log_err();
-            }
-        })
-    }
-
     pub fn shutdown(&mut self, cx: &mut Context<Self>) -> Task<()> {
+        if self.is_session_terminated {
+            return Task::ready(());
+        }
+
         self.is_session_terminated = true;
         self.thread_states.exit_all_threads();
         cx.notify();
@@ -1859,14 +1884,8 @@ impl Session {
 
         cx.emit(SessionStateEvent::Shutdown);
 
-        let debug_client = self.adapter_client();
-
-        cx.background_spawn(async move {
-            let _ = task.await;
-
-            if let Some(client) = debug_client {
-                client.shutdown().await.log_err();
-            }
+        cx.spawn(async move |_, _| {
+            task.await;
         })
     }