@@ -74,6 +74,7 @@ struct SessionConfigResponse {
config_options: Option<Vec<acp::SessionConfigOption>>,
}
+#[derive(Clone)]
struct ConfigOptions {
config_options: Rc<RefCell<Vec<acp::SessionConfigOption>>>,
tx: Rc<RefCell<watch::Sender<()>>>,
@@ -315,16 +316,7 @@ impl AcpConnection {
let status_fut = child.status();
async move |cx| {
let status = status_fut.await?;
-
- for session in sessions.borrow().values() {
- session
- .thread
- .update(cx, |thread, cx| {
- thread.emit_load_error(LoadError::Exited { status }, cx)
- })
- .ok();
- }
-
+ emit_load_error_to_all_sessions(&sessions, LoadError::Exited { status }, cx);
anyhow::Ok(())
}
});
@@ -427,7 +419,7 @@ impl AcpConnection {
&self.agent_capabilities.prompt_capabilities
}
- #[cfg(test)]
+ #[cfg(any(test, feature = "test-support"))]
fn new_for_test(
connection: Rc<acp::ClientSideConnection>,
sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
@@ -661,6 +653,24 @@ impl AcpConnection {
}
}
+fn emit_load_error_to_all_sessions(
+ sessions: &Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
+ error: LoadError,
+ cx: &mut AsyncApp,
+) {
+ let threads: Vec<_> = sessions
+ .borrow()
+ .values()
+ .map(|session| session.thread.clone())
+ .collect();
+
+ for thread in threads {
+ thread
+ .update(cx, |thread, cx| thread.emit_load_error(error.clone(), cx))
+ .ok();
+ }
+}
+
impl Drop for AcpConnection {
fn drop(&mut self) {
if let Some(ref mut child) = self.child {
@@ -1204,6 +1214,447 @@ fn map_acp_error(err: acp::Error) -> anyhow::Error {
}
}
+#[cfg(any(test, feature = "test-support"))]
+pub mod test_support {
+ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
+
+ use acp_thread::{
+ AgentModelSelector, AgentSessionConfigOptions, AgentSessionModes, AgentSessionRetry,
+ AgentSessionSetTitle, AgentSessionTruncate, AgentTelemetry, UserMessageId,
+ };
+
+ use super::*;
+
+ #[derive(Clone, Default)]
+ pub struct FakeAcpAgentServer {
+ load_session_count: Arc<AtomicUsize>,
+ close_session_count: Arc<AtomicUsize>,
+ fail_next_prompt: Arc<AtomicBool>,
+ exit_status_sender:
+ Arc<std::sync::Mutex<Option<smol::channel::Sender<std::process::ExitStatus>>>>,
+ }
+
+ impl FakeAcpAgentServer {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn load_session_count(&self) -> Arc<AtomicUsize> {
+ self.load_session_count.clone()
+ }
+
+ pub fn close_session_count(&self) -> Arc<AtomicUsize> {
+ self.close_session_count.clone()
+ }
+
+ pub fn simulate_server_exit(&self) {
+ let sender = self
+ .exit_status_sender
+ .lock()
+ .expect("exit status sender lock should not be poisoned")
+ .clone()
+ .expect("fake ACP server must be connected before simulating exit");
+ sender
+ .try_send(std::process::ExitStatus::default())
+ .expect("fake ACP server exit receiver should still be alive");
+ }
+
+ pub fn fail_next_prompt(&self) {
+ self.fail_next_prompt.store(true, Ordering::SeqCst);
+ }
+ }
+
+ impl crate::AgentServer for FakeAcpAgentServer {
+ fn logo(&self) -> ui::IconName {
+ ui::IconName::ZedAgent
+ }
+
+ fn agent_id(&self) -> AgentId {
+ AgentId::new("Test")
+ }
+
+ fn connect(
+ &self,
+ _delegate: crate::AgentServerDelegate,
+ project: Entity<Project>,
+ cx: &mut App,
+ ) -> Task<anyhow::Result<Rc<dyn AgentConnection>>> {
+ let load_session_count = self.load_session_count.clone();
+ let close_session_count = self.close_session_count.clone();
+ let fail_next_prompt = self.fail_next_prompt.clone();
+ let exit_status_sender = self.exit_status_sender.clone();
+ cx.spawn(async move |cx| {
+ let harness = build_fake_acp_connection(
+ project,
+ load_session_count,
+ close_session_count,
+ fail_next_prompt,
+ cx,
+ )
+ .await?;
+ let (exit_tx, exit_rx) = smol::channel::bounded(1);
+ *exit_status_sender
+ .lock()
+ .expect("exit status sender lock should not be poisoned") = Some(exit_tx);
+ let connection = harness.connection.clone();
+ let simulate_exit_task = cx.spawn(async move |cx| {
+ while let Ok(status) = exit_rx.recv().await {
+ emit_load_error_to_all_sessions(
+ &connection.sessions,
+ LoadError::Exited { status },
+ cx,
+ );
+ }
+ Ok(())
+ });
+ Ok(Rc::new(FakeAcpAgentConnection {
+ inner: harness.connection,
+ _keep_agent_alive: harness.keep_agent_alive,
+ _simulate_exit_task: simulate_exit_task,
+ }) as Rc<dyn AgentConnection>)
+ })
+ }
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
+ }
+
+ pub struct FakeAcpConnectionHarness {
+ pub connection: Rc<AcpConnection>,
+ pub load_session_count: Arc<AtomicUsize>,
+ pub close_session_count: Arc<AtomicUsize>,
+ pub keep_agent_alive: Task<anyhow::Result<()>>,
+ }
+
+ struct FakeAcpAgentConnection {
+ inner: Rc<AcpConnection>,
+ _keep_agent_alive: Task<anyhow::Result<()>>,
+ _simulate_exit_task: Task<anyhow::Result<()>>,
+ }
+
+ impl AgentConnection for FakeAcpAgentConnection {
+ fn agent_id(&self) -> AgentId {
+ self.inner.agent_id()
+ }
+
+ fn telemetry_id(&self) -> SharedString {
+ self.inner.telemetry_id()
+ }
+
+ fn new_session(
+ self: Rc<Self>,
+ project: Entity<Project>,
+ work_dirs: PathList,
+ cx: &mut App,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ self.inner.clone().new_session(project, work_dirs, cx)
+ }
+
+ fn supports_load_session(&self) -> bool {
+ self.inner.supports_load_session()
+ }
+
+ fn load_session(
+ self: Rc<Self>,
+ session_id: acp::SessionId,
+ project: Entity<Project>,
+ work_dirs: PathList,
+ title: Option<SharedString>,
+ cx: &mut App,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ self.inner
+ .clone()
+ .load_session(session_id, project, work_dirs, title, cx)
+ }
+
+ fn supports_close_session(&self) -> bool {
+ self.inner.supports_close_session()
+ }
+
+ fn close_session(
+ self: Rc<Self>,
+ session_id: &acp::SessionId,
+ cx: &mut App,
+ ) -> Task<Result<()>> {
+ self.inner.clone().close_session(session_id, cx)
+ }
+
+ fn supports_resume_session(&self) -> bool {
+ self.inner.supports_resume_session()
+ }
+
+ fn resume_session(
+ self: Rc<Self>,
+ session_id: acp::SessionId,
+ project: Entity<Project>,
+ work_dirs: PathList,
+ title: Option<SharedString>,
+ cx: &mut App,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ self.inner
+ .clone()
+ .resume_session(session_id, project, work_dirs, title, cx)
+ }
+
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ self.inner.auth_methods()
+ }
+
+ fn terminal_auth_task(
+ &self,
+ method: &acp::AuthMethodId,
+ cx: &App,
+ ) -> Option<Task<Result<SpawnInTerminal>>> {
+ self.inner.terminal_auth_task(method, cx)
+ }
+
+ fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
+ self.inner.authenticate(method, cx)
+ }
+
+ fn prompt(
+ &self,
+ user_message_id: UserMessageId,
+ params: acp::PromptRequest,
+ cx: &mut App,
+ ) -> Task<Result<acp::PromptResponse>> {
+ self.inner.prompt(user_message_id, params, cx)
+ }
+
+ fn retry(
+ &self,
+ session_id: &acp::SessionId,
+ cx: &App,
+ ) -> Option<Rc<dyn AgentSessionRetry>> {
+ self.inner.retry(session_id, cx)
+ }
+
+ fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
+ self.inner.cancel(session_id, cx)
+ }
+
+ fn truncate(
+ &self,
+ session_id: &acp::SessionId,
+ cx: &App,
+ ) -> Option<Rc<dyn AgentSessionTruncate>> {
+ self.inner.truncate(session_id, cx)
+ }
+
+ fn set_title(
+ &self,
+ session_id: &acp::SessionId,
+ cx: &App,
+ ) -> Option<Rc<dyn AgentSessionSetTitle>> {
+ self.inner.set_title(session_id, cx)
+ }
+
+ fn model_selector(
+ &self,
+ session_id: &acp::SessionId,
+ ) -> Option<Rc<dyn AgentModelSelector>> {
+ self.inner.model_selector(session_id)
+ }
+
+ fn telemetry(&self) -> Option<Rc<dyn AgentTelemetry>> {
+ self.inner.telemetry()
+ }
+
+ fn session_modes(
+ &self,
+ session_id: &acp::SessionId,
+ cx: &App,
+ ) -> Option<Rc<dyn AgentSessionModes>> {
+ self.inner.session_modes(session_id, cx)
+ }
+
+ fn session_config_options(
+ &self,
+ session_id: &acp::SessionId,
+ cx: &App,
+ ) -> Option<Rc<dyn AgentSessionConfigOptions>> {
+ self.inner.session_config_options(session_id, cx)
+ }
+
+ fn session_list(&self, cx: &mut App) -> Option<Rc<dyn AgentSessionList>> {
+ self.inner.session_list(cx)
+ }
+
+ fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
+ self
+ }
+ }
+
+ struct FakeAcpAgent {
+ load_session_count: Arc<AtomicUsize>,
+ close_session_count: Arc<AtomicUsize>,
+ fail_next_prompt: Arc<AtomicBool>,
+ }
+
+ #[async_trait::async_trait(?Send)]
+ impl acp::Agent for FakeAcpAgent {
+ async fn initialize(
+ &self,
+ args: acp::InitializeRequest,
+ ) -> acp::Result<acp::InitializeResponse> {
+ Ok(
+ acp::InitializeResponse::new(args.protocol_version).agent_capabilities(
+ acp::AgentCapabilities::default()
+ .load_session(true)
+ .session_capabilities(
+ acp::SessionCapabilities::default()
+ .close(acp::SessionCloseCapabilities::new()),
+ ),
+ ),
+ )
+ }
+
+ async fn authenticate(
+ &self,
+ _: acp::AuthenticateRequest,
+ ) -> acp::Result<acp::AuthenticateResponse> {
+ Ok(Default::default())
+ }
+
+ async fn new_session(
+ &self,
+ _: acp::NewSessionRequest,
+ ) -> acp::Result<acp::NewSessionResponse> {
+ Ok(acp::NewSessionResponse::new(acp::SessionId::new("unused")))
+ }
+
+ async fn prompt(&self, _: acp::PromptRequest) -> acp::Result<acp::PromptResponse> {
+ if self.fail_next_prompt.swap(false, Ordering::SeqCst) {
+ Err(acp::ErrorCode::InternalError.into())
+ } else {
+ Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
+ }
+ }
+
+ async fn cancel(&self, _: acp::CancelNotification) -> acp::Result<()> {
+ Ok(())
+ }
+
+ async fn load_session(
+ &self,
+ _: acp::LoadSessionRequest,
+ ) -> acp::Result<acp::LoadSessionResponse> {
+ self.load_session_count.fetch_add(1, Ordering::SeqCst);
+ Ok(acp::LoadSessionResponse::new())
+ }
+
+ async fn close_session(
+ &self,
+ _: acp::CloseSessionRequest,
+ ) -> acp::Result<acp::CloseSessionResponse> {
+ self.close_session_count.fetch_add(1, Ordering::SeqCst);
+ Ok(acp::CloseSessionResponse::new())
+ }
+ }
+
+ async fn build_fake_acp_connection(
+ project: Entity<Project>,
+ load_session_count: Arc<AtomicUsize>,
+ close_session_count: Arc<AtomicUsize>,
+ fail_next_prompt: Arc<AtomicBool>,
+ cx: &mut AsyncApp,
+ ) -> Result<FakeAcpConnectionHarness> {
+ let (c2a_writer, c2a_reader) = async_pipe::pipe();
+ let (a2c_writer, a2c_reader) = async_pipe::pipe();
+
+ let sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>> =
+ Rc::new(RefCell::new(HashMap::default()));
+ let session_list_container: Rc<RefCell<Option<Rc<AcpSessionList>>>> =
+ Rc::new(RefCell::new(None));
+
+ let foreground = cx.foreground_executor().clone();
+
+ let client_delegate = ClientDelegate {
+ sessions: sessions.clone(),
+ session_list: session_list_container,
+ cx: cx.clone(),
+ };
+
+ let (client_conn, client_io_task) =
+ acp::ClientSideConnection::new(client_delegate, c2a_writer, a2c_reader, {
+ let foreground = foreground.clone();
+ move |fut| {
+ foreground.spawn(fut).detach();
+ }
+ });
+
+ let fake_agent = FakeAcpAgent {
+ load_session_count: load_session_count.clone(),
+ close_session_count: close_session_count.clone(),
+ fail_next_prompt,
+ };
+
+ let (_, agent_io_task) =
+ acp::AgentSideConnection::new(fake_agent, a2c_writer, c2a_reader, {
+ let foreground = foreground.clone();
+ move |fut| {
+ foreground.spawn(fut).detach();
+ }
+ });
+
+ let client_io_task = cx.background_spawn(client_io_task);
+ let agent_io_task = cx.background_spawn(agent_io_task);
+
+ let response = client_conn
+ .initialize(acp::InitializeRequest::new(acp::ProtocolVersion::V1))
+ .await?;
+
+ let agent_capabilities = response.agent_capabilities;
+
+ let agent_server_store =
+ project.read_with(cx, |project, _| project.agent_server_store().downgrade());
+
+ let connection = cx.update(|cx| {
+ AcpConnection::new_for_test(
+ Rc::new(client_conn),
+ sessions,
+ agent_capabilities,
+ agent_server_store,
+ client_io_task,
+ cx,
+ )
+ });
+
+ let keep_agent_alive = cx.background_spawn(async move {
+ agent_io_task.await.ok();
+ anyhow::Ok(())
+ });
+
+ Ok(FakeAcpConnectionHarness {
+ connection: Rc::new(connection),
+ load_session_count,
+ close_session_count,
+ keep_agent_alive,
+ })
+ }
+
+ pub async fn connect_fake_acp_connection(
+ project: Entity<Project>,
+ cx: &mut gpui::TestAppContext,
+ ) -> FakeAcpConnectionHarness {
+ cx.update(|cx| {
+ let store = settings::SettingsStore::test(cx);
+ cx.set_global(store);
+ });
+
+ build_fake_acp_connection(
+ project,
+ Arc::new(AtomicUsize::new(0)),
+ Arc::new(AtomicUsize::new(0)),
+ Arc::new(AtomicBool::new(false)),
+ &mut cx.to_async(),
+ )
+ .await
+ .expect("failed to initialize ACP connection")
+ }
+}
+
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
@@ -1420,8 +1871,8 @@ mod tests {
let load_count = Arc::new(AtomicUsize::new(0));
let close_count = Arc::new(AtomicUsize::new(0));
- let (c2a_reader, c2a_writer) = piper::pipe(4096);
- let (a2c_reader, a2c_writer) = piper::pipe(4096);
+ let (c2a_writer, c2a_reader) = async_pipe::pipe();
+ let (a2c_writer, a2c_reader) = async_pipe::pipe();
let sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>> =
Rc::new(RefCell::new(HashMap::default()));
@@ -1866,17 +2317,24 @@ impl acp::Client for ClientDelegate {
&self,
notification: acp::SessionNotification,
) -> Result<(), acp::Error> {
- let sessions = self.sessions.borrow();
- let session = sessions
- .get(¬ification.session_id)
- .context("Failed to get session")?;
+ let (thread, session_modes, session_config_options) = {
+ let sessions = self.sessions.borrow();
+ let session = sessions
+ .get(¬ification.session_id)
+ .context("Failed to get session")?;
+ (
+ session.thread.clone(),
+ session.session_modes.clone(),
+ session.config_options.clone(),
+ )
+ };
if let acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate {
current_mode_id,
..
}) = ¬ification.update
{
- if let Some(session_modes) = &session.session_modes {
+ if let Some(session_modes) = &session_modes {
session_modes.borrow_mut().current_mode_id = current_mode_id.clone();
}
}
@@ -1886,7 +2344,7 @@ impl acp::Client for ClientDelegate {
..
}) = ¬ification.update
{
- if let Some(opts) = &session.config_options {
+ if let Some(opts) = &session_config_options {
*opts.config_options.borrow_mut() = config_options.clone();
opts.tx.borrow_mut().send(()).ok();
}
@@ -1913,7 +2371,7 @@ impl acp::Client for ClientDelegate {
.and_then(|v| v.as_str().map(PathBuf::from));
// Create a minimal display-only lower-level terminal and register it.
- let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| {
+ let _ = thread.update(&mut self.cx.clone(), |thread, cx| {
let builder = TerminalBuilder::new_display_only(
CursorShape::default(),
AlternateScroll::On,
@@ -1941,7 +2399,7 @@ impl acp::Client for ClientDelegate {
}
// Forward the update to the acp_thread as usual.
- session.thread.update(&mut self.cx.clone(), |thread, cx| {
+ thread.update(&mut self.cx.clone(), |thread, cx| {
thread.handle_session_update(notification.update.clone(), cx)
})??;
@@ -1953,7 +2411,7 @@ impl acp::Client for ClientDelegate {
let terminal_id = acp::TerminalId::new(id_str);
if let Some(s) = term_out.get("data").and_then(|v| v.as_str()) {
let data = s.as_bytes().to_vec();
- let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| {
+ let _ = thread.update(&mut self.cx.clone(), |thread, cx| {
thread.on_terminal_provider_event(
TerminalProviderEvent::Output { terminal_id, data },
cx,
@@ -1980,7 +2438,7 @@ impl acp::Client for ClientDelegate {
.and_then(|v| v.as_str().map(|s| s.to_string())),
);
- let _ = session.thread.update(&mut self.cx.clone(), |thread, cx| {
+ let _ = thread.update(&mut self.cx.clone(), |thread, cx| {
thread.on_terminal_provider_event(
TerminalProviderEvent::Exit {
terminal_id,
@@ -2863,6 +2863,7 @@ pub(crate) mod tests {
use action_log::ActionLog;
use agent::{AgentTool, EditFileTool, FetchTool, TerminalTool, ToolPermissionContext};
use agent_client_protocol::SessionId;
+ use agent_servers::FakeAcpAgentServer;
use editor::MultiBufferOffset;
use fs::FakeFs;
use gpui::{EventEmitter, TestAppContext, VisualTestContext};
@@ -2972,8 +2973,8 @@ pub(crate) mod tests {
async fn test_notification_for_error(cx: &mut TestAppContext) {
init_test(cx);
- let (conversation_view, cx) =
- setup_conversation_view(StubAgentServer::new(SaboteurAgentConnection), cx).await;
+ let server = FakeAcpAgentServer::new();
+ let (conversation_view, cx) = setup_conversation_view(server.clone(), cx).await;
let message_editor = message_editor(&conversation_view, cx);
message_editor.update_in(cx, |editor, window, cx| {
@@ -2981,6 +2982,7 @@ pub(crate) mod tests {
});
cx.deactivate_window();
+ server.fail_next_prompt();
active_thread(&conversation_view, cx)
.update_in(cx, |view, window, cx| view.send(window, cx));
@@ -2994,6 +2996,34 @@ pub(crate) mod tests {
);
}
+ #[gpui::test]
+ async fn test_acp_server_exit_transitions_conversation_to_load_error_without_panic(
+ cx: &mut TestAppContext,
+ ) {
+ init_test(cx);
+
+ let server = FakeAcpAgentServer::new();
+ let close_session_count = server.close_session_count();
+ let (conversation_view, cx) = setup_conversation_view(server.clone(), cx).await;
+
+ cx.run_until_parked();
+
+ server.simulate_server_exit();
+ cx.run_until_parked();
+
+ conversation_view.read_with(cx, |view, _cx| {
+ assert!(
+ matches!(view.server_state, ServerState::LoadError { .. }),
+ "Conversation should transition to LoadError when an ACP thread exits"
+ );
+ });
+ assert_eq!(
+ close_session_count.load(std::sync::atomic::Ordering::SeqCst),
+ 1,
+ "ConversationView should close the ACP session after a thread exit"
+ );
+ }
+
#[gpui::test]
async fn test_recent_history_refreshes_when_history_cache_updated(cx: &mut TestAppContext) {
init_test(cx);
@@ -4520,75 +4550,6 @@ pub(crate) mod tests {
}
}
- #[derive(Clone)]
- struct SaboteurAgentConnection;
-
- impl AgentConnection for SaboteurAgentConnection {
- fn agent_id(&self) -> AgentId {
- AgentId::new("saboteur")
- }
-
- fn telemetry_id(&self) -> SharedString {
- "saboteur".into()
- }
-
- fn new_session(
- self: Rc<Self>,
- project: Entity<Project>,
- work_dirs: PathList,
- cx: &mut gpui::App,
- ) -> Task<gpui::Result<Entity<AcpThread>>> {
- Task::ready(Ok(cx.new(|cx| {
- let action_log = cx.new(|_| ActionLog::new(project.clone()));
- AcpThread::new(
- None,
- None,
- Some(work_dirs),
- self,
- project,
- action_log,
- SessionId::new("test"),
- watch::Receiver::constant(
- acp::PromptCapabilities::new()
- .image(true)
- .audio(true)
- .embedded_context(true),
- ),
- cx,
- )
- })))
- }
-
- fn auth_methods(&self) -> &[acp::AuthMethod] {
- &[]
- }
-
- fn authenticate(
- &self,
- _method_id: acp::AuthMethodId,
- _cx: &mut App,
- ) -> Task<gpui::Result<()>> {
- unimplemented!()
- }
-
- fn prompt(
- &self,
- _id: acp_thread::UserMessageId,
- _params: acp::PromptRequest,
- _cx: &mut App,
- ) -> Task<gpui::Result<acp::PromptResponse>> {
- Task::ready(Err(anyhow::anyhow!("Error prompting")))
- }
-
- fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
- unimplemented!()
- }
-
- fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
- self
- }
- }
-
/// Simulates a model which always returns a refusal response
#[derive(Clone)]
struct RefusalAgentConnection;