test_support.rs

  1use acp_thread::{AgentConnection, StubAgentConnection};
  2use agent_client_protocol as acp;
  3use agent_servers::{AgentServer, AgentServerDelegate};
  4use gpui::{Entity, Task, TestAppContext, VisualTestContext};
  5use project::AgentId;
  6use project::Project;
  7use settings::SettingsStore;
  8use std::any::Any;
  9use std::rc::Rc;
 10
 11use crate::AgentPanel;
 12use crate::agent_panel;
 13
 14pub struct StubAgentServer<C> {
 15    connection: C,
 16    agent_id: AgentId,
 17}
 18
 19impl<C> StubAgentServer<C>
 20where
 21    C: AgentConnection,
 22{
 23    pub fn new(connection: C) -> Self {
 24        Self {
 25            connection,
 26            agent_id: "Test".into(),
 27        }
 28    }
 29
 30    pub fn with_connection_agent_id(mut self) -> Self {
 31        self.agent_id = self.connection.agent_id();
 32        self
 33    }
 34}
 35
 36impl StubAgentServer<StubAgentConnection> {
 37    pub fn default_response() -> Self {
 38        let conn = StubAgentConnection::new();
 39        conn.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk(
 40            acp::ContentChunk::new("Default response".into()),
 41        )]);
 42        Self::new(conn)
 43    }
 44}
 45
 46impl<C> AgentServer for StubAgentServer<C>
 47where
 48    C: 'static + AgentConnection + Send + Clone,
 49{
 50    fn logo(&self) -> ui::IconName {
 51        ui::IconName::ZedAgent
 52    }
 53
 54    fn agent_id(&self) -> AgentId {
 55        self.agent_id.clone()
 56    }
 57
 58    fn connect(
 59        &self,
 60        _delegate: AgentServerDelegate,
 61        _project: Entity<Project>,
 62        _cx: &mut gpui::App,
 63    ) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
 64        Task::ready(Ok(Rc::new(self.connection.clone())))
 65    }
 66
 67    fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
 68        self
 69    }
 70}
 71
 72pub fn init_test(cx: &mut TestAppContext) {
 73    cx.update(|cx| {
 74        let settings_store = SettingsStore::test(cx);
 75        cx.set_global(settings_store);
 76        cx.set_global(acp_thread::StubSessionCounter(
 77            std::sync::atomic::AtomicUsize::new(0),
 78        ));
 79        theme_settings::init(theme::LoadThemes::JustBase, cx);
 80        editor::init(cx);
 81        release_channel::init("0.0.0".parse().unwrap(), cx);
 82        agent_panel::init(cx);
 83    });
 84}
 85
 86pub fn open_thread_with_connection(
 87    panel: &Entity<AgentPanel>,
 88    connection: StubAgentConnection,
 89    cx: &mut VisualTestContext,
 90) {
 91    panel.update_in(cx, |panel, window, cx| {
 92        panel.open_external_thread_with_server(
 93            Rc::new(StubAgentServer::new(connection)),
 94            window,
 95            cx,
 96        );
 97    });
 98    cx.run_until_parked();
 99}
100
101pub fn open_thread_with_custom_connection<C>(
102    panel: &Entity<AgentPanel>,
103    connection: C,
104    cx: &mut VisualTestContext,
105) where
106    C: 'static + AgentConnection + Send + Clone,
107{
108    panel.update_in(cx, |panel, window, cx| {
109        panel.open_external_thread_with_server(
110            Rc::new(StubAgentServer::new(connection).with_connection_agent_id()),
111            window,
112            cx,
113        );
114    });
115    cx.run_until_parked();
116}
117
118pub fn send_message(panel: &Entity<AgentPanel>, cx: &mut VisualTestContext) {
119    let thread_view = panel.read_with(cx, |panel, cx| panel.active_thread_view(cx).unwrap());
120    let message_editor = thread_view.read_with(cx, |view, _cx| view.message_editor.clone());
121    message_editor.update_in(cx, |editor, window, cx| {
122        editor.set_text("Hello", window, cx);
123    });
124    thread_view.update_in(cx, |view, window, cx| view.send(window, cx));
125    cx.run_until_parked();
126}
127
128pub fn active_session_id(panel: &Entity<AgentPanel>, cx: &VisualTestContext) -> acp::SessionId {
129    panel.read_with(cx, |panel, cx| {
130        let thread = panel.active_agent_thread(cx).unwrap();
131        thread.read(cx).session_id().clone()
132    })
133}
134
135pub fn active_thread_id(
136    panel: &Entity<AgentPanel>,
137    cx: &VisualTestContext,
138) -> crate::thread_metadata_store::ThreadId {
139    panel.read_with(cx, |panel, cx| panel.active_thread_id(cx).unwrap())
140}