1use std::{path::Path, sync::Arc, time::Duration};
2
3use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
4use acp_thread::{
5 AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
6};
7use agentic_coding_protocol as acp;
8use futures::{FutureExt, StreamExt, channel::mpsc, select};
9use gpui::{Entity, TestAppContext};
10use indoc::indoc;
11use project::{FakeFs, Project};
12use serde_json::json;
13use settings::{Settings, SettingsStore};
14use util::path;
15
16pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
17 let fs = init_test(cx).await;
18 let project = Project::test(fs, [], cx).await;
19 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
20
21 thread
22 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
23 .await
24 .unwrap();
25
26 thread.read_with(cx, |thread, _| {
27 assert_eq!(thread.entries().len(), 2);
28 assert!(matches!(
29 thread.entries()[0],
30 AgentThreadEntry::UserMessage(_)
31 ));
32 assert!(matches!(
33 thread.entries()[1],
34 AgentThreadEntry::AssistantMessage(_)
35 ));
36 });
37}
38
39pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
40 let _fs = init_test(cx).await;
41
42 let tempdir = tempfile::tempdir().unwrap();
43 std::fs::write(
44 tempdir.path().join("foo.rs"),
45 indoc! {"
46 fn main() {
47 println!(\"Hello, world!\");
48 }
49 "},
50 )
51 .expect("failed to write file");
52 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
53 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
54 thread
55 .update(cx, |thread, cx| {
56 thread.send(
57 acp::SendUserMessageParams {
58 chunks: vec![
59 acp::UserMessageChunk::Text {
60 text: "Read the file ".into(),
61 },
62 acp::UserMessageChunk::Path {
63 path: Path::new("foo.rs").into(),
64 },
65 acp::UserMessageChunk::Text {
66 text: " and tell me what the content of the println! is".into(),
67 },
68 ],
69 },
70 cx,
71 )
72 })
73 .await
74 .unwrap();
75
76 thread.read_with(cx, |thread, cx| {
77 assert_eq!(thread.entries().len(), 3);
78 assert!(matches!(
79 thread.entries()[0],
80 AgentThreadEntry::UserMessage(_)
81 ));
82 assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
83 let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
84 panic!("Expected AssistantMessage")
85 };
86 assert!(
87 assistant_message.to_markdown(cx).contains("Hello, world!"),
88 "unexpected assistant message: {:?}",
89 assistant_message.to_markdown(cx)
90 );
91 });
92}
93
94pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
95 let fs = init_test(cx).await;
96 fs.insert_tree(
97 path!("/private/tmp"),
98 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
99 )
100 .await;
101 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
102 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
103
104 thread
105 .update(cx, |thread, cx| {
106 thread.send_raw(
107 "Read the '/private/tmp/foo' file and tell me what you see.",
108 cx,
109 )
110 })
111 .await
112 .unwrap();
113 thread.read_with(cx, |thread, _cx| {
114 assert!(thread.entries().iter().any(|entry| {
115 matches!(
116 entry,
117 AgentThreadEntry::ToolCall(ToolCall {
118 status: ToolCallStatus::Allowed { .. },
119 ..
120 })
121 )
122 }));
123 assert!(
124 thread
125 .entries()
126 .iter()
127 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
128 );
129 });
130}
131
132pub async fn test_tool_call_with_confirmation(
133 server: impl AgentServer + 'static,
134 cx: &mut TestAppContext,
135) {
136 let fs = init_test(cx).await;
137 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
138 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
139 let full_turn = thread.update(cx, |thread, cx| {
140 thread.send_raw(
141 r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
142 cx,
143 )
144 });
145
146 run_until_first_tool_call(
147 &thread,
148 |entry| {
149 matches!(
150 entry,
151 AgentThreadEntry::ToolCall(ToolCall {
152 status: ToolCallStatus::WaitingForConfirmation { .. },
153 ..
154 })
155 )
156 },
157 cx,
158 )
159 .await;
160
161 let tool_call_id = thread.read_with(cx, |thread, _cx| {
162 let AgentThreadEntry::ToolCall(ToolCall {
163 id,
164 status:
165 ToolCallStatus::WaitingForConfirmation {
166 confirmation: ToolCallConfirmation::Execute { root_command, .. },
167 ..
168 },
169 ..
170 }) = &thread
171 .entries()
172 .iter()
173 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
174 .unwrap()
175 else {
176 panic!();
177 };
178
179 assert!(root_command.contains("touch"));
180
181 *id
182 });
183
184 thread.update(cx, |thread, cx| {
185 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
186
187 assert!(thread.entries().iter().any(|entry| matches!(
188 entry,
189 AgentThreadEntry::ToolCall(ToolCall {
190 status: ToolCallStatus::Allowed { .. },
191 ..
192 })
193 )));
194 });
195
196 full_turn.await.unwrap();
197
198 thread.read_with(cx, |thread, cx| {
199 let AgentThreadEntry::ToolCall(ToolCall {
200 content: Some(ToolCallContent::Markdown { markdown }),
201 status: ToolCallStatus::Allowed { .. },
202 ..
203 }) = thread
204 .entries()
205 .iter()
206 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
207 .unwrap()
208 else {
209 panic!();
210 };
211
212 markdown.read_with(cx, |md, _cx| {
213 assert!(
214 md.source().contains("Hello"),
215 r#"Expected '{}' to contain "Hello""#,
216 md.source()
217 );
218 });
219 });
220}
221
222pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
223 let fs = init_test(cx).await;
224
225 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
226 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
227 let full_turn = thread.update(cx, |thread, cx| {
228 thread.send_raw(
229 r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
230 cx,
231 )
232 });
233
234 let first_tool_call_ix = run_until_first_tool_call(
235 &thread,
236 |entry| {
237 matches!(
238 entry,
239 AgentThreadEntry::ToolCall(ToolCall {
240 status: ToolCallStatus::WaitingForConfirmation { .. },
241 ..
242 })
243 )
244 },
245 cx,
246 )
247 .await;
248
249 thread.read_with(cx, |thread, _cx| {
250 let AgentThreadEntry::ToolCall(ToolCall {
251 id,
252 status:
253 ToolCallStatus::WaitingForConfirmation {
254 confirmation: ToolCallConfirmation::Execute { root_command, .. },
255 ..
256 },
257 ..
258 }) = &thread.entries()[first_tool_call_ix]
259 else {
260 panic!("{:?}", thread.entries()[1]);
261 };
262
263 assert!(root_command.contains("touch"));
264
265 *id
266 });
267
268 thread
269 .update(cx, |thread, cx| thread.cancel(cx))
270 .await
271 .unwrap();
272 full_turn.await.unwrap();
273 thread.read_with(cx, |thread, _| {
274 let AgentThreadEntry::ToolCall(ToolCall {
275 status: ToolCallStatus::Canceled,
276 ..
277 }) = &thread.entries()[first_tool_call_ix]
278 else {
279 panic!();
280 };
281 });
282
283 thread
284 .update(cx, |thread, cx| {
285 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
286 })
287 .await
288 .unwrap();
289 thread.read_with(cx, |thread, _| {
290 assert!(matches!(
291 &thread.entries().last().unwrap(),
292 AgentThreadEntry::AssistantMessage(..),
293 ))
294 });
295}
296
297#[macro_export]
298macro_rules! common_e2e_tests {
299 ($server:expr) => {
300 mod common_e2e {
301 use super::*;
302
303 #[::gpui::test]
304 #[cfg_attr(not(feature = "e2e"), ignore)]
305 async fn basic(cx: &mut ::gpui::TestAppContext) {
306 $crate::e2e_tests::test_basic($server, cx).await;
307 }
308
309 #[::gpui::test]
310 #[cfg_attr(not(feature = "e2e"), ignore)]
311 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
312 $crate::e2e_tests::test_path_mentions($server, cx).await;
313 }
314
315 #[::gpui::test]
316 #[cfg_attr(not(feature = "e2e"), ignore)]
317 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
318 $crate::e2e_tests::test_tool_call($server, cx).await;
319 }
320
321 #[::gpui::test]
322 #[cfg_attr(not(feature = "e2e"), ignore)]
323 async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
324 $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
325 }
326
327 #[::gpui::test]
328 #[cfg_attr(not(feature = "e2e"), ignore)]
329 async fn cancel(cx: &mut ::gpui::TestAppContext) {
330 $crate::e2e_tests::test_cancel($server, cx).await;
331 }
332 }
333 };
334}
335
336// Helpers
337
338pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
339 env_logger::try_init().ok();
340
341 cx.update(|cx| {
342 let settings_store = SettingsStore::test(cx);
343 cx.set_global(settings_store);
344 Project::init_settings(cx);
345 language::init(cx);
346 crate::settings::init(cx);
347
348 crate::AllAgentServersSettings::override_global(
349 AllAgentServersSettings {
350 claude: Some(AgentServerSettings {
351 command: crate::claude::tests::local_command(),
352 }),
353 gemini: Some(AgentServerSettings {
354 command: crate::gemini::tests::local_command(),
355 }),
356 },
357 cx,
358 );
359 });
360
361 cx.executor().allow_parking();
362
363 FakeFs::new(cx.executor())
364}
365
366pub async fn new_test_thread(
367 server: impl AgentServer + 'static,
368 project: Entity<Project>,
369 current_dir: impl AsRef<Path>,
370 cx: &mut TestAppContext,
371) -> Entity<AcpThread> {
372 let thread = cx
373 .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
374 .await
375 .unwrap();
376
377 thread
378 .update(cx, |thread, _| thread.initialize())
379 .await
380 .unwrap();
381 thread
382}
383
384pub async fn run_until_first_tool_call(
385 thread: &Entity<AcpThread>,
386 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
387 cx: &mut TestAppContext,
388) -> usize {
389 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
390
391 let subscription = cx.update(|cx| {
392 cx.subscribe(thread, move |thread, _, cx| {
393 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
394 if wait_until(entry) {
395 return tx.try_send(ix).unwrap();
396 }
397 }
398 })
399 });
400
401 select! {
402 // We have to use a smol timer here because
403 // cx.background_executor().timer isn't real in the test context
404 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
405 panic!("Timeout waiting for tool call")
406 }
407 ix = rx.next().fuse() => {
408 drop(subscription);
409 ix.unwrap()
410 }
411 }
412}