test_support.rs

  1use acp_thread::{AgentConnection, StubAgentConnection};
  2use agent_client_protocol as acp;
  3use agent_servers::{AgentServer, AgentServerDelegate};
  4use gpui::{Entity, Task, TestAppContext, VisualTestContext};
  5use project::AgentId;
  6use project::Project;
  7use settings::SettingsStore;
  8use std::any::Any;
  9use std::cell::RefCell;
 10use std::rc::Rc;
 11
 12use crate::AgentPanel;
 13use crate::agent_panel;
 14
 15thread_local! {
 16    static STUB_AGENT_CONNECTION: RefCell<Option<StubAgentConnection>> = const { RefCell::new(None) };
 17}
 18
 19/// Registers a `StubAgentConnection` that will be used by `Agent::Stub`.
 20///
 21/// Returns the same connection so callers can hold onto it and control
 22/// the stub's behavior (e.g. `connection.set_next_prompt_updates(...)`).
 23pub fn set_stub_agent_connection(connection: StubAgentConnection) -> StubAgentConnection {
 24    STUB_AGENT_CONNECTION.with(|cell| {
 25        *cell.borrow_mut() = Some(connection.clone());
 26    });
 27    connection
 28}
 29
 30/// Returns the shared `StubAgentConnection` used by `Agent::Stub`,
 31/// creating a default one if none was registered.
 32pub fn stub_agent_connection() -> StubAgentConnection {
 33    STUB_AGENT_CONNECTION.with(|cell| {
 34        let mut borrow = cell.borrow_mut();
 35        borrow.get_or_insert_with(StubAgentConnection::new).clone()
 36    })
 37}
 38
 39pub struct StubAgentServer<C> {
 40    connection: C,
 41    agent_id: AgentId,
 42}
 43
 44impl<C> StubAgentServer<C>
 45where
 46    C: AgentConnection,
 47{
 48    pub fn new(connection: C) -> Self {
 49        Self {
 50            connection,
 51            agent_id: "Test".into(),
 52        }
 53    }
 54
 55    pub fn with_connection_agent_id(mut self) -> Self {
 56        self.agent_id = self.connection.agent_id();
 57        self
 58    }
 59}
 60
 61impl StubAgentServer<StubAgentConnection> {
 62    pub fn default_response() -> Self {
 63        let conn = StubAgentConnection::new();
 64        conn.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk(
 65            acp::ContentChunk::new("Default response".into()),
 66        )]);
 67        Self::new(conn)
 68    }
 69}
 70
 71impl<C> AgentServer for StubAgentServer<C>
 72where
 73    C: 'static + AgentConnection + Send + Clone,
 74{
 75    fn logo(&self) -> ui::IconName {
 76        ui::IconName::ZedAgent
 77    }
 78
 79    fn agent_id(&self) -> AgentId {
 80        self.agent_id.clone()
 81    }
 82
 83    fn connect(
 84        &self,
 85        _delegate: AgentServerDelegate,
 86        _project: Entity<Project>,
 87        _cx: &mut gpui::App,
 88    ) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
 89        Task::ready(Ok(Rc::new(self.connection.clone())))
 90    }
 91
 92    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 93        self
 94    }
 95}
 96
 97pub fn init_test(cx: &mut TestAppContext) {
 98    cx.update(|cx| {
 99        let settings_store = SettingsStore::test(cx);
100        cx.set_global(settings_store);
101        cx.set_global(acp_thread::StubSessionCounter(
102            std::sync::atomic::AtomicUsize::new(0),
103        ));
104        theme_settings::init(theme::LoadThemes::JustBase, cx);
105        editor::init(cx);
106        release_channel::init("0.0.0".parse().unwrap(), cx);
107        agent_panel::init(cx);
108    });
109}
110
111pub fn open_thread_with_connection(
112    panel: &Entity<AgentPanel>,
113    connection: StubAgentConnection,
114    cx: &mut VisualTestContext,
115) {
116    panel.update_in(cx, |panel, window, cx| {
117        panel.open_external_thread_with_server(
118            Rc::new(StubAgentServer::new(connection)),
119            window,
120            cx,
121        );
122    });
123    cx.run_until_parked();
124}
125
126pub fn open_thread_with_custom_connection<C>(
127    panel: &Entity<AgentPanel>,
128    connection: C,
129    cx: &mut VisualTestContext,
130) where
131    C: 'static + AgentConnection + Send + Clone,
132{
133    panel.update_in(cx, |panel, window, cx| {
134        panel.open_external_thread_with_server(
135            Rc::new(StubAgentServer::new(connection).with_connection_agent_id()),
136            window,
137            cx,
138        );
139    });
140    cx.run_until_parked();
141}
142
143pub fn send_message(panel: &Entity<AgentPanel>, cx: &mut VisualTestContext) {
144    let thread_view = panel.read_with(cx, |panel, cx| panel.active_thread_view(cx).unwrap());
145    let message_editor = thread_view.read_with(cx, |view, _cx| view.message_editor.clone());
146    message_editor.update_in(cx, |editor, window, cx| {
147        editor.set_text("Hello", window, cx);
148    });
149    thread_view.update_in(cx, |view, window, cx| view.send(window, cx));
150    cx.run_until_parked();
151}
152
153pub fn active_session_id(panel: &Entity<AgentPanel>, cx: &VisualTestContext) -> acp::SessionId {
154    panel.read_with(cx, |panel, cx| {
155        let thread = panel.active_agent_thread(cx).unwrap();
156        thread.read(cx).session_id().clone()
157    })
158}
159
160pub fn active_thread_id(
161    panel: &Entity<AgentPanel>,
162    cx: &VisualTestContext,
163) -> crate::thread_metadata_store::ThreadId {
164    panel.read_with(cx, |panel, cx| panel.active_thread_id(cx).unwrap())
165}