1use super::*;
2use crate::MessageContent;
3use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList};
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use fs::{FakeFs, Fs};
10use futures::channel::mpsc::UnboundedReceiver;
11use gpui::{
12 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
13};
14use indoc::indoc;
15use language_model::{
16 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
17 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
18 fake_provider::FakeLanguageModel,
19};
20use project::Project;
21use prompt_store::ProjectContext;
22use reqwest_client::ReqwestClient;
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26use settings::SettingsStore;
27use smol::stream::StreamExt;
28use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
29use util::path;
30
31mod test_tools;
32use test_tools::*;
33
34#[gpui::test]
35#[ignore = "can't run on CI yet"]
36async fn test_echo(cx: &mut TestAppContext) {
37 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
38
39 let events = thread
40 .update(cx, |thread, cx| {
41 thread.send("Testing: Reply with 'Hello'", cx)
42 })
43 .collect()
44 .await;
45 thread.update(cx, |thread, _cx| {
46 assert_eq!(
47 thread.messages().last().unwrap().content,
48 vec![MessageContent::Text("Hello".to_string())]
49 );
50 });
51 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
52}
53
54#[gpui::test]
55#[ignore = "can't run on CI yet"]
56async fn test_thinking(cx: &mut TestAppContext) {
57 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
58
59 let events = thread
60 .update(cx, |thread, cx| {
61 thread.send(
62 indoc! {"
63 Testing:
64
65 Generate a thinking step where you just think the word 'Think',
66 and have your final answer be 'Hello'
67 "},
68 cx,
69 )
70 })
71 .collect()
72 .await;
73 thread.update(cx, |thread, _cx| {
74 assert_eq!(
75 thread.messages().last().unwrap().to_markdown(),
76 indoc! {"
77 ## assistant
78 <think>Think</think>
79 Hello
80 "}
81 )
82 });
83 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
84}
85
86#[gpui::test]
87async fn test_system_prompt(cx: &mut TestAppContext) {
88 let ThreadTest {
89 model,
90 thread,
91 project_context,
92 ..
93 } = setup(cx, TestModel::Fake).await;
94 let fake_model = model.as_fake();
95
96 project_context.borrow_mut().shell = "test-shell".into();
97 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
98 thread.update(cx, |thread, cx| thread.send("abc", cx));
99 cx.run_until_parked();
100 let mut pending_completions = fake_model.pending_completions();
101 assert_eq!(
102 pending_completions.len(),
103 1,
104 "unexpected pending completions: {:?}",
105 pending_completions
106 );
107
108 let pending_completion = pending_completions.pop().unwrap();
109 assert_eq!(pending_completion.messages[0].role, Role::System);
110
111 let system_message = &pending_completion.messages[0];
112 let system_prompt = system_message.content[0].to_str().unwrap();
113 assert!(
114 system_prompt.contains("test-shell"),
115 "unexpected system message: {:?}",
116 system_message
117 );
118 assert!(
119 system_prompt.contains("## Fixing Diagnostics"),
120 "unexpected system message: {:?}",
121 system_message
122 );
123}
124
125#[gpui::test]
126#[ignore = "can't run on CI yet"]
127async fn test_basic_tool_calls(cx: &mut TestAppContext) {
128 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
129
130 // Test a tool call that's likely to complete *before* streaming stops.
131 let events = thread
132 .update(cx, |thread, cx| {
133 thread.add_tool(EchoTool);
134 thread.send(
135 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
136 cx,
137 )
138 })
139 .collect()
140 .await;
141 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
142
143 // Test a tool calls that's likely to complete *after* streaming stops.
144 let events = thread
145 .update(cx, |thread, cx| {
146 thread.remove_tool(&AgentTool::name(&EchoTool));
147 thread.add_tool(DelayTool);
148 thread.send(
149 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
150 cx,
151 )
152 })
153 .collect()
154 .await;
155 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
156 thread.update(cx, |thread, _cx| {
157 assert!(
158 thread
159 .messages()
160 .last()
161 .unwrap()
162 .content
163 .iter()
164 .any(|content| {
165 if let MessageContent::Text(text) = content {
166 text.contains("Ding")
167 } else {
168 false
169 }
170 }),
171 "{}",
172 thread.to_markdown()
173 );
174 });
175}
176
177#[gpui::test]
178#[ignore = "can't run on CI yet"]
179async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
180 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
181
182 // Test a tool call that's likely to complete *before* streaming stops.
183 let mut events = thread.update(cx, |thread, cx| {
184 thread.add_tool(WordListTool);
185 thread.send("Test the word_list tool.", cx)
186 });
187
188 let mut saw_partial_tool_use = false;
189 while let Some(event) = events.next().await {
190 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
191 thread.update(cx, |thread, _cx| {
192 // Look for a tool use in the thread's last message
193 let last_content = thread.messages().last().unwrap().content.last().unwrap();
194 if let MessageContent::ToolUse(last_tool_use) = last_content {
195 assert_eq!(last_tool_use.name.as_ref(), "word_list");
196 if tool_call.status == acp::ToolCallStatus::Pending {
197 if !last_tool_use.is_input_complete
198 && last_tool_use.input.get("g").is_none()
199 {
200 saw_partial_tool_use = true;
201 }
202 } else {
203 last_tool_use
204 .input
205 .get("a")
206 .expect("'a' has streamed because input is now complete");
207 last_tool_use
208 .input
209 .get("g")
210 .expect("'g' has streamed because input is now complete");
211 }
212 } else {
213 panic!("last content should be a tool use");
214 }
215 });
216 }
217 }
218
219 assert!(
220 saw_partial_tool_use,
221 "should see at least one partially streamed tool use in the history"
222 );
223}
224
225#[gpui::test]
226async fn test_tool_authorization(cx: &mut TestAppContext) {
227 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
228 let fake_model = model.as_fake();
229
230 let mut events = thread.update(cx, |thread, cx| {
231 thread.add_tool(ToolRequiringPermission);
232 thread.send("abc", cx)
233 });
234 cx.run_until_parked();
235 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
236 LanguageModelToolUse {
237 id: "tool_id_1".into(),
238 name: ToolRequiringPermission.name().into(),
239 raw_input: "{}".into(),
240 input: json!({}),
241 is_input_complete: true,
242 },
243 ));
244 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
245 LanguageModelToolUse {
246 id: "tool_id_2".into(),
247 name: ToolRequiringPermission.name().into(),
248 raw_input: "{}".into(),
249 input: json!({}),
250 is_input_complete: true,
251 },
252 ));
253 fake_model.end_last_completion_stream();
254 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
255 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
256
257 // Approve the first
258 tool_call_auth_1
259 .response
260 .send(tool_call_auth_1.options[1].id.clone())
261 .unwrap();
262 cx.run_until_parked();
263
264 // Reject the second
265 tool_call_auth_2
266 .response
267 .send(tool_call_auth_1.options[2].id.clone())
268 .unwrap();
269 cx.run_until_parked();
270
271 let completion = fake_model.pending_completions().pop().unwrap();
272 let message = completion.messages.last().unwrap();
273 assert_eq!(
274 message.content,
275 vec![
276 language_model::MessageContent::ToolResult(LanguageModelToolResult {
277 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
278 tool_name: ToolRequiringPermission.name().into(),
279 is_error: false,
280 content: "Allowed".into(),
281 output: Some("Allowed".into())
282 }),
283 language_model::MessageContent::ToolResult(LanguageModelToolResult {
284 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
285 tool_name: ToolRequiringPermission.name().into(),
286 is_error: true,
287 content: "Permission to run tool denied by user".into(),
288 output: None
289 })
290 ]
291 );
292
293 // Simulate yet another tool call.
294 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
295 LanguageModelToolUse {
296 id: "tool_id_3".into(),
297 name: ToolRequiringPermission.name().into(),
298 raw_input: "{}".into(),
299 input: json!({}),
300 is_input_complete: true,
301 },
302 ));
303 fake_model.end_last_completion_stream();
304
305 // Respond by always allowing tools.
306 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
307 tool_call_auth_3
308 .response
309 .send(tool_call_auth_3.options[0].id.clone())
310 .unwrap();
311 cx.run_until_parked();
312 let completion = fake_model.pending_completions().pop().unwrap();
313 let message = completion.messages.last().unwrap();
314 assert_eq!(
315 message.content,
316 vec![language_model::MessageContent::ToolResult(
317 LanguageModelToolResult {
318 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
319 tool_name: ToolRequiringPermission.name().into(),
320 is_error: false,
321 content: "Allowed".into(),
322 output: Some("Allowed".into())
323 }
324 )]
325 );
326
327 // Simulate a final tool call, ensuring we don't trigger authorization.
328 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
329 LanguageModelToolUse {
330 id: "tool_id_4".into(),
331 name: ToolRequiringPermission.name().into(),
332 raw_input: "{}".into(),
333 input: json!({}),
334 is_input_complete: true,
335 },
336 ));
337 fake_model.end_last_completion_stream();
338 cx.run_until_parked();
339 let completion = fake_model.pending_completions().pop().unwrap();
340 let message = completion.messages.last().unwrap();
341 assert_eq!(
342 message.content,
343 vec![language_model::MessageContent::ToolResult(
344 LanguageModelToolResult {
345 tool_use_id: "tool_id_4".into(),
346 tool_name: ToolRequiringPermission.name().into(),
347 is_error: false,
348 content: "Allowed".into(),
349 output: Some("Allowed".into())
350 }
351 )]
352 );
353}
354
355#[gpui::test]
356async fn test_tool_hallucination(cx: &mut TestAppContext) {
357 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
358 let fake_model = model.as_fake();
359
360 let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
361 cx.run_until_parked();
362 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
363 LanguageModelToolUse {
364 id: "tool_id_1".into(),
365 name: "nonexistent_tool".into(),
366 raw_input: "{}".into(),
367 input: json!({}),
368 is_input_complete: true,
369 },
370 ));
371 fake_model.end_last_completion_stream();
372
373 let tool_call = expect_tool_call(&mut events).await;
374 assert_eq!(tool_call.title, "nonexistent_tool");
375 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
376 let update = expect_tool_call_update_fields(&mut events).await;
377 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
378}
379
380async fn expect_tool_call(
381 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
382) -> acp::ToolCall {
383 let event = events
384 .next()
385 .await
386 .expect("no tool call authorization event received")
387 .unwrap();
388 match event {
389 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
390 event => {
391 panic!("Unexpected event {event:?}");
392 }
393 }
394}
395
396async fn expect_tool_call_update_fields(
397 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
398) -> acp::ToolCallUpdate {
399 let event = events
400 .next()
401 .await
402 .expect("no tool call authorization event received")
403 .unwrap();
404 match event {
405 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
406 return update;
407 }
408 event => {
409 panic!("Unexpected event {event:?}");
410 }
411 }
412}
413
414async fn next_tool_call_authorization(
415 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
416) -> ToolCallAuthorization {
417 loop {
418 let event = events
419 .next()
420 .await
421 .expect("no tool call authorization event received")
422 .unwrap();
423 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
424 let permission_kinds = tool_call_authorization
425 .options
426 .iter()
427 .map(|o| o.kind)
428 .collect::<Vec<_>>();
429 assert_eq!(
430 permission_kinds,
431 vec![
432 acp::PermissionOptionKind::AllowAlways,
433 acp::PermissionOptionKind::AllowOnce,
434 acp::PermissionOptionKind::RejectOnce,
435 ]
436 );
437 return tool_call_authorization;
438 }
439 }
440}
441
442#[gpui::test]
443#[ignore = "can't run on CI yet"]
444async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
445 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
446
447 // Test concurrent tool calls with different delay times
448 let events = thread
449 .update(cx, |thread, cx| {
450 thread.add_tool(DelayTool);
451 thread.send(
452 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
453 cx,
454 )
455 })
456 .collect()
457 .await;
458
459 let stop_reasons = stop_events(events);
460 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
461
462 thread.update(cx, |thread, _cx| {
463 let last_message = thread.messages().last().unwrap();
464 let text = last_message
465 .content
466 .iter()
467 .filter_map(|content| {
468 if let MessageContent::Text(text) = content {
469 Some(text.as_str())
470 } else {
471 None
472 }
473 })
474 .collect::<String>();
475
476 assert!(text.contains("Ding"));
477 });
478}
479
480#[gpui::test]
481async fn test_profiles(cx: &mut TestAppContext) {
482 let ThreadTest {
483 model, thread, fs, ..
484 } = setup(cx, TestModel::Fake).await;
485 let fake_model = model.as_fake();
486
487 thread.update(cx, |thread, _cx| {
488 thread.add_tool(DelayTool);
489 thread.add_tool(EchoTool);
490 thread.add_tool(InfiniteTool);
491 });
492
493 // Override profiles and wait for settings to be loaded.
494 fs.insert_file(
495 paths::settings_file(),
496 json!({
497 "agent": {
498 "profiles": {
499 "test-1": {
500 "name": "Test Profile 1",
501 "tools": {
502 EchoTool.name(): true,
503 DelayTool.name(): true,
504 }
505 },
506 "test-2": {
507 "name": "Test Profile 2",
508 "tools": {
509 InfiniteTool.name(): true,
510 }
511 }
512 }
513 }
514 })
515 .to_string()
516 .into_bytes(),
517 )
518 .await;
519 cx.run_until_parked();
520
521 // Test that test-1 profile (default) has echo and delay tools
522 thread.update(cx, |thread, cx| {
523 thread.set_profile(AgentProfileId("test-1".into()));
524 thread.send("test", cx);
525 });
526 cx.run_until_parked();
527
528 let mut pending_completions = fake_model.pending_completions();
529 assert_eq!(pending_completions.len(), 1);
530 let completion = pending_completions.pop().unwrap();
531 let tool_names: Vec<String> = completion
532 .tools
533 .iter()
534 .map(|tool| tool.name.clone())
535 .collect();
536 assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
537 fake_model.end_last_completion_stream();
538
539 // Switch to test-2 profile, and verify that it has only the infinite tool.
540 thread.update(cx, |thread, cx| {
541 thread.set_profile(AgentProfileId("test-2".into()));
542 thread.send("test2", cx)
543 });
544 cx.run_until_parked();
545 let mut pending_completions = fake_model.pending_completions();
546 assert_eq!(pending_completions.len(), 1);
547 let completion = pending_completions.pop().unwrap();
548 let tool_names: Vec<String> = completion
549 .tools
550 .iter()
551 .map(|tool| tool.name.clone())
552 .collect();
553 assert_eq!(tool_names, vec![InfiniteTool.name()]);
554}
555
556#[gpui::test]
557#[ignore = "can't run on CI yet"]
558async fn test_cancellation(cx: &mut TestAppContext) {
559 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
560
561 let mut events = thread.update(cx, |thread, cx| {
562 thread.add_tool(InfiniteTool);
563 thread.add_tool(EchoTool);
564 thread.send(
565 "Call the echo tool and then call the infinite tool, then explain their output",
566 cx,
567 )
568 });
569
570 // Wait until both tools are called.
571 let mut expected_tools = vec!["Echo", "Infinite Tool"];
572 let mut echo_id = None;
573 let mut echo_completed = false;
574 while let Some(event) = events.next().await {
575 match event.unwrap() {
576 AgentResponseEvent::ToolCall(tool_call) => {
577 assert_eq!(tool_call.title, expected_tools.remove(0));
578 if tool_call.title == "Echo" {
579 echo_id = Some(tool_call.id);
580 }
581 }
582 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
583 acp::ToolCallUpdate {
584 id,
585 fields:
586 acp::ToolCallUpdateFields {
587 status: Some(acp::ToolCallStatus::Completed),
588 ..
589 },
590 },
591 )) if Some(&id) == echo_id.as_ref() => {
592 echo_completed = true;
593 }
594 _ => {}
595 }
596
597 if expected_tools.is_empty() && echo_completed {
598 break;
599 }
600 }
601
602 // Cancel the current send and ensure that the event stream is closed, even
603 // if one of the tools is still running.
604 thread.update(cx, |thread, _cx| thread.cancel());
605 events.collect::<Vec<_>>().await;
606
607 // Ensure we can still send a new message after cancellation.
608 let events = thread
609 .update(cx, |thread, cx| {
610 thread.send("Testing: reply with 'Hello' then stop.", cx)
611 })
612 .collect::<Vec<_>>()
613 .await;
614 thread.update(cx, |thread, _cx| {
615 assert_eq!(
616 thread.messages().last().unwrap().content,
617 vec![MessageContent::Text("Hello".to_string())]
618 );
619 });
620 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
621}
622
623#[gpui::test]
624async fn test_refusal(cx: &mut TestAppContext) {
625 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
626 let fake_model = model.as_fake();
627
628 let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
629 cx.run_until_parked();
630 thread.read_with(cx, |thread, _| {
631 assert_eq!(
632 thread.to_markdown(),
633 indoc! {"
634 ## user
635 Hello
636 "}
637 );
638 });
639
640 fake_model.send_last_completion_stream_text_chunk("Hey!");
641 cx.run_until_parked();
642 thread.read_with(cx, |thread, _| {
643 assert_eq!(
644 thread.to_markdown(),
645 indoc! {"
646 ## user
647 Hello
648 ## assistant
649 Hey!
650 "}
651 );
652 });
653
654 // If the model refuses to continue, the thread should remove all the messages after the last user message.
655 fake_model
656 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
657 let events = events.collect::<Vec<_>>().await;
658 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
659 thread.read_with(cx, |thread, _| {
660 assert_eq!(thread.to_markdown(), "");
661 });
662}
663
664#[gpui::test]
665async fn test_agent_connection(cx: &mut TestAppContext) {
666 cx.update(settings::init);
667 let templates = Templates::new();
668
669 // Initialize language model system with test provider
670 cx.update(|cx| {
671 gpui_tokio::init(cx);
672 client::init_settings(cx);
673
674 let http_client = FakeHttpClient::with_404_response();
675 let clock = Arc::new(clock::FakeSystemClock::new());
676 let client = Client::new(clock, http_client, cx);
677 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
678 language_model::init(client.clone(), cx);
679 language_models::init(user_store.clone(), client.clone(), cx);
680 Project::init_settings(cx);
681 LanguageModelRegistry::test(cx);
682 agent_settings::init(cx);
683 });
684 cx.executor().forbid_parking();
685
686 // Create a project for new_thread
687 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
688 fake_fs.insert_tree(path!("/test"), json!({})).await;
689 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
690 let cwd = Path::new("/test");
691
692 // Create agent and connection
693 let agent = NativeAgent::new(
694 project.clone(),
695 templates.clone(),
696 None,
697 fake_fs.clone(),
698 &mut cx.to_async(),
699 )
700 .await
701 .unwrap();
702 let connection = NativeAgentConnection(agent.clone());
703
704 // Test model_selector returns Some
705 let selector_opt = connection.model_selector();
706 assert!(
707 selector_opt.is_some(),
708 "agent2 should always support ModelSelector"
709 );
710 let selector = selector_opt.unwrap();
711
712 // Test list_models
713 let listed_models = cx
714 .update(|cx| selector.list_models(cx))
715 .await
716 .expect("list_models should succeed");
717 let AgentModelList::Grouped(listed_models) = listed_models else {
718 panic!("Unexpected model list type");
719 };
720 assert!(!listed_models.is_empty(), "should have at least one model");
721 assert_eq!(
722 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
723 "fake/fake"
724 );
725
726 // Create a thread using new_thread
727 let connection_rc = Rc::new(connection.clone());
728 let acp_thread = cx
729 .update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
730 .await
731 .expect("new_thread should succeed");
732
733 // Get the session_id from the AcpThread
734 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
735
736 // Test selected_model returns the default
737 let model = cx
738 .update(|cx| selector.selected_model(&session_id, cx))
739 .await
740 .expect("selected_model should succeed");
741 let model = cx
742 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
743 .unwrap();
744 let model = model.as_fake();
745 assert_eq!(model.id().0, "fake", "should return default model");
746
747 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
748 cx.run_until_parked();
749 model.send_last_completion_stream_text_chunk("def");
750 cx.run_until_parked();
751 acp_thread.read_with(cx, |thread, cx| {
752 assert_eq!(
753 thread.to_markdown(cx),
754 indoc! {"
755 ## User
756
757 abc
758
759 ## Assistant
760
761 def
762
763 "}
764 )
765 });
766
767 // Test cancel
768 cx.update(|cx| connection.cancel(&session_id, cx));
769 request.await.expect("prompt should fail gracefully");
770
771 // Ensure that dropping the ACP thread causes the native thread to be
772 // dropped as well.
773 cx.update(|_| drop(acp_thread));
774 let result = cx
775 .update(|cx| {
776 connection.prompt(
777 acp::PromptRequest {
778 session_id: session_id.clone(),
779 prompt: vec!["ghi".into()],
780 },
781 cx,
782 )
783 })
784 .await;
785 assert_eq!(
786 result.as_ref().unwrap_err().to_string(),
787 "Session not found",
788 "unexpected result: {:?}",
789 result
790 );
791}
792
793#[gpui::test]
794async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
795 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
796 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
797 let fake_model = model.as_fake();
798
799 let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
800 cx.run_until_parked();
801
802 // Simulate streaming partial input.
803 let input = json!({});
804 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
805 LanguageModelToolUse {
806 id: "1".into(),
807 name: ThinkingTool.name().into(),
808 raw_input: input.to_string(),
809 input,
810 is_input_complete: false,
811 },
812 ));
813
814 // Input streaming completed
815 let input = json!({ "content": "Thinking hard!" });
816 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
817 LanguageModelToolUse {
818 id: "1".into(),
819 name: "thinking".into(),
820 raw_input: input.to_string(),
821 input,
822 is_input_complete: true,
823 },
824 ));
825 fake_model.end_last_completion_stream();
826 cx.run_until_parked();
827
828 let tool_call = expect_tool_call(&mut events).await;
829 assert_eq!(
830 tool_call,
831 acp::ToolCall {
832 id: acp::ToolCallId("1".into()),
833 title: "Thinking".into(),
834 kind: acp::ToolKind::Think,
835 status: acp::ToolCallStatus::Pending,
836 content: vec![],
837 locations: vec![],
838 raw_input: Some(json!({})),
839 raw_output: None,
840 }
841 );
842 let update = expect_tool_call_update_fields(&mut events).await;
843 assert_eq!(
844 update,
845 acp::ToolCallUpdate {
846 id: acp::ToolCallId("1".into()),
847 fields: acp::ToolCallUpdateFields {
848 title: Some("Thinking".into()),
849 kind: Some(acp::ToolKind::Think),
850 raw_input: Some(json!({ "content": "Thinking hard!" })),
851 ..Default::default()
852 },
853 }
854 );
855 let update = expect_tool_call_update_fields(&mut events).await;
856 assert_eq!(
857 update,
858 acp::ToolCallUpdate {
859 id: acp::ToolCallId("1".into()),
860 fields: acp::ToolCallUpdateFields {
861 status: Some(acp::ToolCallStatus::InProgress),
862 ..Default::default()
863 },
864 }
865 );
866 let update = expect_tool_call_update_fields(&mut events).await;
867 assert_eq!(
868 update,
869 acp::ToolCallUpdate {
870 id: acp::ToolCallId("1".into()),
871 fields: acp::ToolCallUpdateFields {
872 content: Some(vec!["Thinking hard!".into()]),
873 ..Default::default()
874 },
875 }
876 );
877 let update = expect_tool_call_update_fields(&mut events).await;
878 assert_eq!(
879 update,
880 acp::ToolCallUpdate {
881 id: acp::ToolCallId("1".into()),
882 fields: acp::ToolCallUpdateFields {
883 status: Some(acp::ToolCallStatus::Completed),
884 raw_output: Some("Finished thinking.".into()),
885 ..Default::default()
886 },
887 }
888 );
889}
890
891/// Filters out the stop events for asserting against in tests
892fn stop_events(
893 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
894) -> Vec<acp::StopReason> {
895 result_events
896 .into_iter()
897 .filter_map(|event| match event.unwrap() {
898 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
899 _ => None,
900 })
901 .collect()
902}
903
904struct ThreadTest {
905 model: Arc<dyn LanguageModel>,
906 thread: Entity<Thread>,
907 project_context: Rc<RefCell<ProjectContext>>,
908 fs: Arc<FakeFs>,
909}
910
911enum TestModel {
912 Sonnet4,
913 Sonnet4Thinking,
914 Fake,
915}
916
917impl TestModel {
918 fn id(&self) -> LanguageModelId {
919 match self {
920 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
921 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
922 TestModel::Fake => unreachable!(),
923 }
924 }
925}
926
927async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
928 cx.executor().allow_parking();
929
930 let fs = FakeFs::new(cx.background_executor.clone());
931 fs.create_dir(paths::settings_file().parent().unwrap())
932 .await
933 .unwrap();
934 fs.insert_file(
935 paths::settings_file(),
936 json!({
937 "agent": {
938 "default_profile": "test-profile",
939 "profiles": {
940 "test-profile": {
941 "name": "Test Profile",
942 "tools": {
943 EchoTool.name(): true,
944 DelayTool.name(): true,
945 WordListTool.name(): true,
946 ToolRequiringPermission.name(): true,
947 InfiniteTool.name(): true,
948 }
949 }
950 }
951 }
952 })
953 .to_string()
954 .into_bytes(),
955 )
956 .await;
957
958 cx.update(|cx| {
959 settings::init(cx);
960 Project::init_settings(cx);
961 agent_settings::init(cx);
962 gpui_tokio::init(cx);
963 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
964 cx.set_http_client(Arc::new(http_client));
965
966 client::init_settings(cx);
967 let client = Client::production(cx);
968 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
969 language_model::init(client.clone(), cx);
970 language_models::init(user_store.clone(), client.clone(), cx);
971
972 watch_settings(fs.clone(), cx);
973 });
974
975 let templates = Templates::new();
976
977 fs.insert_tree(path!("/test"), json!({})).await;
978 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
979
980 let model = cx
981 .update(|cx| {
982 if let TestModel::Fake = model {
983 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
984 } else {
985 let model_id = model.id();
986 let models = LanguageModelRegistry::read_global(cx);
987 let model = models
988 .available_models(cx)
989 .find(|model| model.id() == model_id)
990 .unwrap();
991
992 let provider = models.provider(&model.provider_id()).unwrap();
993 let authenticated = provider.authenticate(cx);
994
995 cx.spawn(async move |_cx| {
996 authenticated.await.unwrap();
997 model
998 })
999 }
1000 })
1001 .await;
1002
1003 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
1004 let context_server_registry =
1005 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1006 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1007 let thread = cx.new(|cx| {
1008 Thread::new(
1009 project,
1010 project_context.clone(),
1011 context_server_registry,
1012 action_log,
1013 templates,
1014 model.clone(),
1015 cx,
1016 )
1017 });
1018 ThreadTest {
1019 model,
1020 thread,
1021 project_context,
1022 fs,
1023 }
1024}
1025
1026#[cfg(test)]
1027#[ctor::ctor]
1028fn init_logger() {
1029 if std::env::var("RUST_LOG").is_ok() {
1030 env_logger::init();
1031 }
1032}
1033
1034fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1035 let fs = fs.clone();
1036 cx.spawn({
1037 async move |cx| {
1038 let mut new_settings_content_rx = settings::watch_config_file(
1039 cx.background_executor(),
1040 fs,
1041 paths::settings_file().clone(),
1042 );
1043
1044 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1045 cx.update(|cx| {
1046 SettingsStore::update_global(cx, |settings, cx| {
1047 settings.set_user_settings(&new_settings_content, cx)
1048 })
1049 })
1050 .ok();
1051 }
1052 }
1053 })
1054 .detach();
1055}