1use super::*;
2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
3use action_log::ActionLog;
4use agent_client_protocol::{self as acp};
5use agent_settings::AgentProfileId;
6use anyhow::Result;
7use client::{Client, UserStore};
8use fs::{FakeFs, Fs};
9use futures::channel::mpsc::UnboundedReceiver;
10use gpui::{
11 App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
12};
13use indoc::indoc;
14use language_model::{
15 LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
16 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
17 Role, StopReason, fake_provider::FakeLanguageModel,
18};
19use project::Project;
20use prompt_store::ProjectContext;
21use reqwest_client::ReqwestClient;
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use serde_json::json;
25use settings::SettingsStore;
26use smol::stream::StreamExt;
27use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
28use util::path;
29
30mod test_tools;
31use test_tools::*;
32
33#[gpui::test]
34#[ignore = "can't run on CI yet"]
35async fn test_echo(cx: &mut TestAppContext) {
36 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
37
38 let events = thread
39 .update(cx, |thread, cx| {
40 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
41 })
42 .collect()
43 .await;
44 thread.update(cx, |thread, _cx| {
45 assert_eq!(
46 thread.last_message().unwrap().to_markdown(),
47 indoc! {"
48 ## Assistant
49
50 Hello
51 "}
52 )
53 });
54 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
55}
56
57#[gpui::test]
58#[ignore = "can't run on CI yet"]
59async fn test_thinking(cx: &mut TestAppContext) {
60 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
61
62 let events = thread
63 .update(cx, |thread, cx| {
64 thread.send(
65 UserMessageId::new(),
66 [indoc! {"
67 Testing:
68
69 Generate a thinking step where you just think the word 'Think',
70 and have your final answer be 'Hello'
71 "}],
72 cx,
73 )
74 })
75 .collect()
76 .await;
77 thread.update(cx, |thread, _cx| {
78 assert_eq!(
79 thread.last_message().unwrap().to_markdown(),
80 indoc! {"
81 ## Assistant
82
83 <think>Think</think>
84 Hello
85 "}
86 )
87 });
88 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
89}
90
91#[gpui::test]
92async fn test_system_prompt(cx: &mut TestAppContext) {
93 let ThreadTest {
94 model,
95 thread,
96 project_context,
97 ..
98 } = setup(cx, TestModel::Fake).await;
99 let fake_model = model.as_fake();
100
101 project_context.borrow_mut().shell = "test-shell".into();
102 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
103 thread.update(cx, |thread, cx| {
104 thread.send(UserMessageId::new(), ["abc"], cx)
105 });
106 cx.run_until_parked();
107 let mut pending_completions = fake_model.pending_completions();
108 assert_eq!(
109 pending_completions.len(),
110 1,
111 "unexpected pending completions: {:?}",
112 pending_completions
113 );
114
115 let pending_completion = pending_completions.pop().unwrap();
116 assert_eq!(pending_completion.messages[0].role, Role::System);
117
118 let system_message = &pending_completion.messages[0];
119 let system_prompt = system_message.content[0].to_str().unwrap();
120 assert!(
121 system_prompt.contains("test-shell"),
122 "unexpected system message: {:?}",
123 system_message
124 );
125 assert!(
126 system_prompt.contains("## Fixing Diagnostics"),
127 "unexpected system message: {:?}",
128 system_message
129 );
130}
131
132#[gpui::test]
133#[ignore = "can't run on CI yet"]
134async fn test_basic_tool_calls(cx: &mut TestAppContext) {
135 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
136
137 // Test a tool call that's likely to complete *before* streaming stops.
138 let events = thread
139 .update(cx, |thread, cx| {
140 thread.add_tool(EchoTool);
141 thread.send(
142 UserMessageId::new(),
143 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
144 cx,
145 )
146 })
147 .collect()
148 .await;
149 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
150
151 // Test a tool calls that's likely to complete *after* streaming stops.
152 let events = thread
153 .update(cx, |thread, cx| {
154 thread.remove_tool(&AgentTool::name(&EchoTool));
155 thread.add_tool(DelayTool);
156 thread.send(
157 UserMessageId::new(),
158 [
159 "Now call the delay tool with 200ms.",
160 "When the timer goes off, then you echo the output of the tool.",
161 ],
162 cx,
163 )
164 })
165 .collect()
166 .await;
167 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
168 thread.update(cx, |thread, _cx| {
169 assert!(
170 thread
171 .last_message()
172 .unwrap()
173 .as_agent_message()
174 .unwrap()
175 .content
176 .iter()
177 .any(|content| {
178 if let AgentMessageContent::Text(text) = content {
179 text.contains("Ding")
180 } else {
181 false
182 }
183 }),
184 "{}",
185 thread.to_markdown()
186 );
187 });
188}
189
190#[gpui::test]
191#[ignore = "can't run on CI yet"]
192async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
193 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
194
195 // Test a tool call that's likely to complete *before* streaming stops.
196 let mut events = thread.update(cx, |thread, cx| {
197 thread.add_tool(WordListTool);
198 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
199 });
200
201 let mut saw_partial_tool_use = false;
202 while let Some(event) = events.next().await {
203 if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
204 thread.update(cx, |thread, _cx| {
205 // Look for a tool use in the thread's last message
206 let message = thread.last_message().unwrap();
207 let agent_message = message.as_agent_message().unwrap();
208 let last_content = agent_message.content.last().unwrap();
209 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
210 assert_eq!(last_tool_use.name.as_ref(), "word_list");
211 if tool_call.status == acp::ToolCallStatus::Pending {
212 if !last_tool_use.is_input_complete
213 && last_tool_use.input.get("g").is_none()
214 {
215 saw_partial_tool_use = true;
216 }
217 } else {
218 last_tool_use
219 .input
220 .get("a")
221 .expect("'a' has streamed because input is now complete");
222 last_tool_use
223 .input
224 .get("g")
225 .expect("'g' has streamed because input is now complete");
226 }
227 } else {
228 panic!("last content should be a tool use");
229 }
230 });
231 }
232 }
233
234 assert!(
235 saw_partial_tool_use,
236 "should see at least one partially streamed tool use in the history"
237 );
238}
239
240#[gpui::test]
241async fn test_tool_authorization(cx: &mut TestAppContext) {
242 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
243 let fake_model = model.as_fake();
244
245 let mut events = thread.update(cx, |thread, cx| {
246 thread.add_tool(ToolRequiringPermission);
247 thread.send(UserMessageId::new(), ["abc"], cx)
248 });
249 cx.run_until_parked();
250 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
251 LanguageModelToolUse {
252 id: "tool_id_1".into(),
253 name: ToolRequiringPermission.name().into(),
254 raw_input: "{}".into(),
255 input: json!({}),
256 is_input_complete: true,
257 },
258 ));
259 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
260 LanguageModelToolUse {
261 id: "tool_id_2".into(),
262 name: ToolRequiringPermission.name().into(),
263 raw_input: "{}".into(),
264 input: json!({}),
265 is_input_complete: true,
266 },
267 ));
268 fake_model.end_last_completion_stream();
269 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
270 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
271
272 // Approve the first
273 tool_call_auth_1
274 .response
275 .send(tool_call_auth_1.options[1].id.clone())
276 .unwrap();
277 cx.run_until_parked();
278
279 // Reject the second
280 tool_call_auth_2
281 .response
282 .send(tool_call_auth_1.options[2].id.clone())
283 .unwrap();
284 cx.run_until_parked();
285
286 let completion = fake_model.pending_completions().pop().unwrap();
287 let message = completion.messages.last().unwrap();
288 assert_eq!(
289 message.content,
290 vec![
291 language_model::MessageContent::ToolResult(LanguageModelToolResult {
292 tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
293 tool_name: ToolRequiringPermission.name().into(),
294 is_error: false,
295 content: "Allowed".into(),
296 output: Some("Allowed".into())
297 }),
298 language_model::MessageContent::ToolResult(LanguageModelToolResult {
299 tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
300 tool_name: ToolRequiringPermission.name().into(),
301 is_error: true,
302 content: "Permission to run tool denied by user".into(),
303 output: None
304 })
305 ]
306 );
307
308 // Simulate yet another tool call.
309 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
310 LanguageModelToolUse {
311 id: "tool_id_3".into(),
312 name: ToolRequiringPermission.name().into(),
313 raw_input: "{}".into(),
314 input: json!({}),
315 is_input_complete: true,
316 },
317 ));
318 fake_model.end_last_completion_stream();
319
320 // Respond by always allowing tools.
321 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
322 tool_call_auth_3
323 .response
324 .send(tool_call_auth_3.options[0].id.clone())
325 .unwrap();
326 cx.run_until_parked();
327 let completion = fake_model.pending_completions().pop().unwrap();
328 let message = completion.messages.last().unwrap();
329 assert_eq!(
330 message.content,
331 vec![language_model::MessageContent::ToolResult(
332 LanguageModelToolResult {
333 tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
334 tool_name: ToolRequiringPermission.name().into(),
335 is_error: false,
336 content: "Allowed".into(),
337 output: Some("Allowed".into())
338 }
339 )]
340 );
341
342 // Simulate a final tool call, ensuring we don't trigger authorization.
343 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
344 LanguageModelToolUse {
345 id: "tool_id_4".into(),
346 name: ToolRequiringPermission.name().into(),
347 raw_input: "{}".into(),
348 input: json!({}),
349 is_input_complete: true,
350 },
351 ));
352 fake_model.end_last_completion_stream();
353 cx.run_until_parked();
354 let completion = fake_model.pending_completions().pop().unwrap();
355 let message = completion.messages.last().unwrap();
356 assert_eq!(
357 message.content,
358 vec![language_model::MessageContent::ToolResult(
359 LanguageModelToolResult {
360 tool_use_id: "tool_id_4".into(),
361 tool_name: ToolRequiringPermission.name().into(),
362 is_error: false,
363 content: "Allowed".into(),
364 output: Some("Allowed".into())
365 }
366 )]
367 );
368}
369
370#[gpui::test]
371async fn test_tool_hallucination(cx: &mut TestAppContext) {
372 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
373 let fake_model = model.as_fake();
374
375 let mut events = thread.update(cx, |thread, cx| {
376 thread.send(UserMessageId::new(), ["abc"], cx)
377 });
378 cx.run_until_parked();
379 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
380 LanguageModelToolUse {
381 id: "tool_id_1".into(),
382 name: "nonexistent_tool".into(),
383 raw_input: "{}".into(),
384 input: json!({}),
385 is_input_complete: true,
386 },
387 ));
388 fake_model.end_last_completion_stream();
389
390 let tool_call = expect_tool_call(&mut events).await;
391 assert_eq!(tool_call.title, "nonexistent_tool");
392 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
393 let update = expect_tool_call_update_fields(&mut events).await;
394 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
395}
396
397#[gpui::test]
398async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
399 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
400 let fake_model = model.as_fake();
401
402 let events = thread.update(cx, |thread, cx| {
403 thread.add_tool(EchoTool);
404 thread.send(UserMessageId::new(), ["abc"], cx)
405 });
406 cx.run_until_parked();
407 let tool_use = LanguageModelToolUse {
408 id: "tool_id_1".into(),
409 name: EchoTool.name().into(),
410 raw_input: "{}".into(),
411 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
412 is_input_complete: true,
413 };
414 fake_model
415 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
416 fake_model.end_last_completion_stream();
417
418 cx.run_until_parked();
419 let completion = fake_model.pending_completions().pop().unwrap();
420 let tool_result = LanguageModelToolResult {
421 tool_use_id: "tool_id_1".into(),
422 tool_name: EchoTool.name().into(),
423 is_error: false,
424 content: "def".into(),
425 output: Some("def".into()),
426 };
427 assert_eq!(
428 completion.messages[1..],
429 vec![
430 LanguageModelRequestMessage {
431 role: Role::User,
432 content: vec!["abc".into()],
433 cache: false
434 },
435 LanguageModelRequestMessage {
436 role: Role::Assistant,
437 content: vec![MessageContent::ToolUse(tool_use.clone())],
438 cache: false
439 },
440 LanguageModelRequestMessage {
441 role: Role::User,
442 content: vec![MessageContent::ToolResult(tool_result.clone())],
443 cache: false
444 },
445 ]
446 );
447
448 // Simulate reaching tool use limit.
449 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
450 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
451 ));
452 fake_model.end_last_completion_stream();
453 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
454 assert!(
455 last_event
456 .unwrap_err()
457 .is::<language_model::ToolUseLimitReachedError>()
458 );
459
460 let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
461 cx.run_until_parked();
462 let completion = fake_model.pending_completions().pop().unwrap();
463 assert_eq!(
464 completion.messages[1..],
465 vec![
466 LanguageModelRequestMessage {
467 role: Role::User,
468 content: vec!["abc".into()],
469 cache: false
470 },
471 LanguageModelRequestMessage {
472 role: Role::Assistant,
473 content: vec![MessageContent::ToolUse(tool_use)],
474 cache: false
475 },
476 LanguageModelRequestMessage {
477 role: Role::User,
478 content: vec![MessageContent::ToolResult(tool_result)],
479 cache: false
480 },
481 LanguageModelRequestMessage {
482 role: Role::User,
483 content: vec!["Continue where you left off".into()],
484 cache: false
485 }
486 ]
487 );
488
489 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
490 fake_model.end_last_completion_stream();
491 events.collect::<Vec<_>>().await;
492 thread.read_with(cx, |thread, _cx| {
493 assert_eq!(
494 thread.last_message().unwrap().to_markdown(),
495 indoc! {"
496 ## Assistant
497
498 Done
499 "}
500 )
501 });
502
503 // Ensure we error if calling resume when tool use limit was *not* reached.
504 let error = thread
505 .update(cx, |thread, cx| thread.resume(cx))
506 .unwrap_err();
507 assert_eq!(
508 error.to_string(),
509 "can only resume after tool use limit is reached"
510 )
511}
512
513#[gpui::test]
514async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
515 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
516 let fake_model = model.as_fake();
517
518 let events = thread.update(cx, |thread, cx| {
519 thread.add_tool(EchoTool);
520 thread.send(UserMessageId::new(), ["abc"], cx)
521 });
522 cx.run_until_parked();
523
524 let tool_use = LanguageModelToolUse {
525 id: "tool_id_1".into(),
526 name: EchoTool.name().into(),
527 raw_input: "{}".into(),
528 input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
529 is_input_complete: true,
530 };
531 let tool_result = LanguageModelToolResult {
532 tool_use_id: "tool_id_1".into(),
533 tool_name: EchoTool.name().into(),
534 is_error: false,
535 content: "def".into(),
536 output: Some("def".into()),
537 };
538 fake_model
539 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
540 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
541 cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
542 ));
543 fake_model.end_last_completion_stream();
544 let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
545 assert!(
546 last_event
547 .unwrap_err()
548 .is::<language_model::ToolUseLimitReachedError>()
549 );
550
551 thread.update(cx, |thread, cx| {
552 thread.send(UserMessageId::new(), vec!["ghi"], cx)
553 });
554 cx.run_until_parked();
555 let completion = fake_model.pending_completions().pop().unwrap();
556 assert_eq!(
557 completion.messages[1..],
558 vec![
559 LanguageModelRequestMessage {
560 role: Role::User,
561 content: vec!["abc".into()],
562 cache: false
563 },
564 LanguageModelRequestMessage {
565 role: Role::Assistant,
566 content: vec![MessageContent::ToolUse(tool_use)],
567 cache: false
568 },
569 LanguageModelRequestMessage {
570 role: Role::User,
571 content: vec![MessageContent::ToolResult(tool_result)],
572 cache: false
573 },
574 LanguageModelRequestMessage {
575 role: Role::User,
576 content: vec!["ghi".into()],
577 cache: false
578 }
579 ]
580 );
581}
582
583async fn expect_tool_call(
584 events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
585) -> acp::ToolCall {
586 let event = events
587 .next()
588 .await
589 .expect("no tool call authorization event received")
590 .unwrap();
591 match event {
592 AgentResponseEvent::ToolCall(tool_call) => return tool_call,
593 event => {
594 panic!("Unexpected event {event:?}");
595 }
596 }
597}
598
599async fn expect_tool_call_update_fields(
600 events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
601) -> acp::ToolCallUpdate {
602 let event = events
603 .next()
604 .await
605 .expect("no tool call authorization event received")
606 .unwrap();
607 match event {
608 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
609 return update;
610 }
611 event => {
612 panic!("Unexpected event {event:?}");
613 }
614 }
615}
616
617async fn next_tool_call_authorization(
618 events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
619) -> ToolCallAuthorization {
620 loop {
621 let event = events
622 .next()
623 .await
624 .expect("no tool call authorization event received")
625 .unwrap();
626 if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
627 let permission_kinds = tool_call_authorization
628 .options
629 .iter()
630 .map(|o| o.kind)
631 .collect::<Vec<_>>();
632 assert_eq!(
633 permission_kinds,
634 vec![
635 acp::PermissionOptionKind::AllowAlways,
636 acp::PermissionOptionKind::AllowOnce,
637 acp::PermissionOptionKind::RejectOnce,
638 ]
639 );
640 return tool_call_authorization;
641 }
642 }
643}
644
645#[gpui::test]
646#[ignore = "can't run on CI yet"]
647async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
648 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
649
650 // Test concurrent tool calls with different delay times
651 let events = thread
652 .update(cx, |thread, cx| {
653 thread.add_tool(DelayTool);
654 thread.send(
655 UserMessageId::new(),
656 [
657 "Call the delay tool twice in the same message.",
658 "Once with 100ms. Once with 300ms.",
659 "When both timers are complete, describe the outputs.",
660 ],
661 cx,
662 )
663 })
664 .collect()
665 .await;
666
667 let stop_reasons = stop_events(events);
668 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
669
670 thread.update(cx, |thread, _cx| {
671 let last_message = thread.last_message().unwrap();
672 let agent_message = last_message.as_agent_message().unwrap();
673 let text = agent_message
674 .content
675 .iter()
676 .filter_map(|content| {
677 if let AgentMessageContent::Text(text) = content {
678 Some(text.as_str())
679 } else {
680 None
681 }
682 })
683 .collect::<String>();
684
685 assert!(text.contains("Ding"));
686 });
687}
688
689#[gpui::test]
690async fn test_profiles(cx: &mut TestAppContext) {
691 let ThreadTest {
692 model, thread, fs, ..
693 } = setup(cx, TestModel::Fake).await;
694 let fake_model = model.as_fake();
695
696 thread.update(cx, |thread, _cx| {
697 thread.add_tool(DelayTool);
698 thread.add_tool(EchoTool);
699 thread.add_tool(InfiniteTool);
700 });
701
702 // Override profiles and wait for settings to be loaded.
703 fs.insert_file(
704 paths::settings_file(),
705 json!({
706 "agent": {
707 "profiles": {
708 "test-1": {
709 "name": "Test Profile 1",
710 "tools": {
711 EchoTool.name(): true,
712 DelayTool.name(): true,
713 }
714 },
715 "test-2": {
716 "name": "Test Profile 2",
717 "tools": {
718 InfiniteTool.name(): true,
719 }
720 }
721 }
722 }
723 })
724 .to_string()
725 .into_bytes(),
726 )
727 .await;
728 cx.run_until_parked();
729
730 // Test that test-1 profile (default) has echo and delay tools
731 thread.update(cx, |thread, cx| {
732 thread.set_profile(AgentProfileId("test-1".into()));
733 thread.send(UserMessageId::new(), ["test"], cx);
734 });
735 cx.run_until_parked();
736
737 let mut pending_completions = fake_model.pending_completions();
738 assert_eq!(pending_completions.len(), 1);
739 let completion = pending_completions.pop().unwrap();
740 let tool_names: Vec<String> = completion
741 .tools
742 .iter()
743 .map(|tool| tool.name.clone())
744 .collect();
745 assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
746 fake_model.end_last_completion_stream();
747
748 // Switch to test-2 profile, and verify that it has only the infinite tool.
749 thread.update(cx, |thread, cx| {
750 thread.set_profile(AgentProfileId("test-2".into()));
751 thread.send(UserMessageId::new(), ["test2"], cx)
752 });
753 cx.run_until_parked();
754 let mut pending_completions = fake_model.pending_completions();
755 assert_eq!(pending_completions.len(), 1);
756 let completion = pending_completions.pop().unwrap();
757 let tool_names: Vec<String> = completion
758 .tools
759 .iter()
760 .map(|tool| tool.name.clone())
761 .collect();
762 assert_eq!(tool_names, vec![InfiniteTool.name()]);
763}
764
765#[gpui::test]
766#[ignore = "can't run on CI yet"]
767async fn test_cancellation(cx: &mut TestAppContext) {
768 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
769
770 let mut events = thread.update(cx, |thread, cx| {
771 thread.add_tool(InfiniteTool);
772 thread.add_tool(EchoTool);
773 thread.send(
774 UserMessageId::new(),
775 ["Call the echo tool, then call the infinite tool, then explain their output"],
776 cx,
777 )
778 });
779
780 // Wait until both tools are called.
781 let mut expected_tools = vec!["Echo", "Infinite Tool"];
782 let mut echo_id = None;
783 let mut echo_completed = false;
784 while let Some(event) = events.next().await {
785 match event.unwrap() {
786 AgentResponseEvent::ToolCall(tool_call) => {
787 assert_eq!(tool_call.title, expected_tools.remove(0));
788 if tool_call.title == "Echo" {
789 echo_id = Some(tool_call.id);
790 }
791 }
792 AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
793 acp::ToolCallUpdate {
794 id,
795 fields:
796 acp::ToolCallUpdateFields {
797 status: Some(acp::ToolCallStatus::Completed),
798 ..
799 },
800 },
801 )) if Some(&id) == echo_id.as_ref() => {
802 echo_completed = true;
803 }
804 _ => {}
805 }
806
807 if expected_tools.is_empty() && echo_completed {
808 break;
809 }
810 }
811
812 // Cancel the current send and ensure that the event stream is closed, even
813 // if one of the tools is still running.
814 thread.update(cx, |thread, _cx| thread.cancel());
815 events.collect::<Vec<_>>().await;
816
817 // Ensure we can still send a new message after cancellation.
818 let events = thread
819 .update(cx, |thread, cx| {
820 thread.send(
821 UserMessageId::new(),
822 ["Testing: reply with 'Hello' then stop."],
823 cx,
824 )
825 })
826 .collect::<Vec<_>>()
827 .await;
828 thread.update(cx, |thread, _cx| {
829 let message = thread.last_message().unwrap();
830 let agent_message = message.as_agent_message().unwrap();
831 assert_eq!(
832 agent_message.content,
833 vec![AgentMessageContent::Text("Hello".to_string())]
834 );
835 });
836 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
837}
838
839#[gpui::test]
840async fn test_refusal(cx: &mut TestAppContext) {
841 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
842 let fake_model = model.as_fake();
843
844 let events = thread.update(cx, |thread, cx| {
845 thread.send(UserMessageId::new(), ["Hello"], cx)
846 });
847 cx.run_until_parked();
848 thread.read_with(cx, |thread, _| {
849 assert_eq!(
850 thread.to_markdown(),
851 indoc! {"
852 ## User
853
854 Hello
855 "}
856 );
857 });
858
859 fake_model.send_last_completion_stream_text_chunk("Hey!");
860 cx.run_until_parked();
861 thread.read_with(cx, |thread, _| {
862 assert_eq!(
863 thread.to_markdown(),
864 indoc! {"
865 ## User
866
867 Hello
868
869 ## Assistant
870
871 Hey!
872 "}
873 );
874 });
875
876 // If the model refuses to continue, the thread should remove all the messages after the last user message.
877 fake_model
878 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
879 let events = events.collect::<Vec<_>>().await;
880 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
881 thread.read_with(cx, |thread, _| {
882 assert_eq!(thread.to_markdown(), "");
883 });
884}
885
886#[gpui::test]
887async fn test_truncate(cx: &mut TestAppContext) {
888 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
889 let fake_model = model.as_fake();
890
891 let message_id = UserMessageId::new();
892 thread.update(cx, |thread, cx| {
893 thread.send(message_id.clone(), ["Hello"], cx)
894 });
895 cx.run_until_parked();
896 thread.read_with(cx, |thread, _| {
897 assert_eq!(
898 thread.to_markdown(),
899 indoc! {"
900 ## User
901
902 Hello
903 "}
904 );
905 });
906
907 fake_model.send_last_completion_stream_text_chunk("Hey!");
908 cx.run_until_parked();
909 thread.read_with(cx, |thread, _| {
910 assert_eq!(
911 thread.to_markdown(),
912 indoc! {"
913 ## User
914
915 Hello
916
917 ## Assistant
918
919 Hey!
920 "}
921 );
922 });
923
924 thread
925 .update(cx, |thread, _cx| thread.truncate(message_id))
926 .unwrap();
927 cx.run_until_parked();
928 thread.read_with(cx, |thread, _| {
929 assert_eq!(thread.to_markdown(), "");
930 });
931
932 // Ensure we can still send a new message after truncation.
933 thread.update(cx, |thread, cx| {
934 thread.send(UserMessageId::new(), ["Hi"], cx)
935 });
936 thread.update(cx, |thread, _cx| {
937 assert_eq!(
938 thread.to_markdown(),
939 indoc! {"
940 ## User
941
942 Hi
943 "}
944 );
945 });
946 cx.run_until_parked();
947 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
948 cx.run_until_parked();
949 thread.read_with(cx, |thread, _| {
950 assert_eq!(
951 thread.to_markdown(),
952 indoc! {"
953 ## User
954
955 Hi
956
957 ## Assistant
958
959 Ahoy!
960 "}
961 );
962 });
963}
964
965#[gpui::test]
966async fn test_agent_connection(cx: &mut TestAppContext) {
967 cx.update(settings::init);
968 let templates = Templates::new();
969
970 // Initialize language model system with test provider
971 cx.update(|cx| {
972 gpui_tokio::init(cx);
973 client::init_settings(cx);
974
975 let http_client = FakeHttpClient::with_404_response();
976 let clock = Arc::new(clock::FakeSystemClock::new());
977 let client = Client::new(clock, http_client, cx);
978 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
979 language_model::init(client.clone(), cx);
980 language_models::init(user_store.clone(), client.clone(), cx);
981 Project::init_settings(cx);
982 LanguageModelRegistry::test(cx);
983 agent_settings::init(cx);
984 });
985 cx.executor().forbid_parking();
986
987 // Create a project for new_thread
988 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
989 fake_fs.insert_tree(path!("/test"), json!({})).await;
990 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
991 let cwd = Path::new("/test");
992
993 // Create agent and connection
994 let agent = NativeAgent::new(
995 project.clone(),
996 templates.clone(),
997 None,
998 fake_fs.clone(),
999 &mut cx.to_async(),
1000 )
1001 .await
1002 .unwrap();
1003 let connection = NativeAgentConnection(agent.clone());
1004
1005 // Test model_selector returns Some
1006 let selector_opt = connection.model_selector();
1007 assert!(
1008 selector_opt.is_some(),
1009 "agent2 should always support ModelSelector"
1010 );
1011 let selector = selector_opt.unwrap();
1012
1013 // Test list_models
1014 let listed_models = cx
1015 .update(|cx| selector.list_models(cx))
1016 .await
1017 .expect("list_models should succeed");
1018 let AgentModelList::Grouped(listed_models) = listed_models else {
1019 panic!("Unexpected model list type");
1020 };
1021 assert!(!listed_models.is_empty(), "should have at least one model");
1022 assert_eq!(
1023 listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1024 "fake/fake"
1025 );
1026
1027 // Create a thread using new_thread
1028 let connection_rc = Rc::new(connection.clone());
1029 let acp_thread = cx
1030 .update(|cx| connection_rc.new_thread(project, cwd, cx))
1031 .await
1032 .expect("new_thread should succeed");
1033
1034 // Get the session_id from the AcpThread
1035 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1036
1037 // Test selected_model returns the default
1038 let model = cx
1039 .update(|cx| selector.selected_model(&session_id, cx))
1040 .await
1041 .expect("selected_model should succeed");
1042 let model = cx
1043 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1044 .unwrap();
1045 let model = model.as_fake();
1046 assert_eq!(model.id().0, "fake", "should return default model");
1047
1048 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1049 cx.run_until_parked();
1050 model.send_last_completion_stream_text_chunk("def");
1051 cx.run_until_parked();
1052 acp_thread.read_with(cx, |thread, cx| {
1053 assert_eq!(
1054 thread.to_markdown(cx),
1055 indoc! {"
1056 ## User
1057
1058 abc
1059
1060 ## Assistant
1061
1062 def
1063
1064 "}
1065 )
1066 });
1067
1068 // Test cancel
1069 cx.update(|cx| connection.cancel(&session_id, cx));
1070 request.await.expect("prompt should fail gracefully");
1071
1072 // Ensure that dropping the ACP thread causes the native thread to be
1073 // dropped as well.
1074 cx.update(|_| drop(acp_thread));
1075 let result = cx
1076 .update(|cx| {
1077 connection.prompt(
1078 Some(acp_thread::UserMessageId::new()),
1079 acp::PromptRequest {
1080 session_id: session_id.clone(),
1081 prompt: vec!["ghi".into()],
1082 },
1083 cx,
1084 )
1085 })
1086 .await;
1087 assert_eq!(
1088 result.as_ref().unwrap_err().to_string(),
1089 "Session not found",
1090 "unexpected result: {:?}",
1091 result
1092 );
1093}
1094
1095#[gpui::test]
1096async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1097 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1098 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1099 let fake_model = model.as_fake();
1100
1101 let mut events = thread.update(cx, |thread, cx| {
1102 thread.send(UserMessageId::new(), ["Think"], cx)
1103 });
1104 cx.run_until_parked();
1105
1106 // Simulate streaming partial input.
1107 let input = json!({});
1108 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1109 LanguageModelToolUse {
1110 id: "1".into(),
1111 name: ThinkingTool.name().into(),
1112 raw_input: input.to_string(),
1113 input,
1114 is_input_complete: false,
1115 },
1116 ));
1117
1118 // Input streaming completed
1119 let input = json!({ "content": "Thinking hard!" });
1120 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1121 LanguageModelToolUse {
1122 id: "1".into(),
1123 name: "thinking".into(),
1124 raw_input: input.to_string(),
1125 input,
1126 is_input_complete: true,
1127 },
1128 ));
1129 fake_model.end_last_completion_stream();
1130 cx.run_until_parked();
1131
1132 let tool_call = expect_tool_call(&mut events).await;
1133 assert_eq!(
1134 tool_call,
1135 acp::ToolCall {
1136 id: acp::ToolCallId("1".into()),
1137 title: "Thinking".into(),
1138 kind: acp::ToolKind::Think,
1139 status: acp::ToolCallStatus::Pending,
1140 content: vec![],
1141 locations: vec![],
1142 raw_input: Some(json!({})),
1143 raw_output: None,
1144 }
1145 );
1146 let update = expect_tool_call_update_fields(&mut events).await;
1147 assert_eq!(
1148 update,
1149 acp::ToolCallUpdate {
1150 id: acp::ToolCallId("1".into()),
1151 fields: acp::ToolCallUpdateFields {
1152 title: Some("Thinking".into()),
1153 kind: Some(acp::ToolKind::Think),
1154 raw_input: Some(json!({ "content": "Thinking hard!" })),
1155 ..Default::default()
1156 },
1157 }
1158 );
1159 let update = expect_tool_call_update_fields(&mut events).await;
1160 assert_eq!(
1161 update,
1162 acp::ToolCallUpdate {
1163 id: acp::ToolCallId("1".into()),
1164 fields: acp::ToolCallUpdateFields {
1165 status: Some(acp::ToolCallStatus::InProgress),
1166 ..Default::default()
1167 },
1168 }
1169 );
1170 let update = expect_tool_call_update_fields(&mut events).await;
1171 assert_eq!(
1172 update,
1173 acp::ToolCallUpdate {
1174 id: acp::ToolCallId("1".into()),
1175 fields: acp::ToolCallUpdateFields {
1176 content: Some(vec!["Thinking hard!".into()]),
1177 ..Default::default()
1178 },
1179 }
1180 );
1181 let update = expect_tool_call_update_fields(&mut events).await;
1182 assert_eq!(
1183 update,
1184 acp::ToolCallUpdate {
1185 id: acp::ToolCallId("1".into()),
1186 fields: acp::ToolCallUpdateFields {
1187 status: Some(acp::ToolCallStatus::Completed),
1188 raw_output: Some("Finished thinking.".into()),
1189 ..Default::default()
1190 },
1191 }
1192 );
1193}
1194
1195/// Filters out the stop events for asserting against in tests
1196fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
1197 result_events
1198 .into_iter()
1199 .filter_map(|event| match event.unwrap() {
1200 AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
1201 _ => None,
1202 })
1203 .collect()
1204}
1205
1206struct ThreadTest {
1207 model: Arc<dyn LanguageModel>,
1208 thread: Entity<Thread>,
1209 project_context: Rc<RefCell<ProjectContext>>,
1210 fs: Arc<FakeFs>,
1211}
1212
1213enum TestModel {
1214 Sonnet4,
1215 Sonnet4Thinking,
1216 Fake,
1217}
1218
1219impl TestModel {
1220 fn id(&self) -> LanguageModelId {
1221 match self {
1222 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1223 TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1224 TestModel::Fake => unreachable!(),
1225 }
1226 }
1227}
1228
1229async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1230 cx.executor().allow_parking();
1231
1232 let fs = FakeFs::new(cx.background_executor.clone());
1233 fs.create_dir(paths::settings_file().parent().unwrap())
1234 .await
1235 .unwrap();
1236 fs.insert_file(
1237 paths::settings_file(),
1238 json!({
1239 "agent": {
1240 "default_profile": "test-profile",
1241 "profiles": {
1242 "test-profile": {
1243 "name": "Test Profile",
1244 "tools": {
1245 EchoTool.name(): true,
1246 DelayTool.name(): true,
1247 WordListTool.name(): true,
1248 ToolRequiringPermission.name(): true,
1249 InfiniteTool.name(): true,
1250 }
1251 }
1252 }
1253 }
1254 })
1255 .to_string()
1256 .into_bytes(),
1257 )
1258 .await;
1259
1260 cx.update(|cx| {
1261 settings::init(cx);
1262 Project::init_settings(cx);
1263 agent_settings::init(cx);
1264 gpui_tokio::init(cx);
1265 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1266 cx.set_http_client(Arc::new(http_client));
1267
1268 client::init_settings(cx);
1269 let client = Client::production(cx);
1270 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1271 language_model::init(client.clone(), cx);
1272 language_models::init(user_store.clone(), client.clone(), cx);
1273
1274 watch_settings(fs.clone(), cx);
1275 });
1276
1277 let templates = Templates::new();
1278
1279 fs.insert_tree(path!("/test"), json!({})).await;
1280 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1281
1282 let model = cx
1283 .update(|cx| {
1284 if let TestModel::Fake = model {
1285 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1286 } else {
1287 let model_id = model.id();
1288 let models = LanguageModelRegistry::read_global(cx);
1289 let model = models
1290 .available_models(cx)
1291 .find(|model| model.id() == model_id)
1292 .unwrap();
1293
1294 let provider = models.provider(&model.provider_id()).unwrap();
1295 let authenticated = provider.authenticate(cx);
1296
1297 cx.spawn(async move |_cx| {
1298 authenticated.await.unwrap();
1299 model
1300 })
1301 }
1302 })
1303 .await;
1304
1305 let project_context = Rc::new(RefCell::new(ProjectContext::default()));
1306 let context_server_registry =
1307 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1308 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1309 let thread = cx.new(|cx| {
1310 Thread::new(
1311 project,
1312 project_context.clone(),
1313 context_server_registry,
1314 action_log,
1315 templates,
1316 model.clone(),
1317 cx,
1318 )
1319 });
1320 ThreadTest {
1321 model,
1322 thread,
1323 project_context,
1324 fs,
1325 }
1326}
1327
1328#[cfg(test)]
1329#[ctor::ctor]
1330fn init_logger() {
1331 if std::env::var("RUST_LOG").is_ok() {
1332 env_logger::init();
1333 }
1334}
1335
1336fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1337 let fs = fs.clone();
1338 cx.spawn({
1339 async move |cx| {
1340 let mut new_settings_content_rx = settings::watch_config_file(
1341 cx.background_executor(),
1342 fs,
1343 paths::settings_file().clone(),
1344 );
1345
1346 while let Some(new_settings_content) = new_settings_content_rx.next().await {
1347 cx.update(|cx| {
1348 SettingsStore::update_global(cx, |settings, cx| {
1349 settings.set_user_settings(&new_settings_content, cx)
1350 })
1351 })
1352 .ok();
1353 }
1354 }
1355 })
1356 .detach();
1357}