1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp};
5use agent_settings::AgentProfileId;
6use anyhow::Result;
7use client::{Client, UserStore};
8use fs::{FakeFs, Fs};
9use futures::channel::mpsc::UnboundedReceiver;
10use gpui::{
11 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
12};
13use indoc::indoc;
14use language_model::{
15 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
16 LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
17 fake_provider::FakeLanguageModel,
18};
19use project::Project;
20use prompt_store::ProjectContext;
21use reqwest_client::ReqwestClient;
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25use settings::SettingsStore;
26use smol::stream::StreamExt;
27use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
28use util::path;
29
30mod test_tools;
31use test_tools::*;
32
33#[gpui::test]
34#[ignore = "can't run on CI yet"]
35async fn test_echo(cx: &mut TestAppContext) {
36 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
37
38 let events = thread
39 .update(cx, |thread, cx| {
40 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
41 })
42 .collect()
43 .await;
44 thread.update(cx, |thread, _cx| {
45 assert_eq!(
46 thread.last_message().unwrap().to_markdown(),
47 indoc! {"
48 ## Assistant
49
50 Hello
51 "}
52 )
53 });
54 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
55}
56
57#[gpui::test]
58#[ignore = "can't run on CI yet"]
59async fn test_thinking(cx: &mut TestAppContext) {
60 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
61
62 let events = thread
63 .update(cx, |thread, cx| {
64 thread.send(
65 UserMessageId::new(),
66 [indoc! {"
67 Testing:
68
69 Generate a thinking step where you just think the word 'Think',
70 and have your final answer be 'Hello'
71 "}],
72 cx,
73 )
74 })
75 .collect()
76 .await;
77 thread.update(cx, |thread, _cx| {
78 assert_eq!(
79 thread.last_message().unwrap().to_markdown(),
80 indoc! {"
81 ## Assistant
82
83 <think>Think</think>
84 Hello
85 "}
86 )
87 });
88 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
89}
90
91#[gpui::test]
92async fn test_system_prompt(cx: &mut TestAppContext) {
93 let ThreadTest {
94 model,
95 thread,
96 project_context,
97 ..
98 } = setup(cx, TestModel::Fake).await;
99 let fake_model = model.as_fake();
100
101 project_context.borrow_mut().shell = "test-shell".into();
102 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
103 thread.update(cx, |thread, cx| {
104 thread.send(UserMessageId::new(), ["abc"], cx)
105 });
106 cx.run_until_parked();
107 let mut pending_completions = fake_model.pending_completions();
108 assert_eq!(
109 pending_completions.len(),
110 1,
111 "unexpected pending completions: {:?}",
112 pending_completions
113 );
114
115 let pending_completion = pending_completions.pop().unwrap();
116 assert_eq!(pending_completion.messages[0].role, Role::System);
117
118 let system_message = &pending_completion.messages[0];
119 let system_prompt = system_message.content[0].to_str().unwrap();
120 assert!(
121 system_prompt.contains("test-shell"),
122 "unexpected system message: {:?}",
123 system_message
124 );
125 assert!(
126 system_prompt.contains("## Fixing Diagnostics"),
127 "unexpected system message: {:?}",
128 system_message
129 );
130}
131
132#[gpui::test]
133#[ignore = "can't run on CI yet"]
134async fn test_basic_tool_calls(cx: &mut TestAppContext) {
135 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
136
137 // Test a tool call that's likely to complete *before* streaming stops.
138 let events = thread
139 .update(cx, |thread, cx| {
140 thread.add_tool(EchoTool);
141 thread.send(
142 UserMessageId::new(),
143 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
144 cx,
145 )
146 })
147 .collect()
148 .await;
149 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
150
151 // Test a tool calls that's likely to complete *after* streaming stops.
152 let events = thread
153 .update(cx, |thread, cx| {
154 thread.remove_tool(&AgentTool::name(&EchoTool));
155 thread.add_tool(DelayTool);
156 thread.send(
157 UserMessageId::new(),
158 [
159 "Now call the delay tool with 200ms.",
160 "When the timer goes off, then you echo the output of the tool.",
161 ],
162 cx,
163 )
164 })
165 .collect()
166 .await;
167 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
168 thread.update(cx, |thread, _cx| {
169 assert!(
170 thread
171 .last_message()
172 .unwrap()
173 .as_agent_message()
174 .unwrap()
175 .content
176 .iter()
177 .any(|content| {
178 if let AgentMessageContent::Text(text) = content {
179 text.contains("Ding")
180 } else {
181 false
182 }
183 }),
184 "{}",
185 thread.to_markdown()
186 );
187 });
188}
189
190#[gpui::test]
191#[ignore = "can't run on CI yet"]
192async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
193 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
194
195 // Test a tool call that's likely to complete *before* streaming stops.
196 let mut events = thread.update(cx, |thread, cx| {
197 thread.add_tool(WordListTool);
198 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
199 });
200
201 let mut saw_partial_tool_use = false;
202 while let Some(event) = events.next().await {
203 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
204 thread.update(cx, |thread, _cx| {
205 // Look for a tool use in the thread's last message
206 let message = thread.last_message().unwrap();
207 let agent_message = message.as_agent_message().unwrap();
208 let last_content = agent_message.content.last().unwrap();
209 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
210 assert_eq!(last_tool_use.name.as_ref(), "word_list");
211 if tool_call.status == acp::ToolCallStatus::Pending {
212 if !last_tool_use.is_input_complete
213 && last_tool_use.input.get("g").is_none()
214 {
215 saw_partial_tool_use = true;
216 }
217 } else {
218 last_tool_use
219 .input
220 .get("a")
221 .expect("'a' has streamed because input is now complete");
222 last_tool_use
223 .input
224 .get("g")
225 .expect("'g' has streamed because input is now complete");
226 }
227 } else {
228 panic!("last content should be a tool use");
229 }
230 });
231 }
232 }
233
234 assert!(
235 saw_partial_tool_use,
236 "should see at least one partially streamed tool use in the history"
237 );
238}
239
240#[gpui::test]
241async fn test_tool_authorization(cx: &mut TestAppContext) {
242 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
243 let fake_model = model.as_fake();
244
245 let mut events = thread.update(cx, |thread, cx| {
246 thread.add_tool(ToolRequiringPermission);
247 thread.send(UserMessageId::new(), ["abc"], cx)
248 });
249 cx.run_until_parked();
250 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
251 LanguageModelToolUse {
252 id: "tool_id_1".into(),
253 name: ToolRequiringPermission.name().into(),
254 raw_input: "{}".into(),
255 input: json!({}),
256 is_input_complete: true,
257 },
258 ));
259 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
260 LanguageModelToolUse {
261 id: "tool_id_2".into(),
262 name: ToolRequiringPermission.name().into(),
263 raw_input: "{}".into(),
264 input: json!({}),
265 is_input_complete: true,
266 },
267 ));
268 fake_model.end_last_completion_stream();
269 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
270 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
271
272 // Approve the first
273 tool_call_auth_1
274 .response
275 .send(tool_call_auth_1.options[1].id.clone())
276 .unwrap();
277 cx.run_until_parked();
278
279 // Reject the second
280 tool_call_auth_2
281 .response
282 .send(tool_call_auth_1.options[2].id.clone())
283 .unwrap();
284 cx.run_until_parked();
285
286 let completion = fake_model.pending_completions().pop().unwrap();
287 let message = completion.messages.last().unwrap();
288 assert_eq!(
289 message.content,
290 vec![
291 language_model::MessageContent::ToolResult(LanguageModelToolResult {
292 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
293 tool_name: ToolRequiringPermission.name().into(),
294 is_error: false,
295 content: "Allowed".into(),
296 output: Some("Allowed".into())
297 }),
298 language_model::MessageContent::ToolResult(LanguageModelToolResult {
299 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
300 tool_name: ToolRequiringPermission.name().into(),
301 is_error: true,
302 content: "Permission to run tool denied by user".into(),
303 output: None
304 })
305 ]
306 );
307
308 // Simulate yet another tool call.
309 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
310 LanguageModelToolUse {
311 id: "tool_id_3".into(),
312 name: ToolRequiringPermission.name().into(),
313 raw_input: "{}".into(),
314 input: json!({}),
315 is_input_complete: true,
316 },
317 ));
318 fake_model.end_last_completion_stream();
319
320 // Respond by always allowing tools.
321 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
322 tool_call_auth_3
323 .response
324 .send(tool_call_auth_3.options[0].id.clone())
325 .unwrap();
326 cx.run_until_parked();
327 let completion = fake_model.pending_completions().pop().unwrap();
328 let message = completion.messages.last().unwrap();
329 assert_eq!(
330 message.content,
331 vec![language_model::MessageContent::ToolResult(
332 LanguageModelToolResult {
333 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
334 tool_name: ToolRequiringPermission.name().into(),
335 is_error: false,
336 content: "Allowed".into(),
337 output: Some("Allowed".into())
338 }
339 )]
340 );
341
342 // Simulate a final tool call, ensuring we don't trigger authorization.
343 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
344 LanguageModelToolUse {
345 id: "tool_id_4".into(),
346 name: ToolRequiringPermission.name().into(),
347 raw_input: "{}".into(),
348 input: json!({}),
349 is_input_complete: true,
350 },
351 ));
352 fake_model.end_last_completion_stream();
353 cx.run_until_parked();
354 let completion = fake_model.pending_completions().pop().unwrap();
355 let message = completion.messages.last().unwrap();
356 assert_eq!(
357 message.content,
358 vec![language_model::MessageContent::ToolResult(
359 LanguageModelToolResult {
360 tool_use_id: "tool_id_4".into(),
361 tool_name: ToolRequiringPermission.name().into(),
362 is_error: false,
363 content: "Allowed".into(),
364 output: Some("Allowed".into())
365 }
366 )]
367 );
368}
369
370#[gpui::test]
371async fn test_tool_hallucination(cx: &mut TestAppContext) {
372 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
373 let fake_model = model.as_fake();
374
375 let mut events = thread.update(cx, |thread, cx| {
376 thread.send(UserMessageId::new(), ["abc"], cx)
377 });
378 cx.run_until_parked();
379 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
380 LanguageModelToolUse {
381 id: "tool_id_1".into(),
382 name: "nonexistent_tool".into(),
383 raw_input: "{}".into(),
384 input: json!({}),
385 is_input_complete: true,
386 },
387 ));
388 fake_model.end_last_completion_stream();
389
390 let tool_call = expect_tool_call(&mut events).await;
391 assert_eq!(tool_call.title, "nonexistent_tool");
392 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
393 let update = expect_tool_call_update_fields(&mut events).await;
394 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
395}
396
397async fn expect_tool_call(
398 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
399) -> acp::ToolCall {
400 let event = events
401 .next()
402 .await
403 .expect("no tool call authorization event received")
404 .unwrap();
405 match event {
406 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
407 event => {
408 panic!("Unexpected event {event:?}");
409 }
410 }
411}
412
413async fn expect_tool_call_update_fields(
414 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
415) -> acp::ToolCallUpdate {
416 let event = events
417 .next()
418 .await
419 .expect("no tool call authorization event received")
420 .unwrap();
421 match event {
422 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
423 return update;
424 }
425 event => {
426 panic!("Unexpected event {event:?}");
427 }
428 }
429}
430
431async fn next_tool_call_authorization(
432 events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
433) -> ToolCallAuthorization {
434 loop {
435 let event = events
436 .next()
437 .await
438 .expect("no tool call authorization event received")
439 .unwrap();
440 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
441 let permission_kinds = tool_call_authorization
442 .options
443 .iter()
444 .map(|o| o.kind)
445 .collect::<Vec<_>>();
446 assert_eq!(
447 permission_kinds,
448 vec![
449 acp::PermissionOptionKind::AllowAlways,
450 acp::PermissionOptionKind::AllowOnce,
451 acp::PermissionOptionKind::RejectOnce,
452 ]
453 );
454 return tool_call_authorization;
455 }
456 }
457}
458
459#[gpui::test]
460#[ignore = "can't run on CI yet"]
461async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
462 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
463
464 // Test concurrent tool calls with different delay times
465 let events = thread
466 .update(cx, |thread, cx| {
467 thread.add_tool(DelayTool);
468 thread.send(
469 UserMessageId::new(),
470 [
471 "Call the delay tool twice in the same message.",
472 "Once with 100ms. Once with 300ms.",
473 "When both timers are complete, describe the outputs.",
474 ],
475 cx,
476 )
477 })
478 .collect()
479 .await;
480
481 let stop_reasons = stop_events(events);
482 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
483
484 thread.update(cx, |thread, _cx| {
485 let last_message = thread.last_message().unwrap();
486 let agent_message = last_message.as_agent_message().unwrap();
487 let text = agent_message
488 .content
489 .iter()
490 .filter_map(|content| {
491 if let AgentMessageContent::Text(text) = content {
492 Some(text.as_str())
493 } else {
494 None
495 }
496 })
497 .collect::<String>();
498
499 assert!(text.contains("Ding"));
500 });
501}
502
503#[gpui::test]
504async fn test_profiles(cx: &mut TestAppContext) {
505 let ThreadTest {
506 model, thread, fs, ..
507 } = setup(cx, TestModel::Fake).await;
508 let fake_model = model.as_fake();
509
510 thread.update(cx, |thread, _cx| {
511 thread.add_tool(DelayTool);
512 thread.add_tool(EchoTool);
513 thread.add_tool(InfiniteTool);
514 });
515
516 // Override profiles and wait for settings to be loaded.
517 fs.insert_file(
518 paths::settings_file(),
519 json!({
520 "agent": {
521 "profiles": {
522 "test-1": {
523 "name": "Test Profile 1",
524 "tools": {
525 EchoTool.name(): true,
526 DelayTool.name(): true,
527 }
528 },
529 "test-2": {
530 "name": "Test Profile 2",
531 "tools": {
532 InfiniteTool.name(): true,
533 }
534 }
535 }
536 }
537 })
538 .to_string()
539 .into_bytes(),
540 )
541 .await;
542 cx.run_until_parked();
543
544 // Test that test-1 profile (default) has echo and delay tools
545 thread.update(cx, |thread, cx| {
546 thread.set_profile(AgentProfileId("test-1".into()));
547 thread.send(UserMessageId::new(), ["test"], cx);
548 });
549 cx.run_until_parked();
550
551 let mut pending_completions = fake_model.pending_completions();
552 assert_eq!(pending_completions.len(), 1);
553 let completion = pending_completions.pop().unwrap();
554 let tool_names: Vec<String> = completion
555 .tools
556 .iter()
557 .map(|tool| tool.name.clone())
558 .collect();
559 assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
560 fake_model.end_last_completion_stream();
561
562 // Switch to test-2 profile, and verify that it has only the infinite tool.
563 thread.update(cx, |thread, cx| {
564 thread.set_profile(AgentProfileId("test-2".into()));
565 thread.send(UserMessageId::new(), ["test2"], cx)
566 });
567 cx.run_until_parked();
568 let mut pending_completions = fake_model.pending_completions();
569 assert_eq!(pending_completions.len(), 1);
570 let completion = pending_completions.pop().unwrap();
571 let tool_names: Vec<String> = completion
572 .tools
573 .iter()
574 .map(|tool| tool.name.clone())
575 .collect();
576 assert_eq!(tool_names, vec![InfiniteTool.name()]);
577}
578
579#[gpui::test]
580#[ignore = "can't run on CI yet"]
581async fn test_cancellation(cx: &mut TestAppContext) {
582 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
583
584 let mut events = thread.update(cx, |thread, cx| {
585 thread.add_tool(InfiniteTool);
586 thread.add_tool(EchoTool);
587 thread.send(
588 UserMessageId::new(),
589 ["Call the echo tool, then call the infinite tool, then explain their output"],
590 cx,
591 )
592 });
593
594 // Wait until both tools are called.
595 let mut expected_tools = vec!["Echo", "Infinite Tool"];
596 let mut echo_id = None;
597 let mut echo_completed = false;
598 while let Some(event) = events.next().await {
599 match event.unwrap() {
600 AgentResponseEvent::ToolCall(tool_call) => {
601 assert_eq!(tool_call.title, expected_tools.remove(0));
602 if tool_call.title == "Echo" {
603 echo_id = Some(tool_call.id);
604 }
605 }
606 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
607 acp::ToolCallUpdate {
608 id,
609 fields:
610 acp::ToolCallUpdateFields {
611 status: Some(acp::ToolCallStatus::Completed),
612 ..
613 },
614 },
615 )) if Some(&id) == echo_id.as_ref() => {
616 echo_completed = true;
617 }
618 _ => {}
619 }
620
621 if expected_tools.is_empty() && echo_completed {
622 break;
623 }
624 }
625
626 // Cancel the current send and ensure that the event stream is closed, even
627 // if one of the tools is still running.
628 thread.update(cx, |thread, _cx| thread.cancel());
629 events.collect::<Vec<_>>().await;
630
631 // Ensure we can still send a new message after cancellation.
632 let events = thread
633 .update(cx, |thread, cx| {
634 thread.send(
635 UserMessageId::new(),
636 ["Testing: reply with 'Hello' then stop."],
637 cx,
638 )
639 })
640 .collect::<Vec<_>>()
641 .await;
642 thread.update(cx, |thread, _cx| {
643 let message = thread.last_message().unwrap();
644 let agent_message = message.as_agent_message().unwrap();
645 assert_eq!(
646 agent_message.content,
647 vec![AgentMessageContent::Text("Hello".to_string())]
648 );
649 });
650 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
651}
652
653#[gpui::test]
654async fn test_refusal(cx: &mut TestAppContext) {
655 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
656 let fake_model = model.as_fake();
657
658 let events = thread.update(cx, |thread, cx| {
659 thread.send(UserMessageId::new(), ["Hello"], cx)
660 });
661 cx.run_until_parked();
662 thread.read_with(cx, |thread, _| {
663 assert_eq!(
664 thread.to_markdown(),
665 indoc! {"
666 ## User
667
668 Hello
669 "}
670 );
671 });
672
673 fake_model.send_last_completion_stream_text_chunk("Hey!");
674 cx.run_until_parked();
675 thread.read_with(cx, |thread, _| {
676 assert_eq!(
677 thread.to_markdown(),
678 indoc! {"
679 ## User
680
681 Hello
682
683 ## Assistant
684
685 Hey!
686 "}
687 );
688 });
689
690 // If the model refuses to continue, the thread should remove all the messages after the last user message.
691 fake_model
692 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
693 let events = events.collect::<Vec<_>>().await;
694 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
695 thread.read_with(cx, |thread, _| {
696 assert_eq!(thread.to_markdown(), "");
697 });
698}
699
700#[gpui::test]
701async fn test_truncate(cx: &mut TestAppContext) {
702 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
703 let fake_model = model.as_fake();
704
705 let message_id = UserMessageId::new();
706 thread.update(cx, |thread, cx| {
707 thread.send(message_id.clone(), ["Hello"], cx)
708 });
709 cx.run_until_parked();
710 thread.read_with(cx, |thread, _| {
711 assert_eq!(
712 thread.to_markdown(),
713 indoc! {"
714 ## User
715
716 Hello
717 "}
718 );
719 });
720
721 fake_model.send_last_completion_stream_text_chunk("Hey!");
722 cx.run_until_parked();
723 thread.read_with(cx, |thread, _| {
724 assert_eq!(
725 thread.to_markdown(),
726 indoc! {"
727 ## User
728
729 Hello
730
731 ## Assistant
732
733 Hey!
734 "}
735 );
736 });
737
738 thread
739 .update(cx, |thread, _cx| thread.truncate(message_id))
740 .unwrap();
741 cx.run_until_parked();
742 thread.read_with(cx, |thread, _| {
743 assert_eq!(thread.to_markdown(), "");
744 });
745
746 // Ensure we can still send a new message after truncation.
747 thread.update(cx, |thread, cx| {
748 thread.send(UserMessageId::new(), ["Hi"], cx)
749 });
750 thread.update(cx, |thread, _cx| {
751 assert_eq!(
752 thread.to_markdown(),
753 indoc! {"
754 ## User
755
756 Hi
757 "}
758 );
759 });
760 cx.run_until_parked();
761 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
762 cx.run_until_parked();
763 thread.read_with(cx, |thread, _| {
764 assert_eq!(
765 thread.to_markdown(),
766 indoc! {"
767 ## User
768
769 Hi
770
771 ## Assistant
772
773 Ahoy!
774 "}
775 );
776 });
777}
778
779#[gpui::test]
780async fn test_agent_connection(cx: &mut TestAppContext) {
781 cx.update(settings::init);
782 let templates = Templates::new();
783
784 // Initialize language model system with test provider
785 cx.update(|cx| {
786 gpui_tokio::init(cx);
787 client::init_settings(cx);
788
789 let http_client = FakeHttpClient::with_404_response();
790 let clock = Arc::new(clock::FakeSystemClock::new());
791 let client = Client::new(clock, http_client, cx);
792 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
793 language_model::init(client.clone(), cx);
794 language_models::init(user_store.clone(), client.clone(), cx);
795 Project::init_settings(cx);
796 LanguageModelRegistry::test(cx);
797 agent_settings::init(cx);
798 });
799 cx.executor().forbid_parking();
800
801 // Create a project for new_thread
802 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
803 fake_fs.insert_tree(path!("/test"), json!({})).await;
804 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
805 let cwd = Path::new("/test");
806
807 // Create agent and connection
808 let agent = NativeAgent::new(
809 project.clone(),
810 templates.clone(),
811 None,
812 fake_fs.clone(),
813 &mut cx.to_async(),
814 )
815 .await
816 .unwrap();
817 let connection = NativeAgentConnection(agent.clone());
818
819 // Test model_selector returns Some
820 let selector_opt = connection.model_selector();
821 assert!(
822 selector_opt.is_some(),
823 "agent2 should always support ModelSelector"
824 );
825 let selector = selector_opt.unwrap();
826
827 // Test list_models
828 let listed_models = cx
829 .update(|cx| selector.list_models(cx))
830 .await
831 .expect("list_models should succeed");
832 let AgentModelList::Grouped(listed_models) = listed_models else {
833 panic!("Unexpected model list type");
834 };
835 assert!(!listed_models.is_empty(), "should have at least one model");
836 assert_eq!(
837 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
838 "fake/fake"
839 );
840
841 // Create a thread using new_thread
842 let connection_rc = Rc::new(connection.clone());
843 let acp_thread = cx
844 .update(|cx| connection_rc.new_thread(project, cwd, cx))
845 .await
846 .expect("new_thread should succeed");
847
848 // Get the session_id from the AcpThread
849 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
850
851 // Test selected_model returns the default
852 let model = cx
853 .update(|cx| selector.selected_model(&session_id, cx))
854 .await
855 .expect("selected_model should succeed");
856 let model = cx
857 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
858 .unwrap();
859 let model = model.as_fake();
860 assert_eq!(model.id().0, "fake", "should return default model");
861
862 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
863 cx.run_until_parked();
864 model.send_last_completion_stream_text_chunk("def");
865 cx.run_until_parked();
866 acp_thread.read_with(cx, |thread, cx| {
867 assert_eq!(
868 thread.to_markdown(cx),
869 indoc! {"
870 ## User
871
872 abc
873
874 ## Assistant
875
876 def
877
878 "}
879 )
880 });
881
882 // Test cancel
883 cx.update(|cx| connection.cancel(&session_id, cx));
884 request.await.expect("prompt should fail gracefully");
885
886 // Ensure that dropping the ACP thread causes the native thread to be
887 // dropped as well.
888 cx.update(|_| drop(acp_thread));
889 let result = cx
890 .update(|cx| {
891 connection.prompt(
892 Some(acp_thread::UserMessageId::new()),
893 acp::PromptRequest {
894 session_id: session_id.clone(),
895 prompt: vec!["ghi".into()],
896 },
897 cx,
898 )
899 })
900 .await;
901 assert_eq!(
902 result.as_ref().unwrap_err().to_string(),
903 "Session not found",
904 "unexpected result: {:?}",
905 result
906 );
907}
908
909#[gpui::test]
910async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
911 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
912 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
913 let fake_model = model.as_fake();
914
915 let mut events = thread.update(cx, |thread, cx| {
916 thread.send(UserMessageId::new(), ["Think"], cx)
917 });
918 cx.run_until_parked();
919
920 // Simulate streaming partial input.
921 let input = json!({});
922 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
923 LanguageModelToolUse {
924 id: "1".into(),
925 name: ThinkingTool.name().into(),
926 raw_input: input.to_string(),
927 input,
928 is_input_complete: false,
929 },
930 ));
931
932 // Input streaming completed
933 let input = json!({ "content": "Thinking hard!" });
934 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
935 LanguageModelToolUse {
936 id: "1".into(),
937 name: "thinking".into(),
938 raw_input: input.to_string(),
939 input,
940 is_input_complete: true,
941 },
942 ));
943 fake_model.end_last_completion_stream();
944 cx.run_until_parked();
945
946 let tool_call = expect_tool_call(&mut events).await;
947 assert_eq!(
948 tool_call,
949 acp::ToolCall {
950 id: acp::ToolCallId("1".into()),
951 title: "Thinking".into(),
952 kind: acp::ToolKind::Think,
953 status: acp::ToolCallStatus::Pending,
954 content: vec![],
955 locations: vec![],
956 raw_input: Some(json!({})),
957 raw_output: None,
958 }
959 );
960 let update = expect_tool_call_update_fields(&mut events).await;
961 assert_eq!(
962 update,
963 acp::ToolCallUpdate {
964 id: acp::ToolCallId("1".into()),
965 fields: acp::ToolCallUpdateFields {
966 title: Some("Thinking".into()),
967 kind: Some(acp::ToolKind::Think),
968 raw_input: Some(json!({ "content": "Thinking hard!" })),
969 ..Default::default()
970 },
971 }
972 );
973 let update = expect_tool_call_update_fields(&mut events).await;
974 assert_eq!(
975 update,
976 acp::ToolCallUpdate {
977 id: acp::ToolCallId("1".into()),
978 fields: acp::ToolCallUpdateFields {
979 status: Some(acp::ToolCallStatus::InProgress),
980 ..Default::default()
981 },
982 }
983 );
984 let update = expect_tool_call_update_fields(&mut events).await;
985 assert_eq!(
986 update,
987 acp::ToolCallUpdate {
988 id: acp::ToolCallId("1".into()),
989 fields: acp::ToolCallUpdateFields {
990 content: Some(vec!["Thinking hard!".into()]),
991 ..Default::default()
992 },
993 }
994 );
995 let update = expect_tool_call_update_fields(&mut events).await;
996 assert_eq!(
997 update,
998 acp::ToolCallUpdate {
999 id: acp::ToolCallId("1".into()),
1000 fields: acp::ToolCallUpdateFields {
1001 status: Some(acp::ToolCallStatus::Completed),
1002 raw_output: Some("Finished thinking.".into()),
1003 ..Default::default()
1004 },
1005 }
1006 );
1007}
1008
1009/// Filters out the stop events for asserting against in tests
1010fn stop_events(
1011 result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1012) -> Vec<acp::StopReason> {
1013 result_events
1014 .into_iter()
1015 .filter_map(|event| match event.unwrap() {
1016 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
1017 _ => None,
1018 })
1019 .collect()
1020}
1021
1022struct ThreadTest {
1023 model: Arc<dyn LanguageModel>,
1024 thread: Entity<Thread>,
1025 project_context: Rc<RefCell<ProjectContext>>,
1026 fs: Arc<FakeFs>,
1027}
1028
1029enum TestModel {
1030 Sonnet4,
1031 Sonnet4Thinking,
1032 Fake,
1033}
1034
1035impl TestModel {
1036 fn id(&self) -> LanguageModelId {
1037 match self {
1038 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1039 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1040 TestModel::Fake => unreachable!(),
1041 }
1042 }
1043}
1044
1045async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1046 cx.executor().allow_parking();
1047
1048 let fs = FakeFs::new(cx.background_executor.clone());
1049 fs.create_dir(paths::settings_file().parent().unwrap())
1050 .await
1051 .unwrap();
1052 fs.insert_file(
1053 paths::settings_file(),
1054 json!({
1055 "agent": {
1056 "default_profile": "test-profile",
1057 "profiles": {
1058 "test-profile": {
1059 "name": "Test Profile",
1060 "tools": {
1061 EchoTool.name(): true,
1062 DelayTool.name(): true,
1063 WordListTool.name(): true,
1064 ToolRequiringPermission.name(): true,
1065 InfiniteTool.name(): true,
1066 }
1067 }
1068 }
1069 }
1070 })
1071 .to_string()
1072 .into_bytes(),
1073 )
1074 .await;
1075
1076 cx.update(|cx| {
1077 settings::init(cx);
1078 Project::init_settings(cx);
1079 agent_settings::init(cx);
1080 gpui_tokio::init(cx);
1081 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1082 cx.set_http_client(Arc::new(http_client));
1083
1084 client::init_settings(cx);
1085 let client = Client::production(cx);
1086 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1087 language_model::init(client.clone(), cx);
1088 language_models::init(user_store.clone(), client.clone(), cx);
1089
1090 watch_settings(fs.clone(), cx);
1091 });
1092
1093 let templates = Templates::new();
1094
1095 fs.insert_tree(path!("/test"), json!({})).await;
1096 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1097
1098 let model = cx
1099 .update(|cx| {
1100 if let TestModel::Fake = model {
1101 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1102 } else {
1103 let model_id = model.id();
1104 let models = LanguageModelRegistry::read_global(cx);
1105 let model = models
1106 .available_models(cx)
1107 .find(|model| model.id() == model_id)
1108 .unwrap();
1109
1110 let provider = models.provider(&model.provider_id()).unwrap();
1111 let authenticated = provider.authenticate(cx);
1112
1113 cx.spawn(async move |_cx| {
1114 authenticated.await.unwrap();
1115 model
1116 })
1117 }
1118 })
1119 .await;
1120
1121 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
1122 let context_server_registry =
1123 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1124 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1125 let thread = cx.new(|cx| {
1126 Thread::new(
1127 project,
1128 project_context.clone(),
1129 context_server_registry,
1130 action_log,
1131 templates,
1132 model.clone(),
1133 cx,
1134 )
1135 });
1136 ThreadTest {
1137 model,
1138 thread,
1139 project_context,
1140 fs,
1141 }
1142}
1143
1144#[cfg(test)]
1145#[ctor::ctor]
1146fn init_logger() {
1147 if std::env::var("RUST_LOG").is_ok() {
1148 env_logger::init();
1149 }
1150}
1151
1152fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1153 let fs = fs.clone();
1154 cx.spawn({
1155 async move |cx| {
1156 let mut new_settings_content_rx = settings::watch_config_file(
1157 cx.background_executor(),
1158 fs,
1159 paths::settings_file().clone(),
1160 );
1161
1162 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1163 cx.update(|cx| {
1164 SettingsStore::update_global(cx, |settings, cx| {
1165 settings.set_user_settings(&new_settings_content, cx)
1166 })
1167 })
1168 .ok();
1169 }
1170 }
1171 })
1172 .detach();
1173}