1use std::{path::Path, sync::Arc, time::Duration};
2
3use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
4use acp_thread::{
5 AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
6};
7use agentic_coding_protocol as acp;
8use futures::{FutureExt, StreamExt, channel::mpsc, select};
9use gpui::{Entity, TestAppContext};
10use indoc::indoc;
11use project::{FakeFs, Project};
12use serde_json::json;
13use settings::{Settings, SettingsStore};
14use util::path;
15
16pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
17 let fs = init_test(cx).await;
18 let project = Project::test(fs, [], cx).await;
19 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
20
21 thread
22 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
23 .await
24 .unwrap();
25
26 thread.read_with(cx, |thread, _| {
27 assert_eq!(thread.entries().len(), 2);
28 assert!(matches!(
29 thread.entries()[0],
30 AgentThreadEntry::UserMessage(_)
31 ));
32 assert!(matches!(
33 thread.entries()[1],
34 AgentThreadEntry::AssistantMessage(_)
35 ));
36 });
37}
38
39pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
40 let _fs = init_test(cx).await;
41
42 let tempdir = tempfile::tempdir().unwrap();
43 std::fs::write(
44 tempdir.path().join("foo.rs"),
45 indoc! {"
46 fn main() {
47 println!(\"Hello, world!\");
48 }
49 "},
50 )
51 .expect("failed to write file");
52 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
53 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
54 thread
55 .update(cx, |thread, cx| {
56 thread.send(
57 acp::SendUserMessageParams {
58 chunks: vec![
59 acp::UserMessageChunk::Text {
60 text: "Read the file ".into(),
61 },
62 acp::UserMessageChunk::Path {
63 path: Path::new("foo.rs").into(),
64 },
65 acp::UserMessageChunk::Text {
66 text: " and tell me what the content of the println! is".into(),
67 },
68 ],
69 },
70 cx,
71 )
72 })
73 .await
74 .unwrap();
75
76 thread.read_with(cx, |thread, cx| {
77 assert_eq!(thread.entries().len(), 3);
78 assert!(matches!(
79 thread.entries()[0],
80 AgentThreadEntry::UserMessage(_)
81 ));
82 assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
83 let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
84 panic!("Expected AssistantMessage")
85 };
86 assert!(
87 assistant_message.to_markdown(cx).contains("Hello, world!"),
88 "unexpected assistant message: {:?}",
89 assistant_message.to_markdown(cx)
90 );
91 });
92}
93
94pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
95 let fs = init_test(cx).await;
96 fs.insert_tree(
97 path!("/private/tmp"),
98 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
99 )
100 .await;
101 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
102 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
103
104 thread
105 .update(cx, |thread, cx| {
106 thread.send_raw(
107 "Read the '/private/tmp/foo' file and tell me what you see.",
108 cx,
109 )
110 })
111 .await
112 .unwrap();
113 thread.read_with(cx, |thread, _cx| {
114 assert!(matches!(
115 &thread.entries()[2],
116 AgentThreadEntry::ToolCall(ToolCall {
117 status: ToolCallStatus::Allowed { .. },
118 ..
119 })
120 ));
121
122 assert!(matches!(
123 thread.entries()[3],
124 AgentThreadEntry::AssistantMessage(_)
125 ));
126 });
127}
128
129pub async fn test_tool_call_with_confirmation(
130 server: impl AgentServer + 'static,
131 cx: &mut TestAppContext,
132) {
133 let fs = init_test(cx).await;
134 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
135 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
136 let full_turn = thread.update(cx, |thread, cx| {
137 thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
138 });
139
140 run_until_first_tool_call(&thread, cx).await;
141
142 let tool_call_id = thread.read_with(cx, |thread, _cx| {
143 let AgentThreadEntry::ToolCall(ToolCall {
144 id,
145 status:
146 ToolCallStatus::WaitingForConfirmation {
147 confirmation: ToolCallConfirmation::Execute { root_command, .. },
148 ..
149 },
150 ..
151 }) = &thread.entries()[2]
152 else {
153 panic!();
154 };
155
156 assert_eq!(root_command, "echo");
157
158 *id
159 });
160
161 thread.update(cx, |thread, cx| {
162 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
163
164 assert!(matches!(
165 &thread.entries()[2],
166 AgentThreadEntry::ToolCall(ToolCall {
167 status: ToolCallStatus::Allowed { .. },
168 ..
169 })
170 ));
171 });
172
173 full_turn.await.unwrap();
174
175 thread.read_with(cx, |thread, cx| {
176 let AgentThreadEntry::ToolCall(ToolCall {
177 content: Some(ToolCallContent::Markdown { markdown }),
178 status: ToolCallStatus::Allowed { .. },
179 ..
180 }) = &thread.entries()[2]
181 else {
182 panic!();
183 };
184
185 markdown.read_with(cx, |md, _cx| {
186 assert!(
187 md.source().contains("Hello, world!"),
188 r#"Expected '{}' to contain "Hello, world!""#,
189 md.source()
190 );
191 });
192 });
193}
194
195pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
196 let fs = init_test(cx).await;
197
198 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
199 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
200 let full_turn = thread.update(cx, |thread, cx| {
201 thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
202 });
203
204 let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
205
206 thread.read_with(cx, |thread, _cx| {
207 let AgentThreadEntry::ToolCall(ToolCall {
208 id,
209 status:
210 ToolCallStatus::WaitingForConfirmation {
211 confirmation: ToolCallConfirmation::Execute { root_command, .. },
212 ..
213 },
214 ..
215 }) = &thread.entries()[first_tool_call_ix]
216 else {
217 panic!("{:?}", thread.entries()[1]);
218 };
219
220 assert_eq!(root_command, "echo");
221
222 *id
223 });
224
225 thread
226 .update(cx, |thread, cx| thread.cancel(cx))
227 .await
228 .unwrap();
229 full_turn.await.unwrap();
230 thread.read_with(cx, |thread, _| {
231 let AgentThreadEntry::ToolCall(ToolCall {
232 status: ToolCallStatus::Canceled,
233 ..
234 }) = &thread.entries()[first_tool_call_ix]
235 else {
236 panic!();
237 };
238 });
239
240 thread
241 .update(cx, |thread, cx| {
242 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
243 })
244 .await
245 .unwrap();
246 thread.read_with(cx, |thread, _| {
247 assert!(matches!(
248 &thread.entries().last().unwrap(),
249 AgentThreadEntry::AssistantMessage(..),
250 ))
251 });
252}
253
254#[macro_export]
255macro_rules! common_e2e_tests {
256 ($server:expr) => {
257 mod common_e2e {
258 use super::*;
259
260 #[::gpui::test]
261 #[cfg_attr(not(feature = "e2e"), ignore)]
262 async fn basic(cx: &mut ::gpui::TestAppContext) {
263 $crate::e2e_tests::test_basic($server, cx).await;
264 }
265
266 #[::gpui::test]
267 #[cfg_attr(not(feature = "e2e"), ignore)]
268 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
269 $crate::e2e_tests::test_path_mentions($server, cx).await;
270 }
271
272 #[::gpui::test]
273 #[cfg_attr(not(feature = "e2e"), ignore)]
274 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
275 $crate::e2e_tests::test_tool_call($server, cx).await;
276 }
277
278 #[::gpui::test]
279 #[cfg_attr(not(feature = "e2e"), ignore)]
280 async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
281 $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
282 }
283
284 #[::gpui::test]
285 #[cfg_attr(not(feature = "e2e"), ignore)]
286 async fn cancel(cx: &mut ::gpui::TestAppContext) {
287 $crate::e2e_tests::test_cancel($server, cx).await;
288 }
289 }
290 };
291}
292
293// Helpers
294
295pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
296 env_logger::try_init().ok();
297
298 cx.update(|cx| {
299 let settings_store = SettingsStore::test(cx);
300 cx.set_global(settings_store);
301 Project::init_settings(cx);
302 language::init(cx);
303 crate::settings::init(cx);
304
305 crate::AllAgentServersSettings::override_global(
306 AllAgentServersSettings {
307 claude: Some(AgentServerSettings {
308 command: crate::claude::tests::local_command(),
309 }),
310 gemini: Some(AgentServerSettings {
311 command: crate::gemini::tests::local_command(),
312 }),
313 },
314 cx,
315 );
316 });
317
318 cx.executor().allow_parking();
319
320 FakeFs::new(cx.executor())
321}
322
323pub async fn new_test_thread(
324 server: impl AgentServer + 'static,
325 project: Entity<Project>,
326 current_dir: impl AsRef<Path>,
327 cx: &mut TestAppContext,
328) -> Entity<AcpThread> {
329 let thread = cx
330 .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
331 .await
332 .unwrap();
333
334 thread
335 .update(cx, |thread, _| thread.initialize())
336 .await
337 .unwrap();
338 thread
339}
340
341pub async fn run_until_first_tool_call(
342 thread: &Entity<AcpThread>,
343 cx: &mut TestAppContext,
344) -> usize {
345 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
346
347 let subscription = cx.update(|cx| {
348 cx.subscribe(thread, move |thread, _, cx| {
349 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
350 if matches!(entry, AgentThreadEntry::ToolCall(_)) {
351 return tx.try_send(ix).unwrap();
352 }
353 }
354 })
355 });
356
357 select! {
358 // We have to use a smol timer here because
359 // cx.background_executor().timer isn't real in the test context
360 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
361 panic!("Timeout waiting for tool call")
362 }
363 ix = rx.next().fuse() => {
364 drop(subscription);
365 ix.unwrap()
366 }
367 }
368}