1use super::*;
2use crate::templates::Templates;
3use acp_thread::AgentConnection;
4use agent_client_protocol::{self as acp};
5use anyhow::Result;
6use assistant_tool::ActionLog;
7use client::{Client, UserStore};
8use fs::FakeFs;
9use futures::channel::mpsc::UnboundedReceiver;
10use gpui::{http_client::FakeHttpClient, AppContext, Entity, Task, TestAppContext};
11use indoc::indoc;
12use language_model::{
13 fake_provider::FakeLanguageModel, LanguageModel, LanguageModelCompletionError,
14 LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry, LanguageModelToolResult,
15 LanguageModelToolUse, MessageContent, Role, StopReason,
16};
17use project::Project;
18use prompt_store::ProjectContext;
19use reqwest_client::ReqwestClient;
20use schemars::JsonSchema;
21use serde::{Deserialize, Serialize};
22use serde_json::json;
23use smol::stream::StreamExt;
24use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
25use util::path;
26
27mod test_tools;
28use test_tools::*;
29
30#[gpui::test]
31#[ignore = "can't run on CI yet"]
32async fn test_echo(cx: &mut TestAppContext) {
33 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
34
35 let events = thread
36 .update(cx, |thread, cx| {
37 thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
38 })
39 .collect()
40 .await;
41 thread.update(cx, |thread, _cx| {
42 assert_eq!(
43 thread.messages().last().unwrap().content,
44 vec![MessageContent::Text("Hello".to_string())]
45 );
46 });
47 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
48}
49
50#[gpui::test]
51#[ignore = "can't run on CI yet"]
52async fn test_thinking(cx: &mut TestAppContext) {
53 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
54
55 let events = thread
56 .update(cx, |thread, cx| {
57 thread.send(
58 model.clone(),
59 indoc! {"
60 Testing:
61
62 Generate a thinking step where you just think the word 'Think',
63 and have your final answer be 'Hello'
64 "},
65 cx,
66 )
67 })
68 .collect()
69 .await;
70 thread.update(cx, |thread, _cx| {
71 assert_eq!(
72 thread.messages().last().unwrap().to_markdown(),
73 indoc! {"
74 ## assistant
75 <think>Think</think>
76 Hello
77 "}
78 )
79 });
80 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
81}
82
83#[gpui::test]
84async fn test_system_prompt(cx: &mut TestAppContext) {
85 let ThreadTest {
86 model,
87 thread,
88 project_context,
89 ..
90 } = setup(cx, TestModel::Fake).await;
91 let fake_model = model.as_fake();
92
93 project_context.borrow_mut().shell = "test-shell".into();
94 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
95 thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
96 cx.run_until_parked();
97 let mut pending_completions = fake_model.pending_completions();
98 assert_eq!(
99 pending_completions.len(),
100 1,
101 "unexpected pending completions: {:?}",
102 pending_completions
103 );
104
105 let pending_completion = pending_completions.pop().unwrap();
106 assert_eq!(pending_completion.messages[0].role, Role::System);
107
108 let system_message = &pending_completion.messages[0];
109 let system_prompt = system_message.content[0].to_str().unwrap();
110 assert!(
111 system_prompt.contains("test-shell"),
112 "unexpected system message: {:?}",
113 system_message
114 );
115 assert!(
116 system_prompt.contains("## Fixing Diagnostics"),
117 "unexpected system message: {:?}",
118 system_message
119 );
120}
121
122#[gpui::test]
123#[ignore = "can't run on CI yet"]
124async fn test_basic_tool_calls(cx: &mut TestAppContext) {
125 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
126
127 // Test a tool call that's likely to complete *before* streaming stops.
128 let events = thread
129 .update(cx, |thread, cx| {
130 thread.add_tool(EchoTool);
131 thread.send(
132 model.clone(),
133 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
134 cx,
135 )
136 })
137 .collect()
138 .await;
139 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
140
141 // Test a tool calls that's likely to complete *after* streaming stops.
142 let events = thread
143 .update(cx, |thread, cx| {
144 thread.remove_tool(&AgentTool::name(&EchoTool));
145 thread.add_tool(DelayTool);
146 thread.send(
147 model.clone(),
148 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
149 cx,
150 )
151 })
152 .collect()
153 .await;
154 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
155 thread.update(cx, |thread, _cx| {
156 assert!(thread
157 .messages()
158 .last()
159 .unwrap()
160 .content
161 .iter()
162 .any(|content| {
163 if let MessageContent::Text(text) = content {
164 text.contains("Ding")
165 } else {
166 false
167 }
168 }));
169 });
170}
171
172#[gpui::test]
173#[ignore = "can't run on CI yet"]
174async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
175 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
176
177 // Test a tool call that's likely to complete *before* streaming stops.
178 let mut events = thread.update(cx, |thread, cx| {
179 thread.add_tool(WordListTool);
180 thread.send(model.clone(), "Test the word_list tool.", cx)
181 });
182
183 let mut saw_partial_tool_use = false;
184 while let Some(event) = events.next().await {
185 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
186 thread.update(cx, |thread, _cx| {
187 // Look for a tool use in the thread's last message
188 let last_content = thread.messages().last().unwrap().content.last().unwrap();
189 if let MessageContent::ToolUse(last_tool_use) = last_content {
190 assert_eq!(last_tool_use.name.as_ref(), "word_list");
191 if tool_call.status == acp::ToolCallStatus::Pending {
192 if !last_tool_use.is_input_complete
193 && last_tool_use.input.get("g").is_none()
194 {
195 saw_partial_tool_use = true;
196 }
197 } else {
198 last_tool_use
199 .input
200 .get("a")
201 .expect("'a' has streamed because input is now complete");
202 last_tool_use
203 .input
204 .get("g")
205 .expect("'g' has streamed because input is now complete");
206 }
207 } else {
208 panic!("last content should be a tool use");
209 }
210 });
211 }
212 }
213
214 assert!(
215 saw_partial_tool_use,
216 "should see at least one partially streamed tool use in the history"
217 );
218}
219
220#[gpui::test]
221async fn test_tool_authorization(cx: &mut TestAppContext) {
222 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
223 let fake_model = model.as_fake();
224
225 let mut events = thread.update(cx, |thread, cx| {
226 thread.add_tool(ToolRequiringPermission);
227 thread.send(model.clone(), "abc", cx)
228 });
229 cx.run_until_parked();
230 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
231 LanguageModelToolUse {
232 id: "tool_id_1".into(),
233 name: ToolRequiringPermission.name().into(),
234 raw_input: "{}".into(),
235 input: json!({}),
236 is_input_complete: true,
237 },
238 ));
239 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
240 LanguageModelToolUse {
241 id: "tool_id_2".into(),
242 name: ToolRequiringPermission.name().into(),
243 raw_input: "{}".into(),
244 input: json!({}),
245 is_input_complete: true,
246 },
247 ));
248 fake_model.end_last_completion_stream();
249 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
250 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
251
252 // Approve the first
253 tool_call_auth_1
254 .response
255 .send(tool_call_auth_1.options[1].id.clone())
256 .unwrap();
257 cx.run_until_parked();
258
259 // Reject the second
260 tool_call_auth_2
261 .response
262 .send(tool_call_auth_1.options[2].id.clone())
263 .unwrap();
264 cx.run_until_parked();
265
266 let completion = fake_model.pending_completions().pop().unwrap();
267 let message = completion.messages.last().unwrap();
268 assert_eq!(
269 message.content,
270 vec![
271 MessageContent::ToolResult(LanguageModelToolResult {
272 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
273 tool_name: tool_call_auth_1.tool_call.title.into(),
274 is_error: false,
275 content: "Allowed".into(),
276 output: None
277 }),
278 MessageContent::ToolResult(LanguageModelToolResult {
279 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
280 tool_name: tool_call_auth_2.tool_call.title.into(),
281 is_error: true,
282 content: "Permission to run tool denied by user".into(),
283 output: None
284 })
285 ]
286 );
287}
288
289async fn next_tool_call_authorization(
290 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
291) -> ToolCallAuthorization {
292 loop {
293 let event = events
294 .next()
295 .await
296 .expect("no tool call authorization event received")
297 .unwrap();
298 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
299 let permission_kinds = tool_call_authorization
300 .options
301 .iter()
302 .map(|o| o.kind)
303 .collect::<Vec<_>>();
304 assert_eq!(
305 permission_kinds,
306 vec![
307 acp::PermissionOptionKind::AllowAlways,
308 acp::PermissionOptionKind::AllowOnce,
309 acp::PermissionOptionKind::RejectOnce,
310 ]
311 );
312 return tool_call_authorization;
313 }
314 }
315}
316
317#[gpui::test]
318#[ignore = "can't run on CI yet"]
319async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
320 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
321
322 // Test concurrent tool calls with different delay times
323 let events = thread
324 .update(cx, |thread, cx| {
325 thread.add_tool(DelayTool);
326 thread.send(
327 model.clone(),
328 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
329 cx,
330 )
331 })
332 .collect()
333 .await;
334
335 let stop_reasons = stop_events(events);
336 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
337
338 thread.update(cx, |thread, _cx| {
339 let last_message = thread.messages().last().unwrap();
340 let text = last_message
341 .content
342 .iter()
343 .filter_map(|content| {
344 if let MessageContent::Text(text) = content {
345 Some(text.as_str())
346 } else {
347 None
348 }
349 })
350 .collect::<String>();
351
352 assert!(text.contains("Ding"));
353 });
354}
355
356#[gpui::test]
357#[ignore = "can't run on CI yet"]
358async fn test_cancellation(cx: &mut TestAppContext) {
359 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
360
361 let mut events = thread.update(cx, |thread, cx| {
362 thread.add_tool(InfiniteTool);
363 thread.add_tool(EchoTool);
364 thread.send(
365 model.clone(),
366 "Call the echo tool and then call the infinite tool, then explain their output",
367 cx,
368 )
369 });
370
371 // Wait until both tools are called.
372 let mut expected_tool_calls = vec!["echo", "infinite"];
373 let mut echo_id = None;
374 let mut echo_completed = false;
375 while let Some(event) = events.next().await {
376 match event.unwrap() {
377 AgentResponseEvent::ToolCall(tool_call) => {
378 assert_eq!(tool_call.title, expected_tool_calls.remove(0));
379 if tool_call.title == "echo" {
380 echo_id = Some(tool_call.id);
381 }
382 }
383 AgentResponseEvent::ToolCallUpdate(acp::ToolCallUpdate {
384 id,
385 fields:
386 acp::ToolCallUpdateFields {
387 status: Some(acp::ToolCallStatus::Completed),
388 ..
389 },
390 }) if Some(&id) == echo_id.as_ref() => {
391 echo_completed = true;
392 }
393 _ => {}
394 }
395
396 if expected_tool_calls.is_empty() && echo_completed {
397 break;
398 }
399 }
400
401 // Cancel the current send and ensure that the event stream is closed, even
402 // if one of the tools is still running.
403 thread.update(cx, |thread, _cx| thread.cancel());
404 events.collect::<Vec<_>>().await;
405
406 // Ensure we can still send a new message after cancellation.
407 let events = thread
408 .update(cx, |thread, cx| {
409 thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
410 })
411 .collect::<Vec<_>>()
412 .await;
413 thread.update(cx, |thread, _cx| {
414 assert_eq!(
415 thread.messages().last().unwrap().content,
416 vec![MessageContent::Text("Hello".to_string())]
417 );
418 });
419 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
420}
421
422#[gpui::test]
423async fn test_refusal(cx: &mut TestAppContext) {
424 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
425 let fake_model = model.as_fake();
426
427 let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
428 cx.run_until_parked();
429 thread.read_with(cx, |thread, _| {
430 assert_eq!(
431 thread.to_markdown(),
432 indoc! {"
433 ## user
434 Hello
435 "}
436 );
437 });
438
439 fake_model.send_last_completion_stream_text_chunk("Hey!");
440 cx.run_until_parked();
441 thread.read_with(cx, |thread, _| {
442 assert_eq!(
443 thread.to_markdown(),
444 indoc! {"
445 ## user
446 Hello
447 ## assistant
448 Hey!
449 "}
450 );
451 });
452
453 // If the model refuses to continue, the thread should remove all the messages after the last user message.
454 fake_model
455 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
456 let events = events.collect::<Vec<_>>().await;
457 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
458 thread.read_with(cx, |thread, _| {
459 assert_eq!(thread.to_markdown(), "");
460 });
461}
462
463#[gpui::test]
464async fn test_agent_connection(cx: &mut TestAppContext) {
465 cx.update(settings::init);
466 let templates = Templates::new();
467
468 // Initialize language model system with test provider
469 cx.update(|cx| {
470 gpui_tokio::init(cx);
471 client::init_settings(cx);
472
473 let http_client = FakeHttpClient::with_404_response();
474 let clock = Arc::new(clock::FakeSystemClock::new());
475 let client = Client::new(clock, http_client, cx);
476 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
477 language_model::init(client.clone(), cx);
478 language_models::init(user_store.clone(), client.clone(), cx);
479 Project::init_settings(cx);
480 LanguageModelRegistry::test(cx);
481 });
482 cx.executor().forbid_parking();
483
484 // Create a project for new_thread
485 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
486 fake_fs.insert_tree(path!("/test"), json!({})).await;
487 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
488 let cwd = Path::new("/test");
489
490 // Create agent and connection
491 let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
492 .await
493 .unwrap();
494 let connection = NativeAgentConnection(agent.clone());
495
496 // Test model_selector returns Some
497 let selector_opt = connection.model_selector();
498 assert!(
499 selector_opt.is_some(),
500 "agent2 should always support ModelSelector"
501 );
502 let selector = selector_opt.unwrap();
503
504 // Test list_models
505 let listed_models = cx
506 .update(|cx| {
507 let mut async_cx = cx.to_async();
508 selector.list_models(&mut async_cx)
509 })
510 .await
511 .expect("list_models should succeed");
512 assert!(!listed_models.is_empty(), "should have at least one model");
513 assert_eq!(listed_models[0].id().0, "fake");
514
515 // Create a thread using new_thread
516 let connection_rc = Rc::new(connection.clone());
517 let acp_thread = cx
518 .update(|cx| {
519 let mut async_cx = cx.to_async();
520 connection_rc.new_thread(project, cwd, &mut async_cx)
521 })
522 .await
523 .expect("new_thread should succeed");
524
525 // Get the session_id from the AcpThread
526 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
527
528 // Test selected_model returns the default
529 let model = cx
530 .update(|cx| {
531 let mut async_cx = cx.to_async();
532 selector.selected_model(&session_id, &mut async_cx)
533 })
534 .await
535 .expect("selected_model should succeed");
536 let model = model.as_fake();
537 assert_eq!(model.id().0, "fake", "should return default model");
538
539 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
540 cx.run_until_parked();
541 model.send_last_completion_stream_text_chunk("def");
542 cx.run_until_parked();
543 acp_thread.read_with(cx, |thread, cx| {
544 assert_eq!(
545 thread.to_markdown(cx),
546 indoc! {"
547 ## User
548
549 abc
550
551 ## Assistant
552
553 def
554
555 "}
556 )
557 });
558
559 // Test cancel
560 cx.update(|cx| connection.cancel(&session_id, cx));
561 request.await.expect("prompt should fail gracefully");
562
563 // Ensure that dropping the ACP thread causes the native thread to be
564 // dropped as well.
565 cx.update(|_| drop(acp_thread));
566 let result = cx
567 .update(|cx| {
568 connection.prompt(
569 acp::PromptRequest {
570 session_id: session_id.clone(),
571 prompt: vec!["ghi".into()],
572 },
573 cx,
574 )
575 })
576 .await;
577 assert_eq!(
578 result.as_ref().unwrap_err().to_string(),
579 "Session not found",
580 "unexpected result: {:?}",
581 result
582 );
583}
584
585/// Filters out the stop events for asserting against in tests
586fn stop_events(
587 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
588) -> Vec<acp::StopReason> {
589 result_events
590 .into_iter()
591 .filter_map(|event| match event.unwrap() {
592 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
593 _ => None,
594 })
595 .collect()
596}
597
598struct ThreadTest {
599 model: Arc<dyn LanguageModel>,
600 thread: Entity<Thread>,
601 project_context: Rc<RefCell<ProjectContext>>,
602}
603
604enum TestModel {
605 Sonnet4,
606 Sonnet4Thinking,
607 Fake,
608}
609
610impl TestModel {
611 fn id(&self) -> LanguageModelId {
612 match self {
613 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
614 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
615 TestModel::Fake => unreachable!(),
616 }
617 }
618}
619
620async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
621 cx.executor().allow_parking();
622 cx.update(|cx| {
623 settings::init(cx);
624 Project::init_settings(cx);
625 });
626 let templates = Templates::new();
627
628 let fs = FakeFs::new(cx.background_executor.clone());
629 fs.insert_tree(path!("/test"), json!({})).await;
630 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
631
632 let model = cx
633 .update(|cx| {
634 gpui_tokio::init(cx);
635 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
636 cx.set_http_client(Arc::new(http_client));
637
638 client::init_settings(cx);
639 let client = Client::production(cx);
640 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
641 language_model::init(client.clone(), cx);
642 language_models::init(user_store.clone(), client.clone(), cx);
643
644 if let TestModel::Fake = model {
645 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
646 } else {
647 let model_id = model.id();
648 let models = LanguageModelRegistry::read_global(cx);
649 let model = models
650 .available_models(cx)
651 .find(|model| model.id() == model_id)
652 .unwrap();
653
654 let provider = models.provider(&model.provider_id()).unwrap();
655 let authenticated = provider.authenticate(cx);
656
657 cx.spawn(async move |_cx| {
658 authenticated.await.unwrap();
659 model
660 })
661 }
662 })
663 .await;
664
665 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
666 let action_log = cx.new(|_| ActionLog::new(project.clone()));
667 let thread = cx.new(|_| {
668 Thread::new(
669 project,
670 project_context.clone(),
671 action_log,
672 templates,
673 model.clone(),
674 )
675 });
676 ThreadTest {
677 model,
678 thread,
679 project_context,
680 }
681}
682
683#[cfg(test)]
684#[ctor::ctor]
685fn init_logger() {
686 if std::env::var("RUST_LOG").is_ok() {
687 env_logger::init();
688 }
689}