diff --git a/Cargo.lock b/Cargo.lock index d90c6f8921577e4e388b4e78c91bdf42a95ca2fc..b93f6ddb22d0bcfb6516ef7e8933ba8ba7505e35 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -258,6 +258,7 @@ dependencies = [ "action_log", "agent-client-protocol", "anyhow", + "async-pipe", "async-trait", "chrono", "client", @@ -275,7 +276,6 @@ dependencies = [ "libc", "log", "nix 0.29.0", - "piper", "project", "release_channel", "remote", diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index 85b206248c7e4ccd039bc92e911891a8cf830727..0b547e8a0af797c0d13819869c4cac4eb7f046fb 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -6,7 +6,7 @@ publish.workspace = true license = "GPL-3.0-or-later" [features] -test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support", "dep:env_logger", "client/test-support", "dep:gpui_tokio", "reqwest_client/test-support"] +test-support = ["acp_thread/test-support", "gpui/test-support", "project/test-support", "dep:async-pipe", "dep:env_logger", "client/test-support", "dep:gpui_tokio", "reqwest_client/test-support"] e2e = [] [lints] @@ -22,6 +22,7 @@ acp_thread.workspace = true action_log.workspace = true agent-client-protocol.workspace = true anyhow.workspace = true +async-pipe = { workspace = true, optional = true } async-trait.workspace = true chrono.workspace = true client.workspace = true @@ -66,9 +67,9 @@ fs.workspace = true indoc.workspace = true acp_thread = { workspace = true, features = ["test-support"] } +async-pipe.workspace = true gpui = { workspace = true, features = ["test-support"] } gpui_tokio.workspace = true -piper = "0.2" project = { workspace = true, features = ["test-support"] } reqwest_client = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_servers/src/acp.rs b/crates/agent_servers/src/acp.rs index dae7888e65a01b09699aff59a758d200c03087e3..ce080b244dd4f9560915b99978d6c746c80d2d88 100644 --- a/crates/agent_servers/src/acp.rs +++ b/crates/agent_servers/src/acp.rs @@ -74,6 +74,7 @@ struct SessionConfigResponse { config_options: Option>, } +#[derive(Clone)] struct ConfigOptions { config_options: Rc>>, tx: Rc>>, @@ -315,16 +316,7 @@ impl AcpConnection { let status_fut = child.status(); async move |cx| { let status = status_fut.await?; - - for session in sessions.borrow().values() { - session - .thread - .update(cx, |thread, cx| { - thread.emit_load_error(LoadError::Exited { status }, cx) - }) - .ok(); - } - + emit_load_error_to_all_sessions(&sessions, LoadError::Exited { status }, cx); anyhow::Ok(()) } }); @@ -427,7 +419,7 @@ impl AcpConnection { &self.agent_capabilities.prompt_capabilities } - #[cfg(test)] + #[cfg(any(test, feature = "test-support"))] fn new_for_test( connection: Rc, sessions: Rc>>, @@ -661,6 +653,24 @@ impl AcpConnection { } } +fn emit_load_error_to_all_sessions( + sessions: &Rc>>, + error: LoadError, + cx: &mut AsyncApp, +) { + let threads: Vec<_> = sessions + .borrow() + .values() + .map(|session| session.thread.clone()) + .collect(); + + for thread in threads { + thread + .update(cx, |thread, cx| thread.emit_load_error(error.clone(), cx)) + .ok(); + } +} + impl Drop for AcpConnection { fn drop(&mut self) { if let Some(ref mut child) = self.child { @@ -1204,6 +1214,447 @@ fn map_acp_error(err: acp::Error) -> anyhow::Error { } } +#[cfg(any(test, feature = "test-support"))] +pub mod test_support { + use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + + use acp_thread::{ + AgentModelSelector, AgentSessionConfigOptions, AgentSessionModes, AgentSessionRetry, + AgentSessionSetTitle, AgentSessionTruncate, AgentTelemetry, UserMessageId, + }; + + use super::*; + + #[derive(Clone, Default)] + pub struct FakeAcpAgentServer { + load_session_count: Arc, + close_session_count: Arc, + fail_next_prompt: Arc, + exit_status_sender: + Arc>>>, + } + + impl FakeAcpAgentServer { + pub fn new() -> Self { + Self::default() + } + + pub fn load_session_count(&self) -> Arc { + self.load_session_count.clone() + } + + pub fn close_session_count(&self) -> Arc { + self.close_session_count.clone() + } + + pub fn simulate_server_exit(&self) { + let sender = self + .exit_status_sender + .lock() + .expect("exit status sender lock should not be poisoned") + .clone() + .expect("fake ACP server must be connected before simulating exit"); + sender + .try_send(std::process::ExitStatus::default()) + .expect("fake ACP server exit receiver should still be alive"); + } + + pub fn fail_next_prompt(&self) { + self.fail_next_prompt.store(true, Ordering::SeqCst); + } + } + + impl crate::AgentServer for FakeAcpAgentServer { + fn logo(&self) -> ui::IconName { + ui::IconName::ZedAgent + } + + fn agent_id(&self) -> AgentId { + AgentId::new("Test") + } + + fn connect( + &self, + _delegate: crate::AgentServerDelegate, + project: Entity, + cx: &mut App, + ) -> Task>> { + let load_session_count = self.load_session_count.clone(); + let close_session_count = self.close_session_count.clone(); + let fail_next_prompt = self.fail_next_prompt.clone(); + let exit_status_sender = self.exit_status_sender.clone(); + cx.spawn(async move |cx| { + let harness = build_fake_acp_connection( + project, + load_session_count, + close_session_count, + fail_next_prompt, + cx, + ) + .await?; + let (exit_tx, exit_rx) = smol::channel::bounded(1); + *exit_status_sender + .lock() + .expect("exit status sender lock should not be poisoned") = Some(exit_tx); + let connection = harness.connection.clone(); + let simulate_exit_task = cx.spawn(async move |cx| { + while let Ok(status) = exit_rx.recv().await { + emit_load_error_to_all_sessions( + &connection.sessions, + LoadError::Exited { status }, + cx, + ); + } + Ok(()) + }); + Ok(Rc::new(FakeAcpAgentConnection { + inner: harness.connection, + _keep_agent_alive: harness.keep_agent_alive, + _simulate_exit_task: simulate_exit_task, + }) as Rc) + }) + } + + fn into_any(self: Rc) -> Rc { + self + } + } + + pub struct FakeAcpConnectionHarness { + pub connection: Rc, + pub load_session_count: Arc, + pub close_session_count: Arc, + pub keep_agent_alive: Task>, + } + + struct FakeAcpAgentConnection { + inner: Rc, + _keep_agent_alive: Task>, + _simulate_exit_task: Task>, + } + + impl AgentConnection for FakeAcpAgentConnection { + fn agent_id(&self) -> AgentId { + self.inner.agent_id() + } + + fn telemetry_id(&self) -> SharedString { + self.inner.telemetry_id() + } + + fn new_session( + self: Rc, + project: Entity, + work_dirs: PathList, + cx: &mut App, + ) -> Task>> { + self.inner.clone().new_session(project, work_dirs, cx) + } + + fn supports_load_session(&self) -> bool { + self.inner.supports_load_session() + } + + fn load_session( + self: Rc, + session_id: acp::SessionId, + project: Entity, + work_dirs: PathList, + title: Option, + cx: &mut App, + ) -> Task>> { + self.inner + .clone() + .load_session(session_id, project, work_dirs, title, cx) + } + + fn supports_close_session(&self) -> bool { + self.inner.supports_close_session() + } + + fn close_session( + self: Rc, + session_id: &acp::SessionId, + cx: &mut App, + ) -> Task> { + self.inner.clone().close_session(session_id, cx) + } + + fn supports_resume_session(&self) -> bool { + self.inner.supports_resume_session() + } + + fn resume_session( + self: Rc, + session_id: acp::SessionId, + project: Entity, + work_dirs: PathList, + title: Option, + cx: &mut App, + ) -> Task>> { + self.inner + .clone() + .resume_session(session_id, project, work_dirs, title, cx) + } + + fn auth_methods(&self) -> &[acp::AuthMethod] { + self.inner.auth_methods() + } + + fn terminal_auth_task( + &self, + method: &acp::AuthMethodId, + cx: &App, + ) -> Option>> { + self.inner.terminal_auth_task(method, cx) + } + + fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task> { + self.inner.authenticate(method, cx) + } + + fn prompt( + &self, + user_message_id: UserMessageId, + params: acp::PromptRequest, + cx: &mut App, + ) -> Task> { + self.inner.prompt(user_message_id, params, cx) + } + + fn retry( + &self, + session_id: &acp::SessionId, + cx: &App, + ) -> Option> { + self.inner.retry(session_id, cx) + } + + fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) { + self.inner.cancel(session_id, cx) + } + + fn truncate( + &self, + session_id: &acp::SessionId, + cx: &App, + ) -> Option> { + self.inner.truncate(session_id, cx) + } + + fn set_title( + &self, + session_id: &acp::SessionId, + cx: &App, + ) -> Option> { + self.inner.set_title(session_id, cx) + } + + fn model_selector( + &self, + session_id: &acp::SessionId, + ) -> Option> { + self.inner.model_selector(session_id) + } + + fn telemetry(&self) -> Option> { + self.inner.telemetry() + } + + fn session_modes( + &self, + session_id: &acp::SessionId, + cx: &App, + ) -> Option> { + self.inner.session_modes(session_id, cx) + } + + fn session_config_options( + &self, + session_id: &acp::SessionId, + cx: &App, + ) -> Option> { + self.inner.session_config_options(session_id, cx) + } + + fn session_list(&self, cx: &mut App) -> Option> { + self.inner.session_list(cx) + } + + fn into_any(self: Rc) -> Rc { + self + } + } + + struct FakeAcpAgent { + load_session_count: Arc, + close_session_count: Arc, + fail_next_prompt: Arc, + } + + #[async_trait::async_trait(?Send)] + impl acp::Agent for FakeAcpAgent { + async fn initialize( + &self, + args: acp::InitializeRequest, + ) -> acp::Result { + Ok( + acp::InitializeResponse::new(args.protocol_version).agent_capabilities( + acp::AgentCapabilities::default() + .load_session(true) + .session_capabilities( + acp::SessionCapabilities::default() + .close(acp::SessionCloseCapabilities::new()), + ), + ), + ) + } + + async fn authenticate( + &self, + _: acp::AuthenticateRequest, + ) -> acp::Result { + Ok(Default::default()) + } + + async fn new_session( + &self, + _: acp::NewSessionRequest, + ) -> acp::Result { + Ok(acp::NewSessionResponse::new(acp::SessionId::new("unused"))) + } + + async fn prompt(&self, _: acp::PromptRequest) -> acp::Result { + if self.fail_next_prompt.swap(false, Ordering::SeqCst) { + Err(acp::ErrorCode::InternalError.into()) + } else { + Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)) + } + } + + async fn cancel(&self, _: acp::CancelNotification) -> acp::Result<()> { + Ok(()) + } + + async fn load_session( + &self, + _: acp::LoadSessionRequest, + ) -> acp::Result { + self.load_session_count.fetch_add(1, Ordering::SeqCst); + Ok(acp::LoadSessionResponse::new()) + } + + async fn close_session( + &self, + _: acp::CloseSessionRequest, + ) -> acp::Result { + self.close_session_count.fetch_add(1, Ordering::SeqCst); + Ok(acp::CloseSessionResponse::new()) + } + } + + async fn build_fake_acp_connection( + project: Entity, + load_session_count: Arc, + close_session_count: Arc, + fail_next_prompt: Arc, + cx: &mut AsyncApp, + ) -> Result { + let (c2a_writer, c2a_reader) = async_pipe::pipe(); + let (a2c_writer, a2c_reader) = async_pipe::pipe(); + + let sessions: Rc>> = + Rc::new(RefCell::new(HashMap::default())); + let session_list_container: Rc>>> = + Rc::new(RefCell::new(None)); + + let foreground = cx.foreground_executor().clone(); + + let client_delegate = ClientDelegate { + sessions: sessions.clone(), + session_list: session_list_container, + cx: cx.clone(), + }; + + let (client_conn, client_io_task) = + acp::ClientSideConnection::new(client_delegate, c2a_writer, a2c_reader, { + let foreground = foreground.clone(); + move |fut| { + foreground.spawn(fut).detach(); + } + }); + + let fake_agent = FakeAcpAgent { + load_session_count: load_session_count.clone(), + close_session_count: close_session_count.clone(), + fail_next_prompt, + }; + + let (_, agent_io_task) = + acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, { + let foreground = foreground.clone(); + move |fut| { + foreground.spawn(fut).detach(); + } + }); + + let client_io_task = cx.background_spawn(client_io_task); + let agent_io_task = cx.background_spawn(agent_io_task); + + let response = client_conn + .initialize(acp::InitializeRequest::new(acp::ProtocolVersion::V1)) + .await?; + + let agent_capabilities = response.agent_capabilities; + + let agent_server_store = + project.read_with(cx, |project, _| project.agent_server_store().downgrade()); + + let connection = cx.update(|cx| { + AcpConnection::new_for_test( + Rc::new(client_conn), + sessions, + agent_capabilities, + agent_server_store, + client_io_task, + cx, + ) + }); + + let keep_agent_alive = cx.background_spawn(async move { + agent_io_task.await.ok(); + anyhow::Ok(()) + }); + + Ok(FakeAcpConnectionHarness { + connection: Rc::new(connection), + load_session_count, + close_session_count, + keep_agent_alive, + }) + } + + pub async fn connect_fake_acp_connection( + project: Entity, + cx: &mut gpui::TestAppContext, + ) -> FakeAcpConnectionHarness { + cx.update(|cx| { + let store = settings::SettingsStore::test(cx); + cx.set_global(store); + }); + + build_fake_acp_connection( + project, + Arc::new(AtomicUsize::new(0)), + Arc::new(AtomicUsize::new(0)), + Arc::new(AtomicBool::new(false)), + &mut cx.to_async(), + ) + .await + .expect("failed to initialize ACP connection") + } +} + #[cfg(test)] mod tests { use std::sync::atomic::{AtomicUsize, Ordering}; @@ -1420,8 +1871,8 @@ mod tests { let load_count = Arc::new(AtomicUsize::new(0)); let close_count = Arc::new(AtomicUsize::new(0)); - let (c2a_reader, c2a_writer) = piper::pipe(4096); - let (a2c_reader, a2c_writer) = piper::pipe(4096); + let (c2a_writer, c2a_reader) = async_pipe::pipe(); + let (a2c_writer, a2c_reader) = async_pipe::pipe(); let sessions: Rc>> = Rc::new(RefCell::new(HashMap::default())); @@ -1866,17 +2317,24 @@ impl acp::Client for ClientDelegate { &self, notification: acp::SessionNotification, ) -> Result<(), acp::Error> { - let sessions = self.sessions.borrow(); - let session = sessions - .get(¬ification.session_id) - .context("Failed to get session")?; + let (thread, session_modes, session_config_options) = { + let sessions = self.sessions.borrow(); + let session = sessions + .get(¬ification.session_id) + .context("Failed to get session")?; + ( + session.thread.clone(), + session.session_modes.clone(), + session.config_options.clone(), + ) + }; if let acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate { current_mode_id, .. }) = ¬ification.update { - if let Some(session_modes) = &session.session_modes { + if let Some(session_modes) = &session_modes { session_modes.borrow_mut().current_mode_id = current_mode_id.clone(); } } @@ -1886,7 +2344,7 @@ impl acp::Client for ClientDelegate { .. }) = ¬ification.update { - if let Some(opts) = &session.config_options { + if let Some(opts) = &session_config_options { *opts.config_options.borrow_mut() = config_options.clone(); opts.tx.borrow_mut().send(()).ok(); } @@ -1913,7 +2371,7 @@ impl acp::Client for ClientDelegate { .and_then(|v| v.as_str().map(PathBuf::from)); // Create a minimal display-only lower-level terminal and register it. - let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| { + let _ = thread.update(&mut self.cx.clone(), |thread, cx| { let builder = TerminalBuilder::new_display_only( CursorShape::default(), AlternateScroll::On, @@ -1941,7 +2399,7 @@ impl acp::Client for ClientDelegate { } // Forward the update to the acp_thread as usual. - session.thread.update(&mut self.cx.clone(), |thread, cx| { + thread.update(&mut self.cx.clone(), |thread, cx| { thread.handle_session_update(notification.update.clone(), cx) })??; @@ -1953,7 +2411,7 @@ impl acp::Client for ClientDelegate { let terminal_id = acp::TerminalId::new(id_str); if let Some(s) = term_out.get("data").and_then(|v| v.as_str()) { let data = s.as_bytes().to_vec(); - let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| { + let _ = thread.update(&mut self.cx.clone(), |thread, cx| { thread.on_terminal_provider_event( TerminalProviderEvent::Output { terminal_id, data }, cx, @@ -1980,7 +2438,7 @@ impl acp::Client for ClientDelegate { .and_then(|v| v.as_str().map(|s| s.to_string())), ); - let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| { + let _ = thread.update(&mut self.cx.clone(), |thread, cx| { thread.on_terminal_provider_event( TerminalProviderEvent::Exit { terminal_id, diff --git a/crates/agent_servers/src/agent_servers.rs b/crates/agent_servers/src/agent_servers.rs index 2016e5aaaa27b62c956c5eee49c989172980de49..f609a5f50aef3af9a0a27482e0022ac0cee8d501 100644 --- a/crates/agent_servers/src/agent_servers.rs +++ b/crates/agent_servers/src/agent_servers.rs @@ -17,6 +17,10 @@ use gpui::{App, AppContext, Entity, Task}; use settings::SettingsStore; use std::{any::Any, rc::Rc, sync::Arc}; +#[cfg(any(test, feature = "test-support"))] +pub use acp::test_support::{ + FakeAcpAgentServer, FakeAcpConnectionHarness, connect_fake_acp_connection, +}; pub use acp::{AcpConnection, GEMINI_TERMINAL_AUTH_METHOD_ID}; pub struct AgentServerDelegate { diff --git a/crates/agent_ui/Cargo.toml b/crates/agent_ui/Cargo.toml index 64114c4f2a4b10b0566250ba772b6b3188b14123..3813e99bcd165009177e4d515e569eff61c29941 100644 --- a/crates/agent_ui/Cargo.toml +++ b/crates/agent_ui/Cargo.toml @@ -116,6 +116,7 @@ reqwest_client = { workspace = true, optional = true } [dev-dependencies] acp_thread = { workspace = true, features = ["test-support"] } agent = { workspace = true, features = ["test-support"] } +agent_servers = { workspace = true, features = ["test-support"] } buffer_diff = { workspace = true, features = ["test-support"] } client = { workspace = true, features = ["test-support"] } clock = { workspace = true, features = ["test-support"] } diff --git a/crates/agent_ui/src/conversation_view.rs b/crates/agent_ui/src/conversation_view.rs index bb19274711b5e654cab775c32bd6766b5d84b1f5..787fe774c3b7864ebe679b7bef1525f8a5e6ec49 100644 --- a/crates/agent_ui/src/conversation_view.rs +++ b/crates/agent_ui/src/conversation_view.rs @@ -2863,6 +2863,7 @@ pub(crate) mod tests { use action_log::ActionLog; use agent::{AgentTool, EditFileTool, FetchTool, TerminalTool, ToolPermissionContext}; use agent_client_protocol::SessionId; + use agent_servers::FakeAcpAgentServer; use editor::MultiBufferOffset; use fs::FakeFs; use gpui::{EventEmitter, TestAppContext, VisualTestContext}; @@ -2972,8 +2973,8 @@ pub(crate) mod tests { async fn test_notification_for_error(cx: &mut TestAppContext) { init_test(cx); - let (conversation_view, cx) = - setup_conversation_view(StubAgentServer::new(SaboteurAgentConnection), cx).await; + let server = FakeAcpAgentServer::new(); + let (conversation_view, cx) = setup_conversation_view(server.clone(), cx).await; let message_editor = message_editor(&conversation_view, cx); message_editor.update_in(cx, |editor, window, cx| { @@ -2981,6 +2982,7 @@ pub(crate) mod tests { }); cx.deactivate_window(); + server.fail_next_prompt(); active_thread(&conversation_view, cx) .update_in(cx, |view, window, cx| view.send(window, cx)); @@ -2994,6 +2996,34 @@ pub(crate) mod tests { ); } + #[gpui::test] + async fn test_acp_server_exit_transitions_conversation_to_load_error_without_panic( + cx: &mut TestAppContext, + ) { + init_test(cx); + + let server = FakeAcpAgentServer::new(); + let close_session_count = server.close_session_count(); + let (conversation_view, cx) = setup_conversation_view(server.clone(), cx).await; + + cx.run_until_parked(); + + server.simulate_server_exit(); + cx.run_until_parked(); + + conversation_view.read_with(cx, |view, _cx| { + assert!( + matches!(view.server_state, ServerState::LoadError { .. }), + "Conversation should transition to LoadError when an ACP thread exits" + ); + }); + assert_eq!( + close_session_count.load(std::sync::atomic::Ordering::SeqCst), + 1, + "ConversationView should close the ACP session after a thread exit" + ); + } + #[gpui::test] async fn test_recent_history_refreshes_when_history_cache_updated(cx: &mut TestAppContext) { init_test(cx); @@ -4520,75 +4550,6 @@ pub(crate) mod tests { } } - #[derive(Clone)] - struct SaboteurAgentConnection; - - impl AgentConnection for SaboteurAgentConnection { - fn agent_id(&self) -> AgentId { - AgentId::new("saboteur") - } - - fn telemetry_id(&self) -> SharedString { - "saboteur".into() - } - - fn new_session( - self: Rc, - project: Entity, - work_dirs: PathList, - cx: &mut gpui::App, - ) -> Task>> { - Task::ready(Ok(cx.new(|cx| { - let action_log = cx.new(|_| ActionLog::new(project.clone())); - AcpThread::new( - None, - None, - Some(work_dirs), - self, - project, - action_log, - SessionId::new("test"), - watch::Receiver::constant( - acp::PromptCapabilities::new() - .image(true) - .audio(true) - .embedded_context(true), - ), - cx, - ) - }))) - } - - fn auth_methods(&self) -> &[acp::AuthMethod] { - &[] - } - - fn authenticate( - &self, - _method_id: acp::AuthMethodId, - _cx: &mut App, - ) -> Task> { - unimplemented!() - } - - fn prompt( - &self, - _id: acp_thread::UserMessageId, - _params: acp::PromptRequest, - _cx: &mut App, - ) -> Task> { - Task::ready(Err(anyhow::anyhow!("Error prompting"))) - } - - fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) { - unimplemented!() - } - - fn into_any(self: Rc) -> Rc { - self - } - } - /// Simulates a model which always returns a refusal response #[derive(Clone)] struct RefusalAgentConnection;