diff --git a/Cargo.lock b/Cargo.lock index 1010e09b4b5987752a5344dcc41d66d76ff63e1e..775f2d56f70f6eb1519b366f510674f23fc2c1ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -164,6 +164,7 @@ dependencies = [ "libc", "log", "nix 0.29.0", + "parking_lot", "paths", "project", "schemars", diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 6e01511850c83398061bba41ac14d457bd952cee..4168f08a53e1b3ad72c31cc8141a726ab9206b72 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -49,7 +49,7 @@ impl UserMessage { } fn to_markdown(&self, cx: &App) -> String { - format!("## User\n{}\n", self.content.to_markdown(cx)) + format!("## User\n\n{}\n\n", self.content.to_markdown(cx)) } } @@ -1711,7 +1711,7 @@ mod tests { .unwrap(); drop(end_turn_tx); - request.await.unwrap(); + assert!(request.await.unwrap_err().to_string().contains("canceled")); thread.read_with(cx, |thread, _| { assert!(matches!( @@ -1816,6 +1816,7 @@ mod tests { struct FakeAgent { server: Entity, cx: AsyncApp, + cancel_tx: Rc>>>, } impl acp_old::Agent for FakeAgent { @@ -1834,6 +1835,9 @@ mod tests { } async fn cancel_send_message(&self) -> Result<(), acp_old::Error> { + if let Some(cancel_tx) = self.cancel_tx.take() { + cancel_tx.send(()).log_err(); + } Ok(()) } @@ -1841,6 +1845,9 @@ mod tests { &self, request: acp_old::SendUserMessageParams, ) -> Result<(), acp_old::Error> { + let (cancel_tx, cancel_rx) = oneshot::channel(); + self.cancel_tx.replace(Some(cancel_tx)); + let mut cx = self.cx.clone(); let handler = self .server @@ -1848,7 +1855,10 @@ mod tests { .ok() .flatten(); if let Some(handler) = handler { - handler(request, self.server.clone(), self.cx.clone()).await + select! { + _ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()), + _ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()), + } } else { Err(anyhow::anyhow!("No handler for on_user_message").into()) } @@ -1860,6 +1870,7 @@ mod tests { let agent = FakeAgent { server: cx.entity(), cx: cx.to_async(), + cancel_tx: Default::default(), }; let foreground_executor = cx.foreground_executor().clone(); diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 75d682012bebc0576997c889a2065d1a0150be7f..62c2bf73f76a765aa358226e7978773a11ee493f 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -27,12 +27,14 @@ futures.workspace = true gpui.workspace = true itertools.workspace = true log.workspace = true +parking_lot.workspace = true paths.workspace = true project.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true +shlex.workspace = true smol.workspace = true strum.workspace = true tempfile.workspace = true @@ -41,7 +43,6 @@ util.workspace = true uuid.workspace = true watch.workspace = true which.workspace = true -shlex.workspace = true workspace-hack.workspace = true [target.'cfg(unix)'.dependencies] diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 0a0b94b832d23c3a28875c438c8d4eeb918321e8..7c1c2358063988f29f97bc6a04f15b3cdcbf152d 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -2,7 +2,6 @@ mod claude; mod codex; mod gemini; mod settings; -mod stdio_agent_server; #[cfg(test)] mod e2e_tests; @@ -11,7 +10,6 @@ pub use claude::*; pub use codex::*; pub use gemini::*; pub use settings::*; -pub use stdio_agent_server::*; use acp_thread::AgentConnection; use anyhow::Result; diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 6b4561de1bae49ba2b2f272f6516f258c63b489c..5f35b4af734fc7cd7daf9cdae07ecff42c0f3985 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -291,8 +291,15 @@ impl AgentConnection for ClaudeAgentConnection { .log_err() .is_some() { + let end_turn_tx = session.end_turn_tx.clone(); cx.foreground_executor() - .spawn(async move { done_rx.await? }) + .spawn(async move { + done_rx.await??; + if let Some(end_turn_tx) = end_turn_tx.take() { + end_turn_tx.send(Ok(())).ok(); + } + anyhow::Ok(()) + }) .detach_and_log_err(cx); } } diff --git a/crates/agent_servers/src/gemini.rs b/crates/agent_servers/src/gemini.rs index 678d9f2297e085fba9765f5df6a78823d82ba927..47b965cdada579019364f8798abb3c0baef691ff 100644 --- a/crates/agent_servers/src/gemini.rs +++ b/crates/agent_servers/src/gemini.rs @@ -1,9 +1,17 @@ -use crate::stdio_agent_server::StdioAgentServer; -use crate::{AgentServerCommand, AgentServerVersion}; +use anyhow::anyhow; +use std::cell::RefCell; +use std::path::Path; +use std::rc::Rc; +use util::ResultExt as _; + +use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; +use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; +use agentic_coding_protocol as acp_old; use anyhow::{Context as _, Result}; -use gpui::{AsyncApp, Entity}; +use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity}; use project::Project; use settings::SettingsStore; +use ui::App; use crate::AllAgentServersSettings; @@ -12,7 +20,7 @@ pub struct Gemini; const ACP_ARG: &str = "--experimental-acp"; -impl StdioAgentServer for Gemini { +impl AgentServer for Gemini { fn name(&self) -> &'static str { "Gemini" } @@ -29,6 +37,84 @@ impl StdioAgentServer for Gemini { ui::IconName::AiGemini } + fn connect( + &self, + root_dir: &Path, + project: &Entity, + cx: &mut App, + ) -> Task>> { + let root_dir = root_dir.to_path_buf(); + let project = project.clone(); + let this = self.clone(); + let name = self.name(); + + cx.spawn(async move |cx| { + let command = this.command(&project, cx).await?; + + let mut child = util::command::new_smol_command(&command.path) + .args(command.args.iter()) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + + let foreground_executor = cx.foreground_executor().clone(); + + let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); + + let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( + OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), + stdin, + stdout, + move |fut| foreground_executor.spawn(fut).detach(), + ); + + let io_task = cx.background_spawn(async move { + io_fut.await.log_err(); + }); + + let child_status = cx.background_spawn(async move { + let result = match child.status().await { + Err(e) => Err(anyhow!(e)), + Ok(result) if result.success() => Ok(()), + Ok(result) => { + if let Some(AgentServerVersion::Unsupported { + error_message, + upgrade_message, + upgrade_command, + }) = this.version(&command).await.log_err() + { + Err(anyhow!(LoadError::Unsupported { + error_message, + upgrade_message, + upgrade_command + })) + } else { + Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) + } + } + }; + drop(io_task); + result + }); + + let connection: Rc = Rc::new(OldAcpAgentConnection { + name, + connection, + child_status, + }); + + Ok(connection) + }) + } +} + +impl Gemini { async fn command( &self, project: &Entity, diff --git a/crates/agent_servers/src/stdio_agent_server.rs b/crates/agent_servers/src/stdio_agent_server.rs deleted file mode 100644 index 0e10f6ba6693e3c0fa08192d93ea7bfa2e0b6c1d..0000000000000000000000000000000000000000 --- a/crates/agent_servers/src/stdio_agent_server.rs +++ /dev/null @@ -1,120 +0,0 @@ -use crate::{AgentServer, AgentServerCommand, AgentServerVersion}; -use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate}; -use agentic_coding_protocol as acp_old; -use anyhow::{Result, anyhow}; -use gpui::{App, AsyncApp, Entity, Task, WeakEntity, prelude::*}; -use project::Project; -use std::{cell::RefCell, path::Path, rc::Rc}; -use util::ResultExt; - -pub trait StdioAgentServer: Send + Clone { - fn logo(&self) -> ui::IconName; - fn name(&self) -> &'static str; - fn empty_state_headline(&self) -> &'static str; - fn empty_state_message(&self) -> &'static str; - - fn command( - &self, - project: &Entity, - cx: &mut AsyncApp, - ) -> impl Future>; - - fn version( - &self, - command: &AgentServerCommand, - ) -> impl Future> + Send; -} - -impl AgentServer for T { - fn name(&self) -> &'static str { - self.name() - } - - fn empty_state_headline(&self) -> &'static str { - self.empty_state_headline() - } - - fn empty_state_message(&self) -> &'static str { - self.empty_state_message() - } - - fn logo(&self) -> ui::IconName { - self.logo() - } - - fn connect( - &self, - root_dir: &Path, - project: &Entity, - cx: &mut App, - ) -> Task>> { - let root_dir = root_dir.to_path_buf(); - let project = project.clone(); - let this = self.clone(); - let name = self.name(); - - cx.spawn(async move |cx| { - let command = this.command(&project, cx).await?; - - let mut child = util::command::new_smol_command(&command.path) - .args(command.args.iter()) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - - let foreground_executor = cx.foreground_executor().clone(); - - let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid())); - - let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent( - OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()), - stdin, - stdout, - move |fut| foreground_executor.spawn(fut).detach(), - ); - - let io_task = cx.background_spawn(async move { - io_fut.await.log_err(); - }); - - let child_status = cx.background_spawn(async move { - let result = match child.status().await { - Err(e) => Err(anyhow!(e)), - Ok(result) if result.success() => Ok(()), - Ok(result) => { - if let Some(AgentServerVersion::Unsupported { - error_message, - upgrade_message, - upgrade_command, - }) = this.version(&command).await.log_err() - { - Err(anyhow!(LoadError::Unsupported { - error_message, - upgrade_message, - upgrade_command - })) - } else { - Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127)))) - } - } - }; - drop(io_task); - result - }); - - let connection: Rc = Rc::new(OldAcpAgentConnection { - name, - connection, - child_status, - }); - - Ok(connection) - }) - } -} diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index 8c7def6b0ea3b01c339af5df9d1f4ae715e01e0b..98992b2b73c472cda004be2fb731ef28872580b3 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -68,6 +68,7 @@ pub struct AcpThreadView { plan_expanded: bool, editor_expanded: bool, message_history: Rc>>>, + _cancel_task: Option>, } enum ThreadState { @@ -183,6 +184,7 @@ impl AcpThreadView { plan_expanded: false, editor_expanded: false, message_history, + _cancel_task: None, } } @@ -299,7 +301,7 @@ impl AcpThreadView { self.last_error.take(); if let Some(thread) = self.thread() { - thread.update(cx, |thread, cx| thread.cancel(cx)).detach(); + self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx))); } } diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 57534c8cd540c171069751281e2bc824ed05c343..afb40ef37f539f5ff53e22e5e5b53da5f8b370d7 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -19,7 +19,7 @@ use collections::VecDeque; use debugger_ui::debugger_panel::DebugPanel; use editor::ProposedChangesEditorToolbar; use editor::{Editor, MultiBuffer}; -use feature_flags::{FeatureFlagAppExt, PanicFeatureFlag}; +use feature_flags::FeatureFlagAppExt; use futures::future::Either; use futures::{StreamExt, channel::mpsc, select_biased}; use git_ui::git_panel::GitPanel;