1use super::*;
2use crate::MessageContent;
3use acp_thread::AgentConnection;
4use action_log::ActionLog;
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use fs::{FakeFs, Fs};
10use futures::channel::mpsc::UnboundedReceiver;
11use gpui::{
12 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
13};
14use indoc::indoc;
15use language_model::{
16 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
17 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
18 fake_provider::FakeLanguageModel,
19};
20use project::Project;
21use prompt_store::ProjectContext;
22use reqwest_client::ReqwestClient;
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::json;
26use settings::SettingsStore;
27use smol::stream::StreamExt;
28use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
29use util::path;
30
31mod test_tools;
32use test_tools::*;
33
34#[gpui::test]
35#[ignore = "can't run on CI yet"]
36async fn test_echo(cx: &mut TestAppContext) {
37 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
38
39 let events = thread
40 .update(cx, |thread, cx| {
41 thread.send("Testing: Reply with 'Hello'", cx)
42 })
43 .collect()
44 .await;
45 thread.update(cx, |thread, _cx| {
46 assert_eq!(
47 thread.messages().last().unwrap().content,
48 vec![MessageContent::Text("Hello".to_string())]
49 );
50 });
51 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
52}
53
54#[gpui::test]
55#[ignore = "can't run on CI yet"]
56async fn test_thinking(cx: &mut TestAppContext) {
57 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
58
59 let events = thread
60 .update(cx, |thread, cx| {
61 thread.send(
62 indoc! {"
63 Testing:
64
65 Generate a thinking step where you just think the word 'Think',
66 and have your final answer be 'Hello'
67 "},
68 cx,
69 )
70 })
71 .collect()
72 .await;
73 thread.update(cx, |thread, _cx| {
74 assert_eq!(
75 thread.messages().last().unwrap().to_markdown(),
76 indoc! {"
77 ## assistant
78 <think>Think</think>
79 Hello
80 "}
81 )
82 });
83 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
84}
85
86#[gpui::test]
87async fn test_system_prompt(cx: &mut TestAppContext) {
88 let ThreadTest {
89 model,
90 thread,
91 project_context,
92 ..
93 } = setup(cx, TestModel::Fake).await;
94 let fake_model = model.as_fake();
95
96 project_context.borrow_mut().shell = "test-shell".into();
97 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
98 thread.update(cx, |thread, cx| thread.send("abc", cx));
99 cx.run_until_parked();
100 let mut pending_completions = fake_model.pending_completions();
101 assert_eq!(
102 pending_completions.len(),
103 1,
104 "unexpected pending completions: {:?}",
105 pending_completions
106 );
107
108 let pending_completion = pending_completions.pop().unwrap();
109 assert_eq!(pending_completion.messages[0].role, Role::System);
110
111 let system_message = &pending_completion.messages[0];
112 let system_prompt = system_message.content[0].to_str().unwrap();
113 assert!(
114 system_prompt.contains("test-shell"),
115 "unexpected system message: {:?}",
116 system_message
117 );
118 assert!(
119 system_prompt.contains("## Fixing Diagnostics"),
120 "unexpected system message: {:?}",
121 system_message
122 );
123}
124
125#[gpui::test]
126#[ignore = "can't run on CI yet"]
127async fn test_basic_tool_calls(cx: &mut TestAppContext) {
128 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
129
130 // Test a tool call that's likely to complete *before* streaming stops.
131 let events = thread
132 .update(cx, |thread, cx| {
133 thread.add_tool(EchoTool);
134 thread.send(
135 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
136 cx,
137 )
138 })
139 .collect()
140 .await;
141 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
142
143 // Test a tool calls that's likely to complete *after* streaming stops.
144 let events = thread
145 .update(cx, |thread, cx| {
146 thread.remove_tool(&AgentTool::name(&EchoTool));
147 thread.add_tool(DelayTool);
148 thread.send(
149 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
150 cx,
151 )
152 })
153 .collect()
154 .await;
155 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
156 thread.update(cx, |thread, _cx| {
157 assert!(
158 thread
159 .messages()
160 .last()
161 .unwrap()
162 .content
163 .iter()
164 .any(|content| {
165 if let MessageContent::Text(text) = content {
166 text.contains("Ding")
167 } else {
168 false
169 }
170 }),
171 "{}",
172 thread.to_markdown()
173 );
174 });
175}
176
177#[gpui::test]
178#[ignore = "can't run on CI yet"]
179async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
180 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
181
182 // Test a tool call that's likely to complete *before* streaming stops.
183 let mut events = thread.update(cx, |thread, cx| {
184 thread.add_tool(WordListTool);
185 thread.send("Test the word_list tool.", cx)
186 });
187
188 let mut saw_partial_tool_use = false;
189 while let Some(event) = events.next().await {
190 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
191 thread.update(cx, |thread, _cx| {
192 // Look for a tool use in the thread's last message
193 let last_content = thread.messages().last().unwrap().content.last().unwrap();
194 if let MessageContent::ToolUse(last_tool_use) = last_content {
195 assert_eq!(last_tool_use.name.as_ref(), "word_list");
196 if tool_call.status == acp::ToolCallStatus::Pending {
197 if !last_tool_use.is_input_complete
198 && last_tool_use.input.get("g").is_none()
199 {
200 saw_partial_tool_use = true;
201 }
202 } else {
203 last_tool_use
204 .input
205 .get("a")
206 .expect("'a' has streamed because input is now complete");
207 last_tool_use
208 .input
209 .get("g")
210 .expect("'g' has streamed because input is now complete");
211 }
212 } else {
213 panic!("last content should be a tool use");
214 }
215 });
216 }
217 }
218
219 assert!(
220 saw_partial_tool_use,
221 "should see at least one partially streamed tool use in the history"
222 );
223}
224
225#[gpui::test]
226async fn test_tool_authorization(cx: &mut TestAppContext) {
227 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
228 let fake_model = model.as_fake();
229
230 let mut events = thread.update(cx, |thread, cx| {
231 thread.add_tool(ToolRequiringPermission);
232 thread.send("abc", cx)
233 });
234 cx.run_until_parked();
235 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
236 LanguageModelToolUse {
237 id: "tool_id_1".into(),
238 name: ToolRequiringPermission.name().into(),
239 raw_input: "{}".into(),
240 input: json!({}),
241 is_input_complete: true,
242 },
243 ));
244 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
245 LanguageModelToolUse {
246 id: "tool_id_2".into(),
247 name: ToolRequiringPermission.name().into(),
248 raw_input: "{}".into(),
249 input: json!({}),
250 is_input_complete: true,
251 },
252 ));
253 fake_model.end_last_completion_stream();
254 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
255 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
256
257 // Approve the first
258 tool_call_auth_1
259 .response
260 .send(tool_call_auth_1.options[1].id.clone())
261 .unwrap();
262 cx.run_until_parked();
263
264 // Reject the second
265 tool_call_auth_2
266 .response
267 .send(tool_call_auth_1.options[2].id.clone())
268 .unwrap();
269 cx.run_until_parked();
270
271 let completion = fake_model.pending_completions().pop().unwrap();
272 let message = completion.messages.last().unwrap();
273 assert_eq!(
274 message.content,
275 vec![
276 language_model::MessageContent::ToolResult(LanguageModelToolResult {
277 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
278 tool_name: ToolRequiringPermission.name().into(),
279 is_error: false,
280 content: "Allowed".into(),
281 output: Some("Allowed".into())
282 }),
283 language_model::MessageContent::ToolResult(LanguageModelToolResult {
284 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
285 tool_name: ToolRequiringPermission.name().into(),
286 is_error: true,
287 content: "Permission to run tool denied by user".into(),
288 output: None
289 })
290 ]
291 );
292
293 // Simulate yet another tool call.
294 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
295 LanguageModelToolUse {
296 id: "tool_id_3".into(),
297 name: ToolRequiringPermission.name().into(),
298 raw_input: "{}".into(),
299 input: json!({}),
300 is_input_complete: true,
301 },
302 ));
303 fake_model.end_last_completion_stream();
304
305 // Respond by always allowing tools.
306 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
307 tool_call_auth_3
308 .response
309 .send(tool_call_auth_3.options[0].id.clone())
310 .unwrap();
311 cx.run_until_parked();
312 let completion = fake_model.pending_completions().pop().unwrap();
313 let message = completion.messages.last().unwrap();
314 assert_eq!(
315 message.content,
316 vec![language_model::MessageContent::ToolResult(
317 LanguageModelToolResult {
318 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
319 tool_name: ToolRequiringPermission.name().into(),
320 is_error: false,
321 content: "Allowed".into(),
322 output: Some("Allowed".into())
323 }
324 )]
325 );
326
327 // Simulate a final tool call, ensuring we don't trigger authorization.
328 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
329 LanguageModelToolUse {
330 id: "tool_id_4".into(),
331 name: ToolRequiringPermission.name().into(),
332 raw_input: "{}".into(),
333 input: json!({}),
334 is_input_complete: true,
335 },
336 ));
337 fake_model.end_last_completion_stream();
338 cx.run_until_parked();
339 let completion = fake_model.pending_completions().pop().unwrap();
340 let message = completion.messages.last().unwrap();
341 assert_eq!(
342 message.content,
343 vec![language_model::MessageContent::ToolResult(
344 LanguageModelToolResult {
345 tool_use_id: "tool_id_4".into(),
346 tool_name: ToolRequiringPermission.name().into(),
347 is_error: false,
348 content: "Allowed".into(),
349 output: Some("Allowed".into())
350 }
351 )]
352 );
353}
354
355#[gpui::test]
356async fn test_tool_hallucination(cx: &mut TestAppContext) {
357 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
358 let fake_model = model.as_fake();
359
360 let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
361 cx.run_until_parked();
362 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
363 LanguageModelToolUse {
364 id: "tool_id_1".into(),
365 name: "nonexistent_tool".into(),
366 raw_input: "{}".into(),
367 input: json!({}),
368 is_input_complete: true,
369 },
370 ));
371 fake_model.end_last_completion_stream();
372
373 let tool_call = expect_tool_call(&mut events).await;
374 assert_eq!(tool_call.title, "nonexistent_tool");
375 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
376 let update = expect_tool_call_update_fields(&mut events).await;
377 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
378}
379
380async fn expect_tool_call(
381 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
382) -> acp::ToolCall {
383 let event = events
384 .next()
385 .await
386 .expect("no tool call authorization event received")
387 .unwrap();
388 match event {
389 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
390 event => {
391 panic!("Unexpected event {event:?}");
392 }
393 }
394}
395
396async fn expect_tool_call_update_fields(
397 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
398) -> acp::ToolCallUpdate {
399 let event = events
400 .next()
401 .await
402 .expect("no tool call authorization event received")
403 .unwrap();
404 match event {
405 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
406 return update;
407 }
408 event => {
409 panic!("Unexpected event {event:?}");
410 }
411 }
412}
413
414async fn next_tool_call_authorization(
415 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
416) -> ToolCallAuthorization {
417 loop {
418 let event = events
419 .next()
420 .await
421 .expect("no tool call authorization event received")
422 .unwrap();
423 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
424 let permission_kinds = tool_call_authorization
425 .options
426 .iter()
427 .map(|o| o.kind)
428 .collect::<Vec<_>>();
429 assert_eq!(
430 permission_kinds,
431 vec![
432 acp::PermissionOptionKind::AllowAlways,
433 acp::PermissionOptionKind::AllowOnce,
434 acp::PermissionOptionKind::RejectOnce,
435 ]
436 );
437 return tool_call_authorization;
438 }
439 }
440}
441
442#[gpui::test]
443#[ignore = "can't run on CI yet"]
444async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
445 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
446
447 // Test concurrent tool calls with different delay times
448 let events = thread
449 .update(cx, |thread, cx| {
450 thread.add_tool(DelayTool);
451 thread.send(
452 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
453 cx,
454 )
455 })
456 .collect()
457 .await;
458
459 let stop_reasons = stop_events(events);
460 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
461
462 thread.update(cx, |thread, _cx| {
463 let last_message = thread.messages().last().unwrap();
464 let text = last_message
465 .content
466 .iter()
467 .filter_map(|content| {
468 if let MessageContent::Text(text) = content {
469 Some(text.as_str())
470 } else {
471 None
472 }
473 })
474 .collect::<String>();
475
476 assert!(text.contains("Ding"));
477 });
478}
479
480#[gpui::test]
481async fn test_profiles(cx: &mut TestAppContext) {
482 let ThreadTest {
483 model, thread, fs, ..
484 } = setup(cx, TestModel::Fake).await;
485 let fake_model = model.as_fake();
486
487 thread.update(cx, |thread, _cx| {
488 thread.add_tool(DelayTool);
489 thread.add_tool(EchoTool);
490 thread.add_tool(InfiniteTool);
491 });
492
493 // Override profiles and wait for settings to be loaded.
494 fs.insert_file(
495 paths::settings_file(),
496 json!({
497 "agent": {
498 "profiles": {
499 "test-1": {
500 "name": "Test Profile 1",
501 "tools": {
502 EchoTool.name(): true,
503 DelayTool.name(): true,
504 }
505 },
506 "test-2": {
507 "name": "Test Profile 2",
508 "tools": {
509 InfiniteTool.name(): true,
510 }
511 }
512 }
513 }
514 })
515 .to_string()
516 .into_bytes(),
517 )
518 .await;
519 cx.run_until_parked();
520
521 // Test that test-1 profile (default) has echo and delay tools
522 thread.update(cx, |thread, cx| {
523 thread.set_profile(AgentProfileId("test-1".into()));
524 thread.send("test", cx);
525 });
526 cx.run_until_parked();
527
528 let mut pending_completions = fake_model.pending_completions();
529 assert_eq!(pending_completions.len(), 1);
530 let completion = pending_completions.pop().unwrap();
531 let tool_names: Vec<String> = completion
532 .tools
533 .iter()
534 .map(|tool| tool.name.clone())
535 .collect();
536 assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
537 fake_model.end_last_completion_stream();
538
539 // Switch to test-2 profile, and verify that it has only the infinite tool.
540 thread.update(cx, |thread, cx| {
541 thread.set_profile(AgentProfileId("test-2".into()));
542 thread.send("test2", cx)
543 });
544 cx.run_until_parked();
545 let mut pending_completions = fake_model.pending_completions();
546 assert_eq!(pending_completions.len(), 1);
547 let completion = pending_completions.pop().unwrap();
548 let tool_names: Vec<String> = completion
549 .tools
550 .iter()
551 .map(|tool| tool.name.clone())
552 .collect();
553 assert_eq!(tool_names, vec![InfiniteTool.name()]);
554}
555
556#[gpui::test]
557#[ignore = "can't run on CI yet"]
558async fn test_cancellation(cx: &mut TestAppContext) {
559 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
560
561 let mut events = thread.update(cx, |thread, cx| {
562 thread.add_tool(InfiniteTool);
563 thread.add_tool(EchoTool);
564 thread.send(
565 "Call the echo tool and then call the infinite tool, then explain their output",
566 cx,
567 )
568 });
569
570 // Wait until both tools are called.
571 let mut expected_tools = vec!["Echo", "Infinite Tool"];
572 let mut echo_id = None;
573 let mut echo_completed = false;
574 while let Some(event) = events.next().await {
575 match event.unwrap() {
576 AgentResponseEvent::ToolCall(tool_call) => {
577 assert_eq!(tool_call.title, expected_tools.remove(0));
578 if tool_call.title == "Echo" {
579 echo_id = Some(tool_call.id);
580 }
581 }
582 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
583 acp::ToolCallUpdate {
584 id,
585 fields:
586 acp::ToolCallUpdateFields {
587 status: Some(acp::ToolCallStatus::Completed),
588 ..
589 },
590 },
591 )) if Some(&id) == echo_id.as_ref() => {
592 echo_completed = true;
593 }
594 _ => {}
595 }
596
597 if expected_tools.is_empty() && echo_completed {
598 break;
599 }
600 }
601
602 // Cancel the current send and ensure that the event stream is closed, even
603 // if one of the tools is still running.
604 thread.update(cx, |thread, _cx| thread.cancel());
605 events.collect::<Vec<_>>().await;
606
607 // Ensure we can still send a new message after cancellation.
608 let events = thread
609 .update(cx, |thread, cx| {
610 thread.send("Testing: reply with 'Hello' then stop.", cx)
611 })
612 .collect::<Vec<_>>()
613 .await;
614 thread.update(cx, |thread, _cx| {
615 assert_eq!(
616 thread.messages().last().unwrap().content,
617 vec![MessageContent::Text("Hello".to_string())]
618 );
619 });
620 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
621}
622
623#[gpui::test]
624async fn test_refusal(cx: &mut TestAppContext) {
625 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
626 let fake_model = model.as_fake();
627
628 let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
629 cx.run_until_parked();
630 thread.read_with(cx, |thread, _| {
631 assert_eq!(
632 thread.to_markdown(),
633 indoc! {"
634 ## user
635 Hello
636 "}
637 );
638 });
639
640 fake_model.send_last_completion_stream_text_chunk("Hey!");
641 cx.run_until_parked();
642 thread.read_with(cx, |thread, _| {
643 assert_eq!(
644 thread.to_markdown(),
645 indoc! {"
646 ## user
647 Hello
648 ## assistant
649 Hey!
650 "}
651 );
652 });
653
654 // If the model refuses to continue, the thread should remove all the messages after the last user message.
655 fake_model
656 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
657 let events = events.collect::<Vec<_>>().await;
658 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
659 thread.read_with(cx, |thread, _| {
660 assert_eq!(thread.to_markdown(), "");
661 });
662}
663
664#[gpui::test]
665async fn test_agent_connection(cx: &mut TestAppContext) {
666 cx.update(settings::init);
667 let templates = Templates::new();
668
669 // Initialize language model system with test provider
670 cx.update(|cx| {
671 gpui_tokio::init(cx);
672 client::init_settings(cx);
673
674 let http_client = FakeHttpClient::with_404_response();
675 let clock = Arc::new(clock::FakeSystemClock::new());
676 let client = Client::new(clock, http_client, cx);
677 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
678 language_model::init(client.clone(), cx);
679 language_models::init(user_store.clone(), client.clone(), cx);
680 Project::init_settings(cx);
681 LanguageModelRegistry::test(cx);
682 agent_settings::init(cx);
683 });
684 cx.executor().forbid_parking();
685
686 // Create a project for new_thread
687 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
688 fake_fs.insert_tree(path!("/test"), json!({})).await;
689 let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
690 let cwd = Path::new("/test");
691
692 // Create agent and connection
693 let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
694 .await
695 .unwrap();
696 let connection = NativeAgentConnection(agent.clone());
697
698 // Test model_selector returns Some
699 let selector_opt = connection.model_selector();
700 assert!(
701 selector_opt.is_some(),
702 "agent2 should always support ModelSelector"
703 );
704 let selector = selector_opt.unwrap();
705
706 // Test list_models
707 let listed_models = cx
708 .update(|cx| {
709 let mut async_cx = cx.to_async();
710 selector.list_models(&mut async_cx)
711 })
712 .await
713 .expect("list_models should succeed");
714 assert!(!listed_models.is_empty(), "should have at least one model");
715 assert_eq!(listed_models[0].id().0, "fake");
716
717 // Create a thread using new_thread
718 let connection_rc = Rc::new(connection.clone());
719 let acp_thread = cx
720 .update(|cx| {
721 let mut async_cx = cx.to_async();
722 connection_rc.new_thread(project, cwd, &mut async_cx)
723 })
724 .await
725 .expect("new_thread should succeed");
726
727 // Get the session_id from the AcpThread
728 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
729
730 // Test selected_model returns the default
731 let model = cx
732 .update(|cx| {
733 let mut async_cx = cx.to_async();
734 selector.selected_model(&session_id, &mut async_cx)
735 })
736 .await
737 .expect("selected_model should succeed");
738 let model = model.as_fake();
739 assert_eq!(model.id().0, "fake", "should return default model");
740
741 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
742 cx.run_until_parked();
743 model.send_last_completion_stream_text_chunk("def");
744 cx.run_until_parked();
745 acp_thread.read_with(cx, |thread, cx| {
746 assert_eq!(
747 thread.to_markdown(cx),
748 indoc! {"
749 ## User
750
751 abc
752
753 ## Assistant
754
755 def
756
757 "}
758 )
759 });
760
761 // Test cancel
762 cx.update(|cx| connection.cancel(&session_id, cx));
763 request.await.expect("prompt should fail gracefully");
764
765 // Ensure that dropping the ACP thread causes the native thread to be
766 // dropped as well.
767 cx.update(|_| drop(acp_thread));
768 let result = cx
769 .update(|cx| {
770 connection.prompt(
771 acp::PromptRequest {
772 session_id: session_id.clone(),
773 prompt: vec!["ghi".into()],
774 },
775 cx,
776 )
777 })
778 .await;
779 assert_eq!(
780 result.as_ref().unwrap_err().to_string(),
781 "Session not found",
782 "unexpected result: {:?}",
783 result
784 );
785}
786
787#[gpui::test]
788async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
789 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
790 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
791 let fake_model = model.as_fake();
792
793 let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
794 cx.run_until_parked();
795
796 // Simulate streaming partial input.
797 let input = json!({});
798 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
799 LanguageModelToolUse {
800 id: "1".into(),
801 name: ThinkingTool.name().into(),
802 raw_input: input.to_string(),
803 input,
804 is_input_complete: false,
805 },
806 ));
807
808 // Input streaming completed
809 let input = json!({ "content": "Thinking hard!" });
810 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
811 LanguageModelToolUse {
812 id: "1".into(),
813 name: "thinking".into(),
814 raw_input: input.to_string(),
815 input,
816 is_input_complete: true,
817 },
818 ));
819 fake_model.end_last_completion_stream();
820 cx.run_until_parked();
821
822 let tool_call = expect_tool_call(&mut events).await;
823 assert_eq!(
824 tool_call,
825 acp::ToolCall {
826 id: acp::ToolCallId("1".into()),
827 title: "Thinking".into(),
828 kind: acp::ToolKind::Think,
829 status: acp::ToolCallStatus::Pending,
830 content: vec![],
831 locations: vec![],
832 raw_input: Some(json!({})),
833 raw_output: None,
834 }
835 );
836 let update = expect_tool_call_update_fields(&mut events).await;
837 assert_eq!(
838 update,
839 acp::ToolCallUpdate {
840 id: acp::ToolCallId("1".into()),
841 fields: acp::ToolCallUpdateFields {
842 title: Some("Thinking".into()),
843 kind: Some(acp::ToolKind::Think),
844 raw_input: Some(json!({ "content": "Thinking hard!" })),
845 ..Default::default()
846 },
847 }
848 );
849 let update = expect_tool_call_update_fields(&mut events).await;
850 assert_eq!(
851 update,
852 acp::ToolCallUpdate {
853 id: acp::ToolCallId("1".into()),
854 fields: acp::ToolCallUpdateFields {
855 status: Some(acp::ToolCallStatus::InProgress),
856 ..Default::default()
857 },
858 }
859 );
860 let update = expect_tool_call_update_fields(&mut events).await;
861 assert_eq!(
862 update,
863 acp::ToolCallUpdate {
864 id: acp::ToolCallId("1".into()),
865 fields: acp::ToolCallUpdateFields {
866 content: Some(vec!["Thinking hard!".into()]),
867 ..Default::default()
868 },
869 }
870 );
871 let update = expect_tool_call_update_fields(&mut events).await;
872 assert_eq!(
873 update,
874 acp::ToolCallUpdate {
875 id: acp::ToolCallId("1".into()),
876 fields: acp::ToolCallUpdateFields {
877 status: Some(acp::ToolCallStatus::Completed),
878 raw_output: Some("Finished thinking.".into()),
879 ..Default::default()
880 },
881 }
882 );
883}
884
885/// Filters out the stop events for asserting against in tests
886fn stop_events(
887 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
888) -> Vec<acp::StopReason> {
889 result_events
890 .into_iter()
891 .filter_map(|event| match event.unwrap() {
892 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
893 _ => None,
894 })
895 .collect()
896}
897
898struct ThreadTest {
899 model: Arc<dyn LanguageModel>,
900 thread: Entity<Thread>,
901 project_context: Rc<RefCell<ProjectContext>>,
902 fs: Arc<FakeFs>,
903}
904
905enum TestModel {
906 Sonnet4,
907 Sonnet4Thinking,
908 Fake,
909}
910
911impl TestModel {
912 fn id(&self) -> LanguageModelId {
913 match self {
914 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
915 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
916 TestModel::Fake => unreachable!(),
917 }
918 }
919}
920
921async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
922 cx.executor().allow_parking();
923
924 let fs = FakeFs::new(cx.background_executor.clone());
925 fs.create_dir(paths::settings_file().parent().unwrap())
926 .await
927 .unwrap();
928 fs.insert_file(
929 paths::settings_file(),
930 json!({
931 "agent": {
932 "default_profile": "test-profile",
933 "profiles": {
934 "test-profile": {
935 "name": "Test Profile",
936 "tools": {
937 EchoTool.name(): true,
938 DelayTool.name(): true,
939 WordListTool.name(): true,
940 ToolRequiringPermission.name(): true,
941 InfiniteTool.name(): true,
942 }
943 }
944 }
945 }
946 })
947 .to_string()
948 .into_bytes(),
949 )
950 .await;
951
952 cx.update(|cx| {
953 settings::init(cx);
954 Project::init_settings(cx);
955 agent_settings::init(cx);
956 gpui_tokio::init(cx);
957 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
958 cx.set_http_client(Arc::new(http_client));
959
960 client::init_settings(cx);
961 let client = Client::production(cx);
962 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
963 language_model::init(client.clone(), cx);
964 language_models::init(user_store.clone(), client.clone(), cx);
965
966 watch_settings(fs.clone(), cx);
967 });
968
969 let templates = Templates::new();
970
971 fs.insert_tree(path!("/test"), json!({})).await;
972 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
973
974 let model = cx
975 .update(|cx| {
976 if let TestModel::Fake = model {
977 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
978 } else {
979 let model_id = model.id();
980 let models = LanguageModelRegistry::read_global(cx);
981 let model = models
982 .available_models(cx)
983 .find(|model| model.id() == model_id)
984 .unwrap();
985
986 let provider = models.provider(&model.provider_id()).unwrap();
987 let authenticated = provider.authenticate(cx);
988
989 cx.spawn(async move |_cx| {
990 authenticated.await.unwrap();
991 model
992 })
993 }
994 })
995 .await;
996
997 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
998 let context_server_registry =
999 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1000 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1001 let thread = cx.new(|cx| {
1002 Thread::new(
1003 project,
1004 project_context.clone(),
1005 context_server_registry,
1006 action_log,
1007 templates,
1008 model.clone(),
1009 cx,
1010 )
1011 });
1012 ThreadTest {
1013 model,
1014 thread,
1015 project_context,
1016 fs,
1017 }
1018}
1019
1020#[cfg(test)]
1021#[ctor::ctor]
1022fn init_logger() {
1023 if std::env::var("RUST_LOG").is_ok() {
1024 env_logger::init();
1025 }
1026}
1027
1028fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1029 let fs = fs.clone();
1030 cx.spawn({
1031 async move |cx| {
1032 let mut new_settings_content_rx = settings::watch_config_file(
1033 cx.background_executor(),
1034 fs,
1035 paths::settings_file().clone(),
1036 );
1037
1038 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1039 cx.update(|cx| {
1040 SettingsStore::update_global(cx, |settings, cx| {
1041 settings.set_user_settings(&new_settings_content, cx)
1042 })
1043 })
1044 .ok();
1045 }
1046 }
1047 })
1048 .detach();
1049}