1use std::{path::Path, sync::Arc, time::Duration};
2
3use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
4use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
5use agent_client_protocol as acp;
6
7use futures::{FutureExt, StreamExt, channel::mpsc, select};
8use gpui::{Entity, TestAppContext};
9use indoc::indoc;
10use project::{FakeFs, Project};
11use serde_json::json;
12use settings::{Settings, SettingsStore};
13use util::path;
14
15pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
16 let fs = init_test(cx).await;
17 let project = Project::test(fs, [], cx).await;
18 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
19
20 thread
21 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
22 .await
23 .unwrap();
24
25 thread.read_with(cx, |thread, _| {
26 assert_eq!(thread.entries().len(), 2);
27 assert!(matches!(
28 thread.entries()[0],
29 AgentThreadEntry::UserMessage(_)
30 ));
31 assert!(matches!(
32 thread.entries()[1],
33 AgentThreadEntry::AssistantMessage(_)
34 ));
35 });
36}
37
38pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
39 let _fs = init_test(cx).await;
40
41 let tempdir = tempfile::tempdir().unwrap();
42 std::fs::write(
43 tempdir.path().join("foo.rs"),
44 indoc! {"
45 fn main() {
46 println!(\"Hello, world!\");
47 }
48 "},
49 )
50 .expect("failed to write file");
51 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
52 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
53 thread
54 .update(cx, |thread, cx| {
55 thread.send(
56 vec![
57 acp::ContentBlock::Text(acp::TextContent {
58 text: "Read the file ".into(),
59 annotations: None,
60 }),
61 acp::ContentBlock::ResourceLink(acp::ResourceLink {
62 uri: "foo.rs".into(),
63 name: "foo.rs".into(),
64 annotations: None,
65 description: None,
66 mime_type: None,
67 size: None,
68 title: None,
69 }),
70 acp::ContentBlock::Text(acp::TextContent {
71 text: " and tell me what the content of the println! is".into(),
72 annotations: None,
73 }),
74 ],
75 cx,
76 )
77 })
78 .await
79 .unwrap();
80
81 thread.read_with(cx, |thread, cx| {
82 assert_eq!(thread.entries().len(), 3);
83 assert!(matches!(
84 thread.entries()[0],
85 AgentThreadEntry::UserMessage(_)
86 ));
87 assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
88 let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
89 panic!("Expected AssistantMessage")
90 };
91 assert!(
92 assistant_message.to_markdown(cx).contains("Hello, world!"),
93 "unexpected assistant message: {:?}",
94 assistant_message.to_markdown(cx)
95 );
96 });
97}
98
99pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
100 let fs = init_test(cx).await;
101 fs.insert_tree(
102 path!("/private/tmp"),
103 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
104 )
105 .await;
106 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
107 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
108
109 thread
110 .update(cx, |thread, cx| {
111 thread.send_raw(
112 "Read the '/private/tmp/foo' file and tell me what you see.",
113 cx,
114 )
115 })
116 .await
117 .unwrap();
118 thread.read_with(cx, |thread, _cx| {
119 assert!(thread.entries().iter().any(|entry| {
120 matches!(
121 entry,
122 AgentThreadEntry::ToolCall(ToolCall {
123 status: ToolCallStatus::Allowed { .. },
124 ..
125 })
126 )
127 }));
128 assert!(
129 thread
130 .entries()
131 .iter()
132 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
133 );
134 });
135}
136
137pub async fn test_tool_call_with_confirmation(
138 server: impl AgentServer + 'static,
139 cx: &mut TestAppContext,
140) {
141 let fs = init_test(cx).await;
142 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
143 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
144 let full_turn = thread.update(cx, |thread, cx| {
145 thread.send_raw(
146 r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
147 cx,
148 )
149 });
150
151 run_until_first_tool_call(
152 &thread,
153 |entry| {
154 matches!(
155 entry,
156 AgentThreadEntry::ToolCall(ToolCall {
157 status: ToolCallStatus::WaitingForConfirmation { .. },
158 ..
159 })
160 )
161 },
162 cx,
163 )
164 .await;
165
166 let tool_call_id = thread.read_with(cx, |thread, _cx| {
167 let AgentThreadEntry::ToolCall(ToolCall {
168 id,
169 content,
170 status: ToolCallStatus::WaitingForConfirmation { .. },
171 ..
172 }) = &thread
173 .entries()
174 .iter()
175 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
176 .unwrap()
177 else {
178 panic!();
179 };
180
181 assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
182
183 id.clone()
184 });
185
186 thread.update(cx, |thread, cx| {
187 thread.authorize_tool_call(
188 tool_call_id,
189 acp::PermissionOptionId("0".into()),
190 acp::PermissionOptionKind::AllowOnce,
191 cx,
192 );
193
194 assert!(thread.entries().iter().any(|entry| matches!(
195 entry,
196 AgentThreadEntry::ToolCall(ToolCall {
197 status: ToolCallStatus::Allowed { .. },
198 ..
199 })
200 )));
201 });
202
203 full_turn.await.unwrap();
204
205 thread.read_with(cx, |thread, cx| {
206 let AgentThreadEntry::ToolCall(ToolCall {
207 content,
208 status: ToolCallStatus::Allowed { .. },
209 ..
210 }) = thread
211 .entries()
212 .iter()
213 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
214 .unwrap()
215 else {
216 panic!();
217 };
218
219 assert!(
220 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
221 "Expected content to contain 'Hello'"
222 );
223 });
224}
225
226pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
227 let fs = init_test(cx).await;
228
229 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
230 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
231 let full_turn = thread.update(cx, |thread, cx| {
232 thread.send_raw(
233 r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
234 cx,
235 )
236 });
237
238 let first_tool_call_ix = run_until_first_tool_call(
239 &thread,
240 |entry| {
241 matches!(
242 entry,
243 AgentThreadEntry::ToolCall(ToolCall {
244 status: ToolCallStatus::WaitingForConfirmation { .. },
245 ..
246 })
247 )
248 },
249 cx,
250 )
251 .await;
252
253 thread.read_with(cx, |thread, _cx| {
254 let AgentThreadEntry::ToolCall(ToolCall {
255 id,
256 content,
257 status: ToolCallStatus::WaitingForConfirmation { .. },
258 ..
259 }) = &thread.entries()[first_tool_call_ix]
260 else {
261 panic!("{:?}", thread.entries()[1]);
262 };
263
264 assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
265
266 id.clone()
267 });
268
269 let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
270 full_turn.await.unwrap();
271 thread.read_with(cx, |thread, _| {
272 let AgentThreadEntry::ToolCall(ToolCall {
273 status: ToolCallStatus::Canceled,
274 ..
275 }) = &thread.entries()[first_tool_call_ix]
276 else {
277 panic!();
278 };
279 });
280
281 thread
282 .update(cx, |thread, cx| {
283 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
284 })
285 .await
286 .unwrap();
287 thread.read_with(cx, |thread, _| {
288 assert!(matches!(
289 &thread.entries().last().unwrap(),
290 AgentThreadEntry::AssistantMessage(..),
291 ))
292 });
293}
294
295#[macro_export]
296macro_rules! common_e2e_tests {
297 ($server:expr) => {
298 mod common_e2e {
299 use super::*;
300
301 #[::gpui::test]
302 #[cfg_attr(not(feature = "e2e"), ignore)]
303 async fn basic(cx: &mut ::gpui::TestAppContext) {
304 $crate::e2e_tests::test_basic($server, cx).await;
305 }
306
307 #[::gpui::test]
308 #[cfg_attr(not(feature = "e2e"), ignore)]
309 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
310 $crate::e2e_tests::test_path_mentions($server, cx).await;
311 }
312
313 #[::gpui::test]
314 #[cfg_attr(not(feature = "e2e"), ignore)]
315 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
316 $crate::e2e_tests::test_tool_call($server, cx).await;
317 }
318
319 #[::gpui::test]
320 #[cfg_attr(not(feature = "e2e"), ignore)]
321 async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
322 $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
323 }
324
325 #[::gpui::test]
326 #[cfg_attr(not(feature = "e2e"), ignore)]
327 async fn cancel(cx: &mut ::gpui::TestAppContext) {
328 $crate::e2e_tests::test_cancel($server, cx).await;
329 }
330 }
331 };
332}
333
334// Helpers
335
336pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
337 env_logger::try_init().ok();
338
339 cx.update(|cx| {
340 let settings_store = SettingsStore::test(cx);
341 cx.set_global(settings_store);
342 Project::init_settings(cx);
343 language::init(cx);
344 crate::settings::init(cx);
345
346 crate::AllAgentServersSettings::override_global(
347 AllAgentServersSettings {
348 claude: Some(AgentServerSettings {
349 command: crate::claude::tests::local_command(),
350 }),
351 gemini: Some(AgentServerSettings {
352 command: crate::gemini::tests::local_command(),
353 }),
354 codex: Some(AgentServerSettings {
355 command: crate::codex::tests::local_command(),
356 }),
357 },
358 cx,
359 );
360 });
361
362 cx.executor().allow_parking();
363
364 FakeFs::new(cx.executor())
365}
366
367pub async fn new_test_thread(
368 server: impl AgentServer + 'static,
369 project: Entity<Project>,
370 current_dir: impl AsRef<Path>,
371 cx: &mut TestAppContext,
372) -> Entity<AcpThread> {
373 let connection = cx
374 .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
375 .await
376 .unwrap();
377
378 let thread = connection
379 .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
380 .await
381 .unwrap();
382
383 thread
384}
385
386pub async fn run_until_first_tool_call(
387 thread: &Entity<AcpThread>,
388 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
389 cx: &mut TestAppContext,
390) -> usize {
391 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
392
393 let subscription = cx.update(|cx| {
394 cx.subscribe(thread, move |thread, _, cx| {
395 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
396 if wait_until(entry) {
397 return tx.try_send(ix).unwrap();
398 }
399 }
400 })
401 });
402
403 select! {
404 // We have to use a smol timer here because
405 // cx.background_executor().timer isn't real in the test context
406 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
407 panic!("Timeout waiting for tool call")
408 }
409 ix = rx.next().fuse() => {
410 drop(subscription);
411 ix.unwrap()
412 }
413 }
414}