1use acp_thread::{AgentConnection, StubAgentConnection};
2use agent_client_protocol as acp;
3use agent_servers::{AgentServer, AgentServerDelegate};
4use gpui::{Entity, Task, TestAppContext, VisualTestContext};
5use project::AgentId;
6use project::Project;
7use settings::SettingsStore;
8use std::any::Any;
9use std::cell::RefCell;
10use std::rc::Rc;
11
12use crate::AgentPanel;
13use crate::agent_panel;
14
15thread_local! {
16 static STUB_AGENT_CONNECTION: RefCell<Option<StubAgentConnection>> = const { RefCell::new(None) };
17}
18
19/// Registers a `StubAgentConnection` that will be used by `Agent::Stub`.
20///
21/// Returns the same connection so callers can hold onto it and control
22/// the stub's behavior (e.g. `connection.set_next_prompt_updates(...)`).
23pub fn set_stub_agent_connection(connection: StubAgentConnection) -> StubAgentConnection {
24 STUB_AGENT_CONNECTION.with(|cell| {
25 *cell.borrow_mut() = Some(connection.clone());
26 });
27 connection
28}
29
30/// Returns the shared `StubAgentConnection` used by `Agent::Stub`,
31/// creating a default one if none was registered.
32pub fn stub_agent_connection() -> StubAgentConnection {
33 STUB_AGENT_CONNECTION.with(|cell| {
34 let mut borrow = cell.borrow_mut();
35 borrow.get_or_insert_with(StubAgentConnection::new).clone()
36 })
37}
38
39pub struct StubAgentServer<C> {
40 connection: C,
41 agent_id: AgentId,
42}
43
44impl<C> StubAgentServer<C>
45where
46 C: AgentConnection,
47{
48 pub fn new(connection: C) -> Self {
49 Self {
50 connection,
51 agent_id: "Test".into(),
52 }
53 }
54
55 pub fn with_connection_agent_id(mut self) -> Self {
56 self.agent_id = self.connection.agent_id();
57 self
58 }
59}
60
61impl StubAgentServer<StubAgentConnection> {
62 pub fn default_response() -> Self {
63 let conn = StubAgentConnection::new();
64 conn.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk(
65 acp::ContentChunk::new("Default response".into()),
66 )]);
67 Self::new(conn)
68 }
69}
70
71impl<C> AgentServer for StubAgentServer<C>
72where
73 C: 'static + AgentConnection + Send + Clone,
74{
75 fn logo(&self) -> ui::IconName {
76 ui::IconName::ZedAgent
77 }
78
79 fn agent_id(&self) -> AgentId {
80 self.agent_id.clone()
81 }
82
83 fn connect(
84 &self,
85 _delegate: AgentServerDelegate,
86 _project: Entity<Project>,
87 _cx: &mut gpui::App,
88 ) -> Task<gpui::Result<Rc<dyn AgentConnection>>> {
89 Task::ready(Ok(Rc::new(self.connection.clone())))
90 }
91
92 fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
93 self
94 }
95}
96
97pub fn init_test(cx: &mut TestAppContext) {
98 cx.update(|cx| {
99 let settings_store = SettingsStore::test(cx);
100 cx.set_global(settings_store);
101 cx.set_global(acp_thread::StubSessionCounter(
102 std::sync::atomic::AtomicUsize::new(0),
103 ));
104 theme_settings::init(theme::LoadThemes::JustBase, cx);
105 editor::init(cx);
106 release_channel::init("0.0.0".parse().unwrap(), cx);
107 agent_panel::init(cx);
108 });
109}
110
111pub fn open_thread_with_connection(
112 panel: &Entity<AgentPanel>,
113 connection: StubAgentConnection,
114 cx: &mut VisualTestContext,
115) {
116 panel.update_in(cx, |panel, window, cx| {
117 panel.open_external_thread_with_server(
118 Rc::new(StubAgentServer::new(connection)),
119 window,
120 cx,
121 );
122 });
123 cx.run_until_parked();
124}
125
126pub fn open_thread_with_custom_connection<C>(
127 panel: &Entity<AgentPanel>,
128 connection: C,
129 cx: &mut VisualTestContext,
130) where
131 C: 'static + AgentConnection + Send + Clone,
132{
133 panel.update_in(cx, |panel, window, cx| {
134 panel.open_external_thread_with_server(
135 Rc::new(StubAgentServer::new(connection).with_connection_agent_id()),
136 window,
137 cx,
138 );
139 });
140 cx.run_until_parked();
141}
142
143pub fn send_message(panel: &Entity<AgentPanel>, cx: &mut VisualTestContext) {
144 let thread_view = panel.read_with(cx, |panel, cx| panel.active_thread_view(cx).unwrap());
145 let message_editor = thread_view.read_with(cx, |view, _cx| view.message_editor.clone());
146 message_editor.update_in(cx, |editor, window, cx| {
147 editor.set_text("Hello", window, cx);
148 });
149 thread_view.update_in(cx, |view, window, cx| view.send(window, cx));
150 cx.run_until_parked();
151}
152
153pub fn active_session_id(panel: &Entity<AgentPanel>, cx: &VisualTestContext) -> acp::SessionId {
154 panel.read_with(cx, |panel, cx| {
155 let thread = panel.active_agent_thread(cx).unwrap();
156 thread.read(cx).session_id().clone()
157 })
158}
159
160pub fn active_thread_id(
161 panel: &Entity<AgentPanel>,
162 cx: &VisualTestContext,
163) -> crate::thread_metadata_store::ThreadId {
164 panel.read_with(cx, |panel, cx| panel.active_thread_id(cx).unwrap())
165}