1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeThreadEnvironment {
159 handle: Rc<FakeTerminalHandle>,
160}
161
162impl crate::ThreadEnvironment for FakeThreadEnvironment {
163 fn create_terminal(
164 &self,
165 _command: String,
166 _cwd: Option<std::path::PathBuf>,
167 _output_byte_limit: Option<u64>,
168 _cx: &mut AsyncApp,
169 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
170 Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
171 }
172}
173
174/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
175struct MultiTerminalEnvironment {
176 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
177}
178
179impl MultiTerminalEnvironment {
180 fn new() -> Self {
181 Self {
182 handles: std::cell::RefCell::new(Vec::new()),
183 }
184 }
185
186 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
187 self.handles.borrow().clone()
188 }
189}
190
191impl crate::ThreadEnvironment for MultiTerminalEnvironment {
192 fn create_terminal(
193 &self,
194 _command: String,
195 _cwd: Option<std::path::PathBuf>,
196 _output_byte_limit: Option<u64>,
197 cx: &mut AsyncApp,
198 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
199 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
200 self.handles.borrow_mut().push(handle.clone());
201 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
202 }
203}
204
205fn always_allow_tools(cx: &mut TestAppContext) {
206 cx.update(|cx| {
207 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
208 settings.always_allow_tool_actions = true;
209 agent_settings::AgentSettings::override_global(settings, cx);
210 });
211}
212
213#[gpui::test]
214async fn test_echo(cx: &mut TestAppContext) {
215 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
216 let fake_model = model.as_fake();
217
218 let events = thread
219 .update(cx, |thread, cx| {
220 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
221 })
222 .unwrap();
223 cx.run_until_parked();
224 fake_model.send_last_completion_stream_text_chunk("Hello");
225 fake_model
226 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
227 fake_model.end_last_completion_stream();
228
229 let events = events.collect().await;
230 thread.update(cx, |thread, _cx| {
231 assert_eq!(
232 thread.last_message().unwrap().to_markdown(),
233 indoc! {"
234 ## Assistant
235
236 Hello
237 "}
238 )
239 });
240 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
241}
242
243#[gpui::test]
244async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
245 init_test(cx);
246 always_allow_tools(cx);
247
248 let fs = FakeFs::new(cx.executor());
249 let project = Project::test(fs, [], cx).await;
250
251 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
252 let environment = Rc::new(FakeThreadEnvironment {
253 handle: handle.clone(),
254 });
255
256 #[allow(clippy::arc_with_non_send_sync)]
257 let tool = Arc::new(crate::TerminalTool::new(project, environment));
258 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
259
260 let task = cx.update(|cx| {
261 tool.run(
262 crate::TerminalToolInput {
263 command: "sleep 1000".to_string(),
264 cd: ".".to_string(),
265 timeout_ms: Some(5),
266 },
267 event_stream,
268 cx,
269 )
270 });
271
272 let update = rx.expect_update_fields().await;
273 assert!(
274 update.content.iter().any(|blocks| {
275 blocks
276 .iter()
277 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
278 }),
279 "expected tool call update to include terminal content"
280 );
281
282 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
283
284 let deadline = std::time::Instant::now() + Duration::from_millis(500);
285 loop {
286 if let Some(result) = task_future.as_mut().now_or_never() {
287 let result = result.expect("terminal tool task should complete");
288
289 assert!(
290 handle.was_killed(),
291 "expected terminal handle to be killed on timeout"
292 );
293 assert!(
294 result.contains("partial output"),
295 "expected result to include terminal output, got: {result}"
296 );
297 return;
298 }
299
300 if std::time::Instant::now() >= deadline {
301 panic!("timed out waiting for terminal tool task to complete");
302 }
303
304 cx.run_until_parked();
305 cx.background_executor.timer(Duration::from_millis(1)).await;
306 }
307}
308
309#[gpui::test]
310#[ignore]
311async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
312 init_test(cx);
313 always_allow_tools(cx);
314
315 let fs = FakeFs::new(cx.executor());
316 let project = Project::test(fs, [], cx).await;
317
318 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
319 let environment = Rc::new(FakeThreadEnvironment {
320 handle: handle.clone(),
321 });
322
323 #[allow(clippy::arc_with_non_send_sync)]
324 let tool = Arc::new(crate::TerminalTool::new(project, environment));
325 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
326
327 let _task = cx.update(|cx| {
328 tool.run(
329 crate::TerminalToolInput {
330 command: "sleep 1000".to_string(),
331 cd: ".".to_string(),
332 timeout_ms: None,
333 },
334 event_stream,
335 cx,
336 )
337 });
338
339 let update = rx.expect_update_fields().await;
340 assert!(
341 update.content.iter().any(|blocks| {
342 blocks
343 .iter()
344 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
345 }),
346 "expected tool call update to include terminal content"
347 );
348
349 cx.background_executor
350 .timer(Duration::from_millis(25))
351 .await;
352
353 assert!(
354 !handle.was_killed(),
355 "did not expect terminal handle to be killed without a timeout"
356 );
357}
358
359#[gpui::test]
360async fn test_thinking(cx: &mut TestAppContext) {
361 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
362 let fake_model = model.as_fake();
363
364 let events = thread
365 .update(cx, |thread, cx| {
366 thread.send(
367 UserMessageId::new(),
368 [indoc! {"
369 Testing:
370
371 Generate a thinking step where you just think the word 'Think',
372 and have your final answer be 'Hello'
373 "}],
374 cx,
375 )
376 })
377 .unwrap();
378 cx.run_until_parked();
379 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
380 text: "Think".to_string(),
381 signature: None,
382 });
383 fake_model.send_last_completion_stream_text_chunk("Hello");
384 fake_model
385 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
386 fake_model.end_last_completion_stream();
387
388 let events = events.collect().await;
389 thread.update(cx, |thread, _cx| {
390 assert_eq!(
391 thread.last_message().unwrap().to_markdown(),
392 indoc! {"
393 ## Assistant
394
395 <think>Think</think>
396 Hello
397 "}
398 )
399 });
400 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
401}
402
403#[gpui::test]
404async fn test_system_prompt(cx: &mut TestAppContext) {
405 let ThreadTest {
406 model,
407 thread,
408 project_context,
409 ..
410 } = setup(cx, TestModel::Fake).await;
411 let fake_model = model.as_fake();
412
413 project_context.update(cx, |project_context, _cx| {
414 project_context.shell = "test-shell".into()
415 });
416 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
417 thread
418 .update(cx, |thread, cx| {
419 thread.send(UserMessageId::new(), ["abc"], cx)
420 })
421 .unwrap();
422 cx.run_until_parked();
423 let mut pending_completions = fake_model.pending_completions();
424 assert_eq!(
425 pending_completions.len(),
426 1,
427 "unexpected pending completions: {:?}",
428 pending_completions
429 );
430
431 let pending_completion = pending_completions.pop().unwrap();
432 assert_eq!(pending_completion.messages[0].role, Role::System);
433
434 let system_message = &pending_completion.messages[0];
435 let system_prompt = system_message.content[0].to_str().unwrap();
436 assert!(
437 system_prompt.contains("test-shell"),
438 "unexpected system message: {:?}",
439 system_message
440 );
441 assert!(
442 system_prompt.contains("## Fixing Diagnostics"),
443 "unexpected system message: {:?}",
444 system_message
445 );
446}
447
448#[gpui::test]
449async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
450 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
451 let fake_model = model.as_fake();
452
453 thread
454 .update(cx, |thread, cx| {
455 thread.send(UserMessageId::new(), ["abc"], cx)
456 })
457 .unwrap();
458 cx.run_until_parked();
459 let mut pending_completions = fake_model.pending_completions();
460 assert_eq!(
461 pending_completions.len(),
462 1,
463 "unexpected pending completions: {:?}",
464 pending_completions
465 );
466
467 let pending_completion = pending_completions.pop().unwrap();
468 assert_eq!(pending_completion.messages[0].role, Role::System);
469
470 let system_message = &pending_completion.messages[0];
471 let system_prompt = system_message.content[0].to_str().unwrap();
472 assert!(
473 !system_prompt.contains("## Tool Use"),
474 "unexpected system message: {:?}",
475 system_message
476 );
477 assert!(
478 !system_prompt.contains("## Fixing Diagnostics"),
479 "unexpected system message: {:?}",
480 system_message
481 );
482}
483
484#[gpui::test]
485async fn test_prompt_caching(cx: &mut TestAppContext) {
486 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
487 let fake_model = model.as_fake();
488
489 // Send initial user message and verify it's cached
490 thread
491 .update(cx, |thread, cx| {
492 thread.send(UserMessageId::new(), ["Message 1"], cx)
493 })
494 .unwrap();
495 cx.run_until_parked();
496
497 let completion = fake_model.pending_completions().pop().unwrap();
498 assert_eq!(
499 completion.messages[1..],
500 vec![LanguageModelRequestMessage {
501 role: Role::User,
502 content: vec!["Message 1".into()],
503 cache: true,
504 reasoning_details: None,
505 }]
506 );
507 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
508 "Response to Message 1".into(),
509 ));
510 fake_model.end_last_completion_stream();
511 cx.run_until_parked();
512
513 // Send another user message and verify only the latest is cached
514 thread
515 .update(cx, |thread, cx| {
516 thread.send(UserMessageId::new(), ["Message 2"], cx)
517 })
518 .unwrap();
519 cx.run_until_parked();
520
521 let completion = fake_model.pending_completions().pop().unwrap();
522 assert_eq!(
523 completion.messages[1..],
524 vec![
525 LanguageModelRequestMessage {
526 role: Role::User,
527 content: vec!["Message 1".into()],
528 cache: false,
529 reasoning_details: None,
530 },
531 LanguageModelRequestMessage {
532 role: Role::Assistant,
533 content: vec!["Response to Message 1".into()],
534 cache: false,
535 reasoning_details: None,
536 },
537 LanguageModelRequestMessage {
538 role: Role::User,
539 content: vec!["Message 2".into()],
540 cache: true,
541 reasoning_details: None,
542 }
543 ]
544 );
545 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
546 "Response to Message 2".into(),
547 ));
548 fake_model.end_last_completion_stream();
549 cx.run_until_parked();
550
551 // Simulate a tool call and verify that the latest tool result is cached
552 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
553 thread
554 .update(cx, |thread, cx| {
555 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
556 })
557 .unwrap();
558 cx.run_until_parked();
559
560 let tool_use = LanguageModelToolUse {
561 id: "tool_1".into(),
562 name: EchoTool::name().into(),
563 raw_input: json!({"text": "test"}).to_string(),
564 input: json!({"text": "test"}),
565 is_input_complete: true,
566 thought_signature: None,
567 };
568 fake_model
569 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
570 fake_model.end_last_completion_stream();
571 cx.run_until_parked();
572
573 let completion = fake_model.pending_completions().pop().unwrap();
574 let tool_result = LanguageModelToolResult {
575 tool_use_id: "tool_1".into(),
576 tool_name: EchoTool::name().into(),
577 is_error: false,
578 content: "test".into(),
579 output: Some("test".into()),
580 };
581 assert_eq!(
582 completion.messages[1..],
583 vec![
584 LanguageModelRequestMessage {
585 role: Role::User,
586 content: vec!["Message 1".into()],
587 cache: false,
588 reasoning_details: None,
589 },
590 LanguageModelRequestMessage {
591 role: Role::Assistant,
592 content: vec!["Response to Message 1".into()],
593 cache: false,
594 reasoning_details: None,
595 },
596 LanguageModelRequestMessage {
597 role: Role::User,
598 content: vec!["Message 2".into()],
599 cache: false,
600 reasoning_details: None,
601 },
602 LanguageModelRequestMessage {
603 role: Role::Assistant,
604 content: vec!["Response to Message 2".into()],
605 cache: false,
606 reasoning_details: None,
607 },
608 LanguageModelRequestMessage {
609 role: Role::User,
610 content: vec!["Use the echo tool".into()],
611 cache: false,
612 reasoning_details: None,
613 },
614 LanguageModelRequestMessage {
615 role: Role::Assistant,
616 content: vec![MessageContent::ToolUse(tool_use)],
617 cache: false,
618 reasoning_details: None,
619 },
620 LanguageModelRequestMessage {
621 role: Role::User,
622 content: vec![MessageContent::ToolResult(tool_result)],
623 cache: true,
624 reasoning_details: None,
625 }
626 ]
627 );
628}
629
630#[gpui::test]
631#[cfg_attr(not(feature = "e2e"), ignore)]
632async fn test_basic_tool_calls(cx: &mut TestAppContext) {
633 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
634
635 // Test a tool call that's likely to complete *before* streaming stops.
636 let events = thread
637 .update(cx, |thread, cx| {
638 thread.add_tool(EchoTool);
639 thread.send(
640 UserMessageId::new(),
641 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
642 cx,
643 )
644 })
645 .unwrap()
646 .collect()
647 .await;
648 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
649
650 // Test a tool calls that's likely to complete *after* streaming stops.
651 let events = thread
652 .update(cx, |thread, cx| {
653 thread.remove_tool(&EchoTool::name());
654 thread.add_tool(DelayTool);
655 thread.send(
656 UserMessageId::new(),
657 [
658 "Now call the delay tool with 200ms.",
659 "When the timer goes off, then you echo the output of the tool.",
660 ],
661 cx,
662 )
663 })
664 .unwrap()
665 .collect()
666 .await;
667 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
668 thread.update(cx, |thread, _cx| {
669 assert!(
670 thread
671 .last_message()
672 .unwrap()
673 .as_agent_message()
674 .unwrap()
675 .content
676 .iter()
677 .any(|content| {
678 if let AgentMessageContent::Text(text) = content {
679 text.contains("Ding")
680 } else {
681 false
682 }
683 }),
684 "{}",
685 thread.to_markdown()
686 );
687 });
688}
689
690#[gpui::test]
691#[cfg_attr(not(feature = "e2e"), ignore)]
692async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
693 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
694
695 // Test a tool call that's likely to complete *before* streaming stops.
696 let mut events = thread
697 .update(cx, |thread, cx| {
698 thread.add_tool(WordListTool);
699 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
700 })
701 .unwrap();
702
703 let mut saw_partial_tool_use = false;
704 while let Some(event) = events.next().await {
705 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
706 thread.update(cx, |thread, _cx| {
707 // Look for a tool use in the thread's last message
708 let message = thread.last_message().unwrap();
709 let agent_message = message.as_agent_message().unwrap();
710 let last_content = agent_message.content.last().unwrap();
711 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
712 assert_eq!(last_tool_use.name.as_ref(), "word_list");
713 if tool_call.status == acp::ToolCallStatus::Pending {
714 if !last_tool_use.is_input_complete
715 && last_tool_use.input.get("g").is_none()
716 {
717 saw_partial_tool_use = true;
718 }
719 } else {
720 last_tool_use
721 .input
722 .get("a")
723 .expect("'a' has streamed because input is now complete");
724 last_tool_use
725 .input
726 .get("g")
727 .expect("'g' has streamed because input is now complete");
728 }
729 } else {
730 panic!("last content should be a tool use");
731 }
732 });
733 }
734 }
735
736 assert!(
737 saw_partial_tool_use,
738 "should see at least one partially streamed tool use in the history"
739 );
740}
741
742#[gpui::test]
743async fn test_tool_authorization(cx: &mut TestAppContext) {
744 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
745 let fake_model = model.as_fake();
746
747 let mut events = thread
748 .update(cx, |thread, cx| {
749 thread.add_tool(ToolRequiringPermission);
750 thread.send(UserMessageId::new(), ["abc"], cx)
751 })
752 .unwrap();
753 cx.run_until_parked();
754 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
755 LanguageModelToolUse {
756 id: "tool_id_1".into(),
757 name: ToolRequiringPermission::name().into(),
758 raw_input: "{}".into(),
759 input: json!({}),
760 is_input_complete: true,
761 thought_signature: None,
762 },
763 ));
764 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
765 LanguageModelToolUse {
766 id: "tool_id_2".into(),
767 name: ToolRequiringPermission::name().into(),
768 raw_input: "{}".into(),
769 input: json!({}),
770 is_input_complete: true,
771 thought_signature: None,
772 },
773 ));
774 fake_model.end_last_completion_stream();
775 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
776 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
777
778 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
779 tool_call_auth_1
780 .response
781 .send(acp::PermissionOptionId::new("allow"))
782 .unwrap();
783 cx.run_until_parked();
784
785 // Reject the second - send "deny" option_id directly since Deny is now a button
786 tool_call_auth_2
787 .response
788 .send(acp::PermissionOptionId::new("deny"))
789 .unwrap();
790 cx.run_until_parked();
791
792 let completion = fake_model.pending_completions().pop().unwrap();
793 let message = completion.messages.last().unwrap();
794 assert_eq!(
795 message.content,
796 vec![
797 language_model::MessageContent::ToolResult(LanguageModelToolResult {
798 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
799 tool_name: ToolRequiringPermission::name().into(),
800 is_error: false,
801 content: "Allowed".into(),
802 output: Some("Allowed".into())
803 }),
804 language_model::MessageContent::ToolResult(LanguageModelToolResult {
805 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
806 tool_name: ToolRequiringPermission::name().into(),
807 is_error: true,
808 content: "Permission to run tool denied by user".into(),
809 output: Some("Permission to run tool denied by user".into())
810 })
811 ]
812 );
813
814 // Simulate yet another tool call.
815 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
816 LanguageModelToolUse {
817 id: "tool_id_3".into(),
818 name: ToolRequiringPermission::name().into(),
819 raw_input: "{}".into(),
820 input: json!({}),
821 is_input_complete: true,
822 thought_signature: None,
823 },
824 ));
825 fake_model.end_last_completion_stream();
826
827 // Respond by always allowing tools - send transformed option_id
828 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
829 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
830 tool_call_auth_3
831 .response
832 .send(acp::PermissionOptionId::new(
833 "always_allow:tool_requiring_permission",
834 ))
835 .unwrap();
836 cx.run_until_parked();
837 let completion = fake_model.pending_completions().pop().unwrap();
838 let message = completion.messages.last().unwrap();
839 assert_eq!(
840 message.content,
841 vec![language_model::MessageContent::ToolResult(
842 LanguageModelToolResult {
843 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
844 tool_name: ToolRequiringPermission::name().into(),
845 is_error: false,
846 content: "Allowed".into(),
847 output: Some("Allowed".into())
848 }
849 )]
850 );
851
852 // Simulate a final tool call, ensuring we don't trigger authorization.
853 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
854 LanguageModelToolUse {
855 id: "tool_id_4".into(),
856 name: ToolRequiringPermission::name().into(),
857 raw_input: "{}".into(),
858 input: json!({}),
859 is_input_complete: true,
860 thought_signature: None,
861 },
862 ));
863 fake_model.end_last_completion_stream();
864 cx.run_until_parked();
865 let completion = fake_model.pending_completions().pop().unwrap();
866 let message = completion.messages.last().unwrap();
867 assert_eq!(
868 message.content,
869 vec![language_model::MessageContent::ToolResult(
870 LanguageModelToolResult {
871 tool_use_id: "tool_id_4".into(),
872 tool_name: ToolRequiringPermission::name().into(),
873 is_error: false,
874 content: "Allowed".into(),
875 output: Some("Allowed".into())
876 }
877 )]
878 );
879}
880
881#[gpui::test]
882async fn test_tool_hallucination(cx: &mut TestAppContext) {
883 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
884 let fake_model = model.as_fake();
885
886 let mut events = thread
887 .update(cx, |thread, cx| {
888 thread.send(UserMessageId::new(), ["abc"], cx)
889 })
890 .unwrap();
891 cx.run_until_parked();
892 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
893 LanguageModelToolUse {
894 id: "tool_id_1".into(),
895 name: "nonexistent_tool".into(),
896 raw_input: "{}".into(),
897 input: json!({}),
898 is_input_complete: true,
899 thought_signature: None,
900 },
901 ));
902 fake_model.end_last_completion_stream();
903
904 let tool_call = expect_tool_call(&mut events).await;
905 assert_eq!(tool_call.title, "nonexistent_tool");
906 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
907 let update = expect_tool_call_update_fields(&mut events).await;
908 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
909}
910
911async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
912 let event = events
913 .next()
914 .await
915 .expect("no tool call authorization event received")
916 .unwrap();
917 match event {
918 ThreadEvent::ToolCall(tool_call) => tool_call,
919 event => {
920 panic!("Unexpected event {event:?}");
921 }
922 }
923}
924
925async fn expect_tool_call_update_fields(
926 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
927) -> acp::ToolCallUpdate {
928 let event = events
929 .next()
930 .await
931 .expect("no tool call authorization event received")
932 .unwrap();
933 match event {
934 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
935 event => {
936 panic!("Unexpected event {event:?}");
937 }
938 }
939}
940
941async fn next_tool_call_authorization(
942 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
943) -> ToolCallAuthorization {
944 loop {
945 let event = events
946 .next()
947 .await
948 .expect("no tool call authorization event received")
949 .unwrap();
950 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
951 let permission_kinds = tool_call_authorization
952 .options
953 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
954 .map(|option| option.kind);
955 let allow_once = tool_call_authorization
956 .options
957 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
958 .map(|option| option.kind);
959
960 assert_eq!(
961 permission_kinds,
962 Some(acp::PermissionOptionKind::AllowAlways)
963 );
964 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
965 return tool_call_authorization;
966 }
967 }
968}
969
970#[test]
971fn test_permission_options_terminal_with_pattern() {
972 let permission_options =
973 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
974
975 let PermissionOptions::Dropdown(choices) = permission_options else {
976 panic!("Expected dropdown permission options");
977 };
978
979 assert_eq!(choices.len(), 3);
980 let labels: Vec<&str> = choices
981 .iter()
982 .map(|choice| choice.allow.name.as_ref())
983 .collect();
984 assert!(labels.contains(&"Always for terminal"));
985 assert!(labels.contains(&"Always for `cargo` commands"));
986 assert!(labels.contains(&"Only this time"));
987}
988
989#[test]
990fn test_permission_options_edit_file_with_path_pattern() {
991 let permission_options =
992 ToolPermissionContext::new("edit_file", "src/main.rs").build_permission_options();
993
994 let PermissionOptions::Dropdown(choices) = permission_options else {
995 panic!("Expected dropdown permission options");
996 };
997
998 let labels: Vec<&str> = choices
999 .iter()
1000 .map(|choice| choice.allow.name.as_ref())
1001 .collect();
1002 assert!(labels.contains(&"Always for edit file"));
1003 assert!(labels.contains(&"Always for `src/`"));
1004}
1005
1006#[test]
1007fn test_permission_options_fetch_with_domain_pattern() {
1008 let permission_options =
1009 ToolPermissionContext::new("fetch", "https://docs.rs/gpui").build_permission_options();
1010
1011 let PermissionOptions::Dropdown(choices) = permission_options else {
1012 panic!("Expected dropdown permission options");
1013 };
1014
1015 let labels: Vec<&str> = choices
1016 .iter()
1017 .map(|choice| choice.allow.name.as_ref())
1018 .collect();
1019 assert!(labels.contains(&"Always for fetch"));
1020 assert!(labels.contains(&"Always for `docs.rs`"));
1021}
1022
1023#[test]
1024fn test_permission_options_without_pattern() {
1025 let permission_options = ToolPermissionContext::new("terminal", "./deploy.sh --production")
1026 .build_permission_options();
1027
1028 let PermissionOptions::Dropdown(choices) = permission_options else {
1029 panic!("Expected dropdown permission options");
1030 };
1031
1032 assert_eq!(choices.len(), 2);
1033 let labels: Vec<&str> = choices
1034 .iter()
1035 .map(|choice| choice.allow.name.as_ref())
1036 .collect();
1037 assert!(labels.contains(&"Always for terminal"));
1038 assert!(labels.contains(&"Only this time"));
1039 assert!(!labels.iter().any(|label| label.contains("commands")));
1040}
1041
1042#[test]
1043fn test_permission_option_ids_for_terminal() {
1044 let permission_options =
1045 ToolPermissionContext::new("terminal", "cargo build --release").build_permission_options();
1046
1047 let PermissionOptions::Dropdown(choices) = permission_options else {
1048 panic!("Expected dropdown permission options");
1049 };
1050
1051 let allow_ids: Vec<String> = choices
1052 .iter()
1053 .map(|choice| choice.allow.option_id.0.to_string())
1054 .collect();
1055 let deny_ids: Vec<String> = choices
1056 .iter()
1057 .map(|choice| choice.deny.option_id.0.to_string())
1058 .collect();
1059
1060 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1061 assert!(allow_ids.contains(&"allow".to_string()));
1062 assert!(
1063 allow_ids
1064 .iter()
1065 .any(|id| id.starts_with("always_allow_pattern:terminal:")),
1066 "Missing allow pattern option"
1067 );
1068
1069 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1070 assert!(deny_ids.contains(&"deny".to_string()));
1071 assert!(
1072 deny_ids
1073 .iter()
1074 .any(|id| id.starts_with("always_deny_pattern:terminal:")),
1075 "Missing deny pattern option"
1076 );
1077}
1078
1079#[gpui::test]
1080#[cfg_attr(not(feature = "e2e"), ignore)]
1081async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1082 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1083
1084 // Test concurrent tool calls with different delay times
1085 let events = thread
1086 .update(cx, |thread, cx| {
1087 thread.add_tool(DelayTool);
1088 thread.send(
1089 UserMessageId::new(),
1090 [
1091 "Call the delay tool twice in the same message.",
1092 "Once with 100ms. Once with 300ms.",
1093 "When both timers are complete, describe the outputs.",
1094 ],
1095 cx,
1096 )
1097 })
1098 .unwrap()
1099 .collect()
1100 .await;
1101
1102 let stop_reasons = stop_events(events);
1103 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1104
1105 thread.update(cx, |thread, _cx| {
1106 let last_message = thread.last_message().unwrap();
1107 let agent_message = last_message.as_agent_message().unwrap();
1108 let text = agent_message
1109 .content
1110 .iter()
1111 .filter_map(|content| {
1112 if let AgentMessageContent::Text(text) = content {
1113 Some(text.as_str())
1114 } else {
1115 None
1116 }
1117 })
1118 .collect::<String>();
1119
1120 assert!(text.contains("Ding"));
1121 });
1122}
1123
1124#[gpui::test]
1125async fn test_profiles(cx: &mut TestAppContext) {
1126 let ThreadTest {
1127 model, thread, fs, ..
1128 } = setup(cx, TestModel::Fake).await;
1129 let fake_model = model.as_fake();
1130
1131 thread.update(cx, |thread, _cx| {
1132 thread.add_tool(DelayTool);
1133 thread.add_tool(EchoTool);
1134 thread.add_tool(InfiniteTool);
1135 });
1136
1137 // Override profiles and wait for settings to be loaded.
1138 fs.insert_file(
1139 paths::settings_file(),
1140 json!({
1141 "agent": {
1142 "profiles": {
1143 "test-1": {
1144 "name": "Test Profile 1",
1145 "tools": {
1146 EchoTool::name(): true,
1147 DelayTool::name(): true,
1148 }
1149 },
1150 "test-2": {
1151 "name": "Test Profile 2",
1152 "tools": {
1153 InfiniteTool::name(): true,
1154 }
1155 }
1156 }
1157 }
1158 })
1159 .to_string()
1160 .into_bytes(),
1161 )
1162 .await;
1163 cx.run_until_parked();
1164
1165 // Test that test-1 profile (default) has echo and delay tools
1166 thread
1167 .update(cx, |thread, cx| {
1168 thread.set_profile(AgentProfileId("test-1".into()), cx);
1169 thread.send(UserMessageId::new(), ["test"], cx)
1170 })
1171 .unwrap();
1172 cx.run_until_parked();
1173
1174 let mut pending_completions = fake_model.pending_completions();
1175 assert_eq!(pending_completions.len(), 1);
1176 let completion = pending_completions.pop().unwrap();
1177 let tool_names: Vec<String> = completion
1178 .tools
1179 .iter()
1180 .map(|tool| tool.name.clone())
1181 .collect();
1182 assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1183 fake_model.end_last_completion_stream();
1184
1185 // Switch to test-2 profile, and verify that it has only the infinite tool.
1186 thread
1187 .update(cx, |thread, cx| {
1188 thread.set_profile(AgentProfileId("test-2".into()), cx);
1189 thread.send(UserMessageId::new(), ["test2"], cx)
1190 })
1191 .unwrap();
1192 cx.run_until_parked();
1193 let mut pending_completions = fake_model.pending_completions();
1194 assert_eq!(pending_completions.len(), 1);
1195 let completion = pending_completions.pop().unwrap();
1196 let tool_names: Vec<String> = completion
1197 .tools
1198 .iter()
1199 .map(|tool| tool.name.clone())
1200 .collect();
1201 assert_eq!(tool_names, vec![InfiniteTool::name()]);
1202}
1203
1204#[gpui::test]
1205async fn test_mcp_tools(cx: &mut TestAppContext) {
1206 let ThreadTest {
1207 model,
1208 thread,
1209 context_server_store,
1210 fs,
1211 ..
1212 } = setup(cx, TestModel::Fake).await;
1213 let fake_model = model.as_fake();
1214
1215 // Override profiles and wait for settings to be loaded.
1216 fs.insert_file(
1217 paths::settings_file(),
1218 json!({
1219 "agent": {
1220 "always_allow_tool_actions": true,
1221 "profiles": {
1222 "test": {
1223 "name": "Test Profile",
1224 "enable_all_context_servers": true,
1225 "tools": {
1226 EchoTool::name(): true,
1227 }
1228 },
1229 }
1230 }
1231 })
1232 .to_string()
1233 .into_bytes(),
1234 )
1235 .await;
1236 cx.run_until_parked();
1237 thread.update(cx, |thread, cx| {
1238 thread.set_profile(AgentProfileId("test".into()), cx)
1239 });
1240
1241 let mut mcp_tool_calls = setup_context_server(
1242 "test_server",
1243 vec![context_server::types::Tool {
1244 name: "echo".into(),
1245 description: None,
1246 input_schema: serde_json::to_value(EchoTool::input_schema(
1247 LanguageModelToolSchemaFormat::JsonSchema,
1248 ))
1249 .unwrap(),
1250 output_schema: None,
1251 annotations: None,
1252 }],
1253 &context_server_store,
1254 cx,
1255 );
1256
1257 let events = thread.update(cx, |thread, cx| {
1258 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1259 });
1260 cx.run_until_parked();
1261
1262 // Simulate the model calling the MCP tool.
1263 let completion = fake_model.pending_completions().pop().unwrap();
1264 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1265 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1266 LanguageModelToolUse {
1267 id: "tool_1".into(),
1268 name: "echo".into(),
1269 raw_input: json!({"text": "test"}).to_string(),
1270 input: json!({"text": "test"}),
1271 is_input_complete: true,
1272 thought_signature: None,
1273 },
1274 ));
1275 fake_model.end_last_completion_stream();
1276 cx.run_until_parked();
1277
1278 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1279 assert_eq!(tool_call_params.name, "echo");
1280 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1281 tool_call_response
1282 .send(context_server::types::CallToolResponse {
1283 content: vec![context_server::types::ToolResponseContent::Text {
1284 text: "test".into(),
1285 }],
1286 is_error: None,
1287 meta: None,
1288 structured_content: None,
1289 })
1290 .unwrap();
1291 cx.run_until_parked();
1292
1293 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1294 fake_model.send_last_completion_stream_text_chunk("Done!");
1295 fake_model.end_last_completion_stream();
1296 events.collect::<Vec<_>>().await;
1297
1298 // Send again after adding the echo tool, ensuring the name collision is resolved.
1299 let events = thread.update(cx, |thread, cx| {
1300 thread.add_tool(EchoTool);
1301 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1302 });
1303 cx.run_until_parked();
1304 let completion = fake_model.pending_completions().pop().unwrap();
1305 assert_eq!(
1306 tool_names_for_completion(&completion),
1307 vec!["echo", "test_server_echo"]
1308 );
1309 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1310 LanguageModelToolUse {
1311 id: "tool_2".into(),
1312 name: "test_server_echo".into(),
1313 raw_input: json!({"text": "mcp"}).to_string(),
1314 input: json!({"text": "mcp"}),
1315 is_input_complete: true,
1316 thought_signature: None,
1317 },
1318 ));
1319 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1320 LanguageModelToolUse {
1321 id: "tool_3".into(),
1322 name: "echo".into(),
1323 raw_input: json!({"text": "native"}).to_string(),
1324 input: json!({"text": "native"}),
1325 is_input_complete: true,
1326 thought_signature: None,
1327 },
1328 ));
1329 fake_model.end_last_completion_stream();
1330 cx.run_until_parked();
1331
1332 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1333 assert_eq!(tool_call_params.name, "echo");
1334 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1335 tool_call_response
1336 .send(context_server::types::CallToolResponse {
1337 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1338 is_error: None,
1339 meta: None,
1340 structured_content: None,
1341 })
1342 .unwrap();
1343 cx.run_until_parked();
1344
1345 // Ensure the tool results were inserted with the correct names.
1346 let completion = fake_model.pending_completions().pop().unwrap();
1347 assert_eq!(
1348 completion.messages.last().unwrap().content,
1349 vec![
1350 MessageContent::ToolResult(LanguageModelToolResult {
1351 tool_use_id: "tool_3".into(),
1352 tool_name: "echo".into(),
1353 is_error: false,
1354 content: "native".into(),
1355 output: Some("native".into()),
1356 },),
1357 MessageContent::ToolResult(LanguageModelToolResult {
1358 tool_use_id: "tool_2".into(),
1359 tool_name: "test_server_echo".into(),
1360 is_error: false,
1361 content: "mcp".into(),
1362 output: Some("mcp".into()),
1363 },),
1364 ]
1365 );
1366 fake_model.end_last_completion_stream();
1367 events.collect::<Vec<_>>().await;
1368}
1369
1370#[gpui::test]
1371async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1372 let ThreadTest {
1373 model,
1374 thread,
1375 context_server_store,
1376 fs,
1377 ..
1378 } = setup(cx, TestModel::Fake).await;
1379 let fake_model = model.as_fake();
1380
1381 // Set up a profile with all tools enabled
1382 fs.insert_file(
1383 paths::settings_file(),
1384 json!({
1385 "agent": {
1386 "profiles": {
1387 "test": {
1388 "name": "Test Profile",
1389 "enable_all_context_servers": true,
1390 "tools": {
1391 EchoTool::name(): true,
1392 DelayTool::name(): true,
1393 WordListTool::name(): true,
1394 ToolRequiringPermission::name(): true,
1395 InfiniteTool::name(): true,
1396 }
1397 },
1398 }
1399 }
1400 })
1401 .to_string()
1402 .into_bytes(),
1403 )
1404 .await;
1405 cx.run_until_parked();
1406
1407 thread.update(cx, |thread, cx| {
1408 thread.set_profile(AgentProfileId("test".into()), cx);
1409 thread.add_tool(EchoTool);
1410 thread.add_tool(DelayTool);
1411 thread.add_tool(WordListTool);
1412 thread.add_tool(ToolRequiringPermission);
1413 thread.add_tool(InfiniteTool);
1414 });
1415
1416 // Set up multiple context servers with some overlapping tool names
1417 let _server1_calls = setup_context_server(
1418 "xxx",
1419 vec![
1420 context_server::types::Tool {
1421 name: "echo".into(), // Conflicts with native EchoTool
1422 description: None,
1423 input_schema: serde_json::to_value(EchoTool::input_schema(
1424 LanguageModelToolSchemaFormat::JsonSchema,
1425 ))
1426 .unwrap(),
1427 output_schema: None,
1428 annotations: None,
1429 },
1430 context_server::types::Tool {
1431 name: "unique_tool_1".into(),
1432 description: None,
1433 input_schema: json!({"type": "object", "properties": {}}),
1434 output_schema: None,
1435 annotations: None,
1436 },
1437 ],
1438 &context_server_store,
1439 cx,
1440 );
1441
1442 let _server2_calls = setup_context_server(
1443 "yyy",
1444 vec![
1445 context_server::types::Tool {
1446 name: "echo".into(), // Also conflicts with native EchoTool
1447 description: None,
1448 input_schema: serde_json::to_value(EchoTool::input_schema(
1449 LanguageModelToolSchemaFormat::JsonSchema,
1450 ))
1451 .unwrap(),
1452 output_schema: None,
1453 annotations: None,
1454 },
1455 context_server::types::Tool {
1456 name: "unique_tool_2".into(),
1457 description: None,
1458 input_schema: json!({"type": "object", "properties": {}}),
1459 output_schema: None,
1460 annotations: None,
1461 },
1462 context_server::types::Tool {
1463 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1464 description: None,
1465 input_schema: json!({"type": "object", "properties": {}}),
1466 output_schema: None,
1467 annotations: None,
1468 },
1469 context_server::types::Tool {
1470 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1471 description: None,
1472 input_schema: json!({"type": "object", "properties": {}}),
1473 output_schema: None,
1474 annotations: None,
1475 },
1476 ],
1477 &context_server_store,
1478 cx,
1479 );
1480 let _server3_calls = setup_context_server(
1481 "zzz",
1482 vec![
1483 context_server::types::Tool {
1484 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1485 description: None,
1486 input_schema: json!({"type": "object", "properties": {}}),
1487 output_schema: None,
1488 annotations: None,
1489 },
1490 context_server::types::Tool {
1491 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1492 description: None,
1493 input_schema: json!({"type": "object", "properties": {}}),
1494 output_schema: None,
1495 annotations: None,
1496 },
1497 context_server::types::Tool {
1498 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1499 description: None,
1500 input_schema: json!({"type": "object", "properties": {}}),
1501 output_schema: None,
1502 annotations: None,
1503 },
1504 ],
1505 &context_server_store,
1506 cx,
1507 );
1508
1509 thread
1510 .update(cx, |thread, cx| {
1511 thread.send(UserMessageId::new(), ["Go"], cx)
1512 })
1513 .unwrap();
1514 cx.run_until_parked();
1515 let completion = fake_model.pending_completions().pop().unwrap();
1516 assert_eq!(
1517 tool_names_for_completion(&completion),
1518 vec![
1519 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1520 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1521 "delay",
1522 "echo",
1523 "infinite",
1524 "tool_requiring_permission",
1525 "unique_tool_1",
1526 "unique_tool_2",
1527 "word_list",
1528 "xxx_echo",
1529 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1530 "yyy_echo",
1531 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1532 ]
1533 );
1534}
1535
1536#[gpui::test]
1537#[cfg_attr(not(feature = "e2e"), ignore)]
1538async fn test_cancellation(cx: &mut TestAppContext) {
1539 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1540
1541 let mut events = thread
1542 .update(cx, |thread, cx| {
1543 thread.add_tool(InfiniteTool);
1544 thread.add_tool(EchoTool);
1545 thread.send(
1546 UserMessageId::new(),
1547 ["Call the echo tool, then call the infinite tool, then explain their output"],
1548 cx,
1549 )
1550 })
1551 .unwrap();
1552
1553 // Wait until both tools are called.
1554 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1555 let mut echo_id = None;
1556 let mut echo_completed = false;
1557 while let Some(event) = events.next().await {
1558 match event.unwrap() {
1559 ThreadEvent::ToolCall(tool_call) => {
1560 assert_eq!(tool_call.title, expected_tools.remove(0));
1561 if tool_call.title == "Echo" {
1562 echo_id = Some(tool_call.tool_call_id);
1563 }
1564 }
1565 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1566 acp::ToolCallUpdate {
1567 tool_call_id,
1568 fields:
1569 acp::ToolCallUpdateFields {
1570 status: Some(acp::ToolCallStatus::Completed),
1571 ..
1572 },
1573 ..
1574 },
1575 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1576 echo_completed = true;
1577 }
1578 _ => {}
1579 }
1580
1581 if expected_tools.is_empty() && echo_completed {
1582 break;
1583 }
1584 }
1585
1586 // Cancel the current send and ensure that the event stream is closed, even
1587 // if one of the tools is still running.
1588 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1589 let events = events.collect::<Vec<_>>().await;
1590 let last_event = events.last();
1591 assert!(
1592 matches!(
1593 last_event,
1594 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1595 ),
1596 "unexpected event {last_event:?}"
1597 );
1598
1599 // Ensure we can still send a new message after cancellation.
1600 let events = thread
1601 .update(cx, |thread, cx| {
1602 thread.send(
1603 UserMessageId::new(),
1604 ["Testing: reply with 'Hello' then stop."],
1605 cx,
1606 )
1607 })
1608 .unwrap()
1609 .collect::<Vec<_>>()
1610 .await;
1611 thread.update(cx, |thread, _cx| {
1612 let message = thread.last_message().unwrap();
1613 let agent_message = message.as_agent_message().unwrap();
1614 assert_eq!(
1615 agent_message.content,
1616 vec![AgentMessageContent::Text("Hello".to_string())]
1617 );
1618 });
1619 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1620}
1621
1622#[gpui::test]
1623async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1624 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1625 always_allow_tools(cx);
1626 let fake_model = model.as_fake();
1627
1628 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1629 let environment = Rc::new(FakeThreadEnvironment {
1630 handle: handle.clone(),
1631 });
1632
1633 let mut events = thread
1634 .update(cx, |thread, cx| {
1635 thread.add_tool(crate::TerminalTool::new(
1636 thread.project().clone(),
1637 environment,
1638 ));
1639 thread.send(UserMessageId::new(), ["run a command"], cx)
1640 })
1641 .unwrap();
1642
1643 cx.run_until_parked();
1644
1645 // Simulate the model calling the terminal tool
1646 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1647 LanguageModelToolUse {
1648 id: "terminal_tool_1".into(),
1649 name: "terminal".into(),
1650 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1651 input: json!({"command": "sleep 1000", "cd": "."}),
1652 is_input_complete: true,
1653 thought_signature: None,
1654 },
1655 ));
1656 fake_model.end_last_completion_stream();
1657
1658 // Wait for the terminal tool to start running
1659 wait_for_terminal_tool_started(&mut events, cx).await;
1660
1661 // Cancel the thread while the terminal is running
1662 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1663
1664 // Collect remaining events, driving the executor to let cancellation complete
1665 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1666
1667 // Verify the terminal was killed
1668 assert!(
1669 handle.was_killed(),
1670 "expected terminal handle to be killed on cancellation"
1671 );
1672
1673 // Verify we got a cancellation stop event
1674 assert_eq!(
1675 stop_events(remaining_events),
1676 vec![acp::StopReason::Cancelled],
1677 );
1678
1679 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1680 thread.update(cx, |thread, _cx| {
1681 let message = thread.last_message().unwrap();
1682 let agent_message = message.as_agent_message().unwrap();
1683
1684 let tool_use = agent_message
1685 .content
1686 .iter()
1687 .find_map(|content| match content {
1688 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1689 _ => None,
1690 })
1691 .expect("expected tool use in agent message");
1692
1693 let tool_result = agent_message
1694 .tool_results
1695 .get(&tool_use.id)
1696 .expect("expected tool result");
1697
1698 let result_text = match &tool_result.content {
1699 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1700 _ => panic!("expected text content in tool result"),
1701 };
1702
1703 // "partial output" comes from FakeTerminalHandle's output field
1704 assert!(
1705 result_text.contains("partial output"),
1706 "expected tool result to contain terminal output, got: {result_text}"
1707 );
1708 // Match the actual format from process_content in terminal_tool.rs
1709 assert!(
1710 result_text.contains("The user stopped this command"),
1711 "expected tool result to indicate user stopped, got: {result_text}"
1712 );
1713 });
1714
1715 // Verify we can send a new message after cancellation
1716 verify_thread_recovery(&thread, &fake_model, cx).await;
1717}
1718
1719#[gpui::test]
1720async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1721 // This test verifies that tools which properly handle cancellation via
1722 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1723 // to cancellation and report that they were cancelled.
1724 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1725 always_allow_tools(cx);
1726 let fake_model = model.as_fake();
1727
1728 let (tool, was_cancelled) = CancellationAwareTool::new();
1729
1730 let mut events = thread
1731 .update(cx, |thread, cx| {
1732 thread.add_tool(tool);
1733 thread.send(
1734 UserMessageId::new(),
1735 ["call the cancellation aware tool"],
1736 cx,
1737 )
1738 })
1739 .unwrap();
1740
1741 cx.run_until_parked();
1742
1743 // Simulate the model calling the cancellation-aware tool
1744 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1745 LanguageModelToolUse {
1746 id: "cancellation_aware_1".into(),
1747 name: "cancellation_aware".into(),
1748 raw_input: r#"{}"#.into(),
1749 input: json!({}),
1750 is_input_complete: true,
1751 thought_signature: None,
1752 },
1753 ));
1754 fake_model.end_last_completion_stream();
1755
1756 cx.run_until_parked();
1757
1758 // Wait for the tool call to be reported
1759 let mut tool_started = false;
1760 let deadline = cx.executor().num_cpus() * 100;
1761 for _ in 0..deadline {
1762 cx.run_until_parked();
1763
1764 while let Some(Some(event)) = events.next().now_or_never() {
1765 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1766 if tool_call.title == "Cancellation Aware Tool" {
1767 tool_started = true;
1768 break;
1769 }
1770 }
1771 }
1772
1773 if tool_started {
1774 break;
1775 }
1776
1777 cx.background_executor
1778 .timer(Duration::from_millis(10))
1779 .await;
1780 }
1781 assert!(tool_started, "expected cancellation aware tool to start");
1782
1783 // Cancel the thread and wait for it to complete
1784 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1785
1786 // The cancel task should complete promptly because the tool handles cancellation
1787 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1788 futures::select! {
1789 _ = cancel_task.fuse() => {}
1790 _ = timeout.fuse() => {
1791 panic!("cancel task timed out - tool did not respond to cancellation");
1792 }
1793 }
1794
1795 // Verify the tool detected cancellation via its flag
1796 assert!(
1797 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1798 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1799 );
1800
1801 // Collect remaining events
1802 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1803
1804 // Verify we got a cancellation stop event
1805 assert_eq!(
1806 stop_events(remaining_events),
1807 vec![acp::StopReason::Cancelled],
1808 );
1809
1810 // Verify we can send a new message after cancellation
1811 verify_thread_recovery(&thread, &fake_model, cx).await;
1812}
1813
1814/// Helper to verify thread can recover after cancellation by sending a simple message.
1815async fn verify_thread_recovery(
1816 thread: &Entity<Thread>,
1817 fake_model: &FakeLanguageModel,
1818 cx: &mut TestAppContext,
1819) {
1820 let events = thread
1821 .update(cx, |thread, cx| {
1822 thread.send(
1823 UserMessageId::new(),
1824 ["Testing: reply with 'Hello' then stop."],
1825 cx,
1826 )
1827 })
1828 .unwrap();
1829 cx.run_until_parked();
1830 fake_model.send_last_completion_stream_text_chunk("Hello");
1831 fake_model
1832 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1833 fake_model.end_last_completion_stream();
1834
1835 let events = events.collect::<Vec<_>>().await;
1836 thread.update(cx, |thread, _cx| {
1837 let message = thread.last_message().unwrap();
1838 let agent_message = message.as_agent_message().unwrap();
1839 assert_eq!(
1840 agent_message.content,
1841 vec![AgentMessageContent::Text("Hello".to_string())]
1842 );
1843 });
1844 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1845}
1846
1847/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1848async fn wait_for_terminal_tool_started(
1849 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1850 cx: &mut TestAppContext,
1851) {
1852 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1853 for _ in 0..deadline {
1854 cx.run_until_parked();
1855
1856 while let Some(Some(event)) = events.next().now_or_never() {
1857 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1858 update,
1859 ))) = &event
1860 {
1861 if update.fields.content.as_ref().is_some_and(|content| {
1862 content
1863 .iter()
1864 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1865 }) {
1866 return;
1867 }
1868 }
1869 }
1870
1871 cx.background_executor
1872 .timer(Duration::from_millis(10))
1873 .await;
1874 }
1875 panic!("terminal tool did not start within the expected time");
1876}
1877
1878/// Collects events until a Stop event is received, driving the executor to completion.
1879async fn collect_events_until_stop(
1880 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1881 cx: &mut TestAppContext,
1882) -> Vec<Result<ThreadEvent>> {
1883 let mut collected = Vec::new();
1884 let deadline = cx.executor().num_cpus() * 200;
1885
1886 for _ in 0..deadline {
1887 cx.executor().advance_clock(Duration::from_millis(10));
1888 cx.run_until_parked();
1889
1890 while let Some(Some(event)) = events.next().now_or_never() {
1891 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1892 collected.push(event);
1893 if is_stop {
1894 return collected;
1895 }
1896 }
1897 }
1898 panic!(
1899 "did not receive Stop event within the expected time; collected {} events",
1900 collected.len()
1901 );
1902}
1903
1904#[gpui::test]
1905async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1906 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1907 always_allow_tools(cx);
1908 let fake_model = model.as_fake();
1909
1910 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1911 let environment = Rc::new(FakeThreadEnvironment {
1912 handle: handle.clone(),
1913 });
1914
1915 let message_id = UserMessageId::new();
1916 let mut events = thread
1917 .update(cx, |thread, cx| {
1918 thread.add_tool(crate::TerminalTool::new(
1919 thread.project().clone(),
1920 environment,
1921 ));
1922 thread.send(message_id.clone(), ["run a command"], cx)
1923 })
1924 .unwrap();
1925
1926 cx.run_until_parked();
1927
1928 // Simulate the model calling the terminal tool
1929 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1930 LanguageModelToolUse {
1931 id: "terminal_tool_1".into(),
1932 name: "terminal".into(),
1933 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1934 input: json!({"command": "sleep 1000", "cd": "."}),
1935 is_input_complete: true,
1936 thought_signature: None,
1937 },
1938 ));
1939 fake_model.end_last_completion_stream();
1940
1941 // Wait for the terminal tool to start running
1942 wait_for_terminal_tool_started(&mut events, cx).await;
1943
1944 // Truncate the thread while the terminal is running
1945 thread
1946 .update(cx, |thread, cx| thread.truncate(message_id, cx))
1947 .unwrap();
1948
1949 // Drive the executor to let cancellation complete
1950 let _ = collect_events_until_stop(&mut events, cx).await;
1951
1952 // Verify the terminal was killed
1953 assert!(
1954 handle.was_killed(),
1955 "expected terminal handle to be killed on truncate"
1956 );
1957
1958 // Verify the thread is empty after truncation
1959 thread.update(cx, |thread, _cx| {
1960 assert_eq!(
1961 thread.to_markdown(),
1962 "",
1963 "expected thread to be empty after truncating the only message"
1964 );
1965 });
1966
1967 // Verify we can send a new message after truncation
1968 verify_thread_recovery(&thread, &fake_model, cx).await;
1969}
1970
1971#[gpui::test]
1972async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
1973 // Tests that cancellation properly kills all running terminal tools when multiple are active.
1974 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1975 always_allow_tools(cx);
1976 let fake_model = model.as_fake();
1977
1978 let environment = Rc::new(MultiTerminalEnvironment::new());
1979
1980 let mut events = thread
1981 .update(cx, |thread, cx| {
1982 thread.add_tool(crate::TerminalTool::new(
1983 thread.project().clone(),
1984 environment.clone(),
1985 ));
1986 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
1987 })
1988 .unwrap();
1989
1990 cx.run_until_parked();
1991
1992 // Simulate the model calling two terminal tools
1993 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1994 LanguageModelToolUse {
1995 id: "terminal_tool_1".into(),
1996 name: "terminal".into(),
1997 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1998 input: json!({"command": "sleep 1000", "cd": "."}),
1999 is_input_complete: true,
2000 thought_signature: None,
2001 },
2002 ));
2003 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2004 LanguageModelToolUse {
2005 id: "terminal_tool_2".into(),
2006 name: "terminal".into(),
2007 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2008 input: json!({"command": "sleep 2000", "cd": "."}),
2009 is_input_complete: true,
2010 thought_signature: None,
2011 },
2012 ));
2013 fake_model.end_last_completion_stream();
2014
2015 // Wait for both terminal tools to start by counting terminal content updates
2016 let mut terminals_started = 0;
2017 let deadline = cx.executor().num_cpus() * 100;
2018 for _ in 0..deadline {
2019 cx.run_until_parked();
2020
2021 while let Some(Some(event)) = events.next().now_or_never() {
2022 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2023 update,
2024 ))) = &event
2025 {
2026 if update.fields.content.as_ref().is_some_and(|content| {
2027 content
2028 .iter()
2029 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2030 }) {
2031 terminals_started += 1;
2032 if terminals_started >= 2 {
2033 break;
2034 }
2035 }
2036 }
2037 }
2038 if terminals_started >= 2 {
2039 break;
2040 }
2041
2042 cx.background_executor
2043 .timer(Duration::from_millis(10))
2044 .await;
2045 }
2046 assert!(
2047 terminals_started >= 2,
2048 "expected 2 terminal tools to start, got {terminals_started}"
2049 );
2050
2051 // Cancel the thread while both terminals are running
2052 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2053
2054 // Collect remaining events
2055 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2056
2057 // Verify both terminal handles were killed
2058 let handles = environment.handles();
2059 assert_eq!(
2060 handles.len(),
2061 2,
2062 "expected 2 terminal handles to be created"
2063 );
2064 assert!(
2065 handles[0].was_killed(),
2066 "expected first terminal handle to be killed on cancellation"
2067 );
2068 assert!(
2069 handles[1].was_killed(),
2070 "expected second terminal handle to be killed on cancellation"
2071 );
2072
2073 // Verify we got a cancellation stop event
2074 assert_eq!(
2075 stop_events(remaining_events),
2076 vec![acp::StopReason::Cancelled],
2077 );
2078}
2079
2080#[gpui::test]
2081async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2082 // Tests that clicking the stop button on the terminal card (as opposed to the main
2083 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2084 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2085 always_allow_tools(cx);
2086 let fake_model = model.as_fake();
2087
2088 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2089 let environment = Rc::new(FakeThreadEnvironment {
2090 handle: handle.clone(),
2091 });
2092
2093 let mut events = thread
2094 .update(cx, |thread, cx| {
2095 thread.add_tool(crate::TerminalTool::new(
2096 thread.project().clone(),
2097 environment,
2098 ));
2099 thread.send(UserMessageId::new(), ["run a command"], cx)
2100 })
2101 .unwrap();
2102
2103 cx.run_until_parked();
2104
2105 // Simulate the model calling the terminal tool
2106 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2107 LanguageModelToolUse {
2108 id: "terminal_tool_1".into(),
2109 name: "terminal".into(),
2110 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2111 input: json!({"command": "sleep 1000", "cd": "."}),
2112 is_input_complete: true,
2113 thought_signature: None,
2114 },
2115 ));
2116 fake_model.end_last_completion_stream();
2117
2118 // Wait for the terminal tool to start running
2119 wait_for_terminal_tool_started(&mut events, cx).await;
2120
2121 // Simulate user clicking stop on the terminal card itself.
2122 // This sets the flag and signals exit (simulating what the real UI would do).
2123 handle.set_stopped_by_user(true);
2124 handle.killed.store(true, Ordering::SeqCst);
2125 handle.signal_exit();
2126
2127 // Wait for the tool to complete
2128 cx.run_until_parked();
2129
2130 // The thread continues after tool completion - simulate the model ending its turn
2131 fake_model
2132 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2133 fake_model.end_last_completion_stream();
2134
2135 // Collect remaining events
2136 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2137
2138 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2139 assert_eq!(
2140 stop_events(remaining_events),
2141 vec![acp::StopReason::EndTurn],
2142 );
2143
2144 // Verify the tool result indicates user stopped
2145 thread.update(cx, |thread, _cx| {
2146 let message = thread.last_message().unwrap();
2147 let agent_message = message.as_agent_message().unwrap();
2148
2149 let tool_use = agent_message
2150 .content
2151 .iter()
2152 .find_map(|content| match content {
2153 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2154 _ => None,
2155 })
2156 .expect("expected tool use in agent message");
2157
2158 let tool_result = agent_message
2159 .tool_results
2160 .get(&tool_use.id)
2161 .expect("expected tool result");
2162
2163 let result_text = match &tool_result.content {
2164 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2165 _ => panic!("expected text content in tool result"),
2166 };
2167
2168 assert!(
2169 result_text.contains("The user stopped this command"),
2170 "expected tool result to indicate user stopped, got: {result_text}"
2171 );
2172 });
2173}
2174
2175#[gpui::test]
2176async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2177 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2178 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2179 always_allow_tools(cx);
2180 let fake_model = model.as_fake();
2181
2182 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2183 let environment = Rc::new(FakeThreadEnvironment {
2184 handle: handle.clone(),
2185 });
2186
2187 let mut events = thread
2188 .update(cx, |thread, cx| {
2189 thread.add_tool(crate::TerminalTool::new(
2190 thread.project().clone(),
2191 environment,
2192 ));
2193 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2194 })
2195 .unwrap();
2196
2197 cx.run_until_parked();
2198
2199 // Simulate the model calling the terminal tool with a short timeout
2200 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2201 LanguageModelToolUse {
2202 id: "terminal_tool_1".into(),
2203 name: "terminal".into(),
2204 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2205 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2206 is_input_complete: true,
2207 thought_signature: None,
2208 },
2209 ));
2210 fake_model.end_last_completion_stream();
2211
2212 // Wait for the terminal tool to start running
2213 wait_for_terminal_tool_started(&mut events, cx).await;
2214
2215 // Advance clock past the timeout
2216 cx.executor().advance_clock(Duration::from_millis(200));
2217 cx.run_until_parked();
2218
2219 // The thread continues after tool completion - simulate the model ending its turn
2220 fake_model
2221 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2222 fake_model.end_last_completion_stream();
2223
2224 // Collect remaining events
2225 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2226
2227 // Verify the terminal was killed due to timeout
2228 assert!(
2229 handle.was_killed(),
2230 "expected terminal handle to be killed on timeout"
2231 );
2232
2233 // Verify we got an EndTurn (the tool completed, just with timeout)
2234 assert_eq!(
2235 stop_events(remaining_events),
2236 vec![acp::StopReason::EndTurn],
2237 );
2238
2239 // Verify the tool result indicates timeout, not user stopped
2240 thread.update(cx, |thread, _cx| {
2241 let message = thread.last_message().unwrap();
2242 let agent_message = message.as_agent_message().unwrap();
2243
2244 let tool_use = agent_message
2245 .content
2246 .iter()
2247 .find_map(|content| match content {
2248 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2249 _ => None,
2250 })
2251 .expect("expected tool use in agent message");
2252
2253 let tool_result = agent_message
2254 .tool_results
2255 .get(&tool_use.id)
2256 .expect("expected tool result");
2257
2258 let result_text = match &tool_result.content {
2259 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2260 _ => panic!("expected text content in tool result"),
2261 };
2262
2263 assert!(
2264 result_text.contains("timed out"),
2265 "expected tool result to indicate timeout, got: {result_text}"
2266 );
2267 assert!(
2268 !result_text.contains("The user stopped"),
2269 "tool result should not mention user stopped when it timed out, got: {result_text}"
2270 );
2271 });
2272}
2273
2274#[gpui::test]
2275async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2276 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2277 let fake_model = model.as_fake();
2278
2279 let events_1 = thread
2280 .update(cx, |thread, cx| {
2281 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2282 })
2283 .unwrap();
2284 cx.run_until_parked();
2285 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2286 cx.run_until_parked();
2287
2288 let events_2 = thread
2289 .update(cx, |thread, cx| {
2290 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2291 })
2292 .unwrap();
2293 cx.run_until_parked();
2294 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2295 fake_model
2296 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2297 fake_model.end_last_completion_stream();
2298
2299 let events_1 = events_1.collect::<Vec<_>>().await;
2300 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2301 let events_2 = events_2.collect::<Vec<_>>().await;
2302 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2303}
2304
2305#[gpui::test]
2306async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2307 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2308 let fake_model = model.as_fake();
2309
2310 let events_1 = thread
2311 .update(cx, |thread, cx| {
2312 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2313 })
2314 .unwrap();
2315 cx.run_until_parked();
2316 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2317 fake_model
2318 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2319 fake_model.end_last_completion_stream();
2320 let events_1 = events_1.collect::<Vec<_>>().await;
2321
2322 let events_2 = thread
2323 .update(cx, |thread, cx| {
2324 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2325 })
2326 .unwrap();
2327 cx.run_until_parked();
2328 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2329 fake_model
2330 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2331 fake_model.end_last_completion_stream();
2332 let events_2 = events_2.collect::<Vec<_>>().await;
2333
2334 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2335 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2336}
2337
2338#[gpui::test]
2339async fn test_refusal(cx: &mut TestAppContext) {
2340 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2341 let fake_model = model.as_fake();
2342
2343 let events = thread
2344 .update(cx, |thread, cx| {
2345 thread.send(UserMessageId::new(), ["Hello"], cx)
2346 })
2347 .unwrap();
2348 cx.run_until_parked();
2349 thread.read_with(cx, |thread, _| {
2350 assert_eq!(
2351 thread.to_markdown(),
2352 indoc! {"
2353 ## User
2354
2355 Hello
2356 "}
2357 );
2358 });
2359
2360 fake_model.send_last_completion_stream_text_chunk("Hey!");
2361 cx.run_until_parked();
2362 thread.read_with(cx, |thread, _| {
2363 assert_eq!(
2364 thread.to_markdown(),
2365 indoc! {"
2366 ## User
2367
2368 Hello
2369
2370 ## Assistant
2371
2372 Hey!
2373 "}
2374 );
2375 });
2376
2377 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2378 fake_model
2379 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2380 let events = events.collect::<Vec<_>>().await;
2381 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2382 thread.read_with(cx, |thread, _| {
2383 assert_eq!(thread.to_markdown(), "");
2384 });
2385}
2386
2387#[gpui::test]
2388async fn test_truncate_first_message(cx: &mut TestAppContext) {
2389 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2390 let fake_model = model.as_fake();
2391
2392 let message_id = UserMessageId::new();
2393 thread
2394 .update(cx, |thread, cx| {
2395 thread.send(message_id.clone(), ["Hello"], cx)
2396 })
2397 .unwrap();
2398 cx.run_until_parked();
2399 thread.read_with(cx, |thread, _| {
2400 assert_eq!(
2401 thread.to_markdown(),
2402 indoc! {"
2403 ## User
2404
2405 Hello
2406 "}
2407 );
2408 assert_eq!(thread.latest_token_usage(), None);
2409 });
2410
2411 fake_model.send_last_completion_stream_text_chunk("Hey!");
2412 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2413 language_model::TokenUsage {
2414 input_tokens: 32_000,
2415 output_tokens: 16_000,
2416 cache_creation_input_tokens: 0,
2417 cache_read_input_tokens: 0,
2418 },
2419 ));
2420 cx.run_until_parked();
2421 thread.read_with(cx, |thread, _| {
2422 assert_eq!(
2423 thread.to_markdown(),
2424 indoc! {"
2425 ## User
2426
2427 Hello
2428
2429 ## Assistant
2430
2431 Hey!
2432 "}
2433 );
2434 assert_eq!(
2435 thread.latest_token_usage(),
2436 Some(acp_thread::TokenUsage {
2437 used_tokens: 32_000 + 16_000,
2438 max_tokens: 1_000_000,
2439 input_tokens: 32_000,
2440 output_tokens: 16_000,
2441 })
2442 );
2443 });
2444
2445 thread
2446 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2447 .unwrap();
2448 cx.run_until_parked();
2449 thread.read_with(cx, |thread, _| {
2450 assert_eq!(thread.to_markdown(), "");
2451 assert_eq!(thread.latest_token_usage(), None);
2452 });
2453
2454 // Ensure we can still send a new message after truncation.
2455 thread
2456 .update(cx, |thread, cx| {
2457 thread.send(UserMessageId::new(), ["Hi"], cx)
2458 })
2459 .unwrap();
2460 thread.update(cx, |thread, _cx| {
2461 assert_eq!(
2462 thread.to_markdown(),
2463 indoc! {"
2464 ## User
2465
2466 Hi
2467 "}
2468 );
2469 });
2470 cx.run_until_parked();
2471 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2472 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2473 language_model::TokenUsage {
2474 input_tokens: 40_000,
2475 output_tokens: 20_000,
2476 cache_creation_input_tokens: 0,
2477 cache_read_input_tokens: 0,
2478 },
2479 ));
2480 cx.run_until_parked();
2481 thread.read_with(cx, |thread, _| {
2482 assert_eq!(
2483 thread.to_markdown(),
2484 indoc! {"
2485 ## User
2486
2487 Hi
2488
2489 ## Assistant
2490
2491 Ahoy!
2492 "}
2493 );
2494
2495 assert_eq!(
2496 thread.latest_token_usage(),
2497 Some(acp_thread::TokenUsage {
2498 used_tokens: 40_000 + 20_000,
2499 max_tokens: 1_000_000,
2500 input_tokens: 40_000,
2501 output_tokens: 20_000,
2502 })
2503 );
2504 });
2505}
2506
2507#[gpui::test]
2508async fn test_truncate_second_message(cx: &mut TestAppContext) {
2509 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2510 let fake_model = model.as_fake();
2511
2512 thread
2513 .update(cx, |thread, cx| {
2514 thread.send(UserMessageId::new(), ["Message 1"], cx)
2515 })
2516 .unwrap();
2517 cx.run_until_parked();
2518 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2519 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2520 language_model::TokenUsage {
2521 input_tokens: 32_000,
2522 output_tokens: 16_000,
2523 cache_creation_input_tokens: 0,
2524 cache_read_input_tokens: 0,
2525 },
2526 ));
2527 fake_model.end_last_completion_stream();
2528 cx.run_until_parked();
2529
2530 let assert_first_message_state = |cx: &mut TestAppContext| {
2531 thread.clone().read_with(cx, |thread, _| {
2532 assert_eq!(
2533 thread.to_markdown(),
2534 indoc! {"
2535 ## User
2536
2537 Message 1
2538
2539 ## Assistant
2540
2541 Message 1 response
2542 "}
2543 );
2544
2545 assert_eq!(
2546 thread.latest_token_usage(),
2547 Some(acp_thread::TokenUsage {
2548 used_tokens: 32_000 + 16_000,
2549 max_tokens: 1_000_000,
2550 input_tokens: 32_000,
2551 output_tokens: 16_000,
2552 })
2553 );
2554 });
2555 };
2556
2557 assert_first_message_state(cx);
2558
2559 let second_message_id = UserMessageId::new();
2560 thread
2561 .update(cx, |thread, cx| {
2562 thread.send(second_message_id.clone(), ["Message 2"], cx)
2563 })
2564 .unwrap();
2565 cx.run_until_parked();
2566
2567 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2568 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2569 language_model::TokenUsage {
2570 input_tokens: 40_000,
2571 output_tokens: 20_000,
2572 cache_creation_input_tokens: 0,
2573 cache_read_input_tokens: 0,
2574 },
2575 ));
2576 fake_model.end_last_completion_stream();
2577 cx.run_until_parked();
2578
2579 thread.read_with(cx, |thread, _| {
2580 assert_eq!(
2581 thread.to_markdown(),
2582 indoc! {"
2583 ## User
2584
2585 Message 1
2586
2587 ## Assistant
2588
2589 Message 1 response
2590
2591 ## User
2592
2593 Message 2
2594
2595 ## Assistant
2596
2597 Message 2 response
2598 "}
2599 );
2600
2601 assert_eq!(
2602 thread.latest_token_usage(),
2603 Some(acp_thread::TokenUsage {
2604 used_tokens: 40_000 + 20_000,
2605 max_tokens: 1_000_000,
2606 input_tokens: 40_000,
2607 output_tokens: 20_000,
2608 })
2609 );
2610 });
2611
2612 thread
2613 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2614 .unwrap();
2615 cx.run_until_parked();
2616
2617 assert_first_message_state(cx);
2618}
2619
2620#[gpui::test]
2621async fn test_title_generation(cx: &mut TestAppContext) {
2622 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2623 let fake_model = model.as_fake();
2624
2625 let summary_model = Arc::new(FakeLanguageModel::default());
2626 thread.update(cx, |thread, cx| {
2627 thread.set_summarization_model(Some(summary_model.clone()), cx)
2628 });
2629
2630 let send = thread
2631 .update(cx, |thread, cx| {
2632 thread.send(UserMessageId::new(), ["Hello"], cx)
2633 })
2634 .unwrap();
2635 cx.run_until_parked();
2636
2637 fake_model.send_last_completion_stream_text_chunk("Hey!");
2638 fake_model.end_last_completion_stream();
2639 cx.run_until_parked();
2640 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2641
2642 // Ensure the summary model has been invoked to generate a title.
2643 summary_model.send_last_completion_stream_text_chunk("Hello ");
2644 summary_model.send_last_completion_stream_text_chunk("world\nG");
2645 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2646 summary_model.end_last_completion_stream();
2647 send.collect::<Vec<_>>().await;
2648 cx.run_until_parked();
2649 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2650
2651 // Send another message, ensuring no title is generated this time.
2652 let send = thread
2653 .update(cx, |thread, cx| {
2654 thread.send(UserMessageId::new(), ["Hello again"], cx)
2655 })
2656 .unwrap();
2657 cx.run_until_parked();
2658 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2659 fake_model.end_last_completion_stream();
2660 cx.run_until_parked();
2661 assert_eq!(summary_model.pending_completions(), Vec::new());
2662 send.collect::<Vec<_>>().await;
2663 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2664}
2665
2666#[gpui::test]
2667async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2668 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2669 let fake_model = model.as_fake();
2670
2671 let _events = thread
2672 .update(cx, |thread, cx| {
2673 thread.add_tool(ToolRequiringPermission);
2674 thread.add_tool(EchoTool);
2675 thread.send(UserMessageId::new(), ["Hey!"], cx)
2676 })
2677 .unwrap();
2678 cx.run_until_parked();
2679
2680 let permission_tool_use = LanguageModelToolUse {
2681 id: "tool_id_1".into(),
2682 name: ToolRequiringPermission::name().into(),
2683 raw_input: "{}".into(),
2684 input: json!({}),
2685 is_input_complete: true,
2686 thought_signature: None,
2687 };
2688 let echo_tool_use = LanguageModelToolUse {
2689 id: "tool_id_2".into(),
2690 name: EchoTool::name().into(),
2691 raw_input: json!({"text": "test"}).to_string(),
2692 input: json!({"text": "test"}),
2693 is_input_complete: true,
2694 thought_signature: None,
2695 };
2696 fake_model.send_last_completion_stream_text_chunk("Hi!");
2697 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2698 permission_tool_use,
2699 ));
2700 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2701 echo_tool_use.clone(),
2702 ));
2703 fake_model.end_last_completion_stream();
2704 cx.run_until_parked();
2705
2706 // Ensure pending tools are skipped when building a request.
2707 let request = thread
2708 .read_with(cx, |thread, cx| {
2709 thread.build_completion_request(CompletionIntent::EditFile, cx)
2710 })
2711 .unwrap();
2712 assert_eq!(
2713 request.messages[1..],
2714 vec![
2715 LanguageModelRequestMessage {
2716 role: Role::User,
2717 content: vec!["Hey!".into()],
2718 cache: true,
2719 reasoning_details: None,
2720 },
2721 LanguageModelRequestMessage {
2722 role: Role::Assistant,
2723 content: vec![
2724 MessageContent::Text("Hi!".into()),
2725 MessageContent::ToolUse(echo_tool_use.clone())
2726 ],
2727 cache: false,
2728 reasoning_details: None,
2729 },
2730 LanguageModelRequestMessage {
2731 role: Role::User,
2732 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2733 tool_use_id: echo_tool_use.id.clone(),
2734 tool_name: echo_tool_use.name,
2735 is_error: false,
2736 content: "test".into(),
2737 output: Some("test".into())
2738 })],
2739 cache: false,
2740 reasoning_details: None,
2741 },
2742 ],
2743 );
2744}
2745
2746#[gpui::test]
2747async fn test_agent_connection(cx: &mut TestAppContext) {
2748 cx.update(settings::init);
2749 let templates = Templates::new();
2750
2751 // Initialize language model system with test provider
2752 cx.update(|cx| {
2753 gpui_tokio::init(cx);
2754
2755 let http_client = FakeHttpClient::with_404_response();
2756 let clock = Arc::new(clock::FakeSystemClock::new());
2757 let client = Client::new(clock, http_client, cx);
2758 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2759 language_model::init(client.clone(), cx);
2760 language_models::init(user_store, client.clone(), cx);
2761 LanguageModelRegistry::test(cx);
2762 });
2763 cx.executor().forbid_parking();
2764
2765 // Create a project for new_thread
2766 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2767 fake_fs.insert_tree(path!("/test"), json!({})).await;
2768 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2769 let cwd = Path::new("/test");
2770 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2771
2772 // Create agent and connection
2773 let agent = NativeAgent::new(
2774 project.clone(),
2775 thread_store,
2776 templates.clone(),
2777 None,
2778 fake_fs.clone(),
2779 &mut cx.to_async(),
2780 )
2781 .await
2782 .unwrap();
2783 let connection = NativeAgentConnection(agent.clone());
2784
2785 // Create a thread using new_thread
2786 let connection_rc = Rc::new(connection.clone());
2787 let acp_thread = cx
2788 .update(|cx| connection_rc.new_thread(project, cwd, cx))
2789 .await
2790 .expect("new_thread should succeed");
2791
2792 // Get the session_id from the AcpThread
2793 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2794
2795 // Test model_selector returns Some
2796 let selector_opt = connection.model_selector(&session_id);
2797 assert!(
2798 selector_opt.is_some(),
2799 "agent should always support ModelSelector"
2800 );
2801 let selector = selector_opt.unwrap();
2802
2803 // Test list_models
2804 let listed_models = cx
2805 .update(|cx| selector.list_models(cx))
2806 .await
2807 .expect("list_models should succeed");
2808 let AgentModelList::Grouped(listed_models) = listed_models else {
2809 panic!("Unexpected model list type");
2810 };
2811 assert!(!listed_models.is_empty(), "should have at least one model");
2812 assert_eq!(
2813 listed_models[&AgentModelGroupName("Fake".into())][0]
2814 .id
2815 .0
2816 .as_ref(),
2817 "fake/fake"
2818 );
2819
2820 // Test selected_model returns the default
2821 let model = cx
2822 .update(|cx| selector.selected_model(cx))
2823 .await
2824 .expect("selected_model should succeed");
2825 let model = cx
2826 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2827 .unwrap();
2828 let model = model.as_fake();
2829 assert_eq!(model.id().0, "fake", "should return default model");
2830
2831 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2832 cx.run_until_parked();
2833 model.send_last_completion_stream_text_chunk("def");
2834 cx.run_until_parked();
2835 acp_thread.read_with(cx, |thread, cx| {
2836 assert_eq!(
2837 thread.to_markdown(cx),
2838 indoc! {"
2839 ## User
2840
2841 abc
2842
2843 ## Assistant
2844
2845 def
2846
2847 "}
2848 )
2849 });
2850
2851 // Test cancel
2852 cx.update(|cx| connection.cancel(&session_id, cx));
2853 request.await.expect("prompt should fail gracefully");
2854
2855 // Ensure that dropping the ACP thread causes the native thread to be
2856 // dropped as well.
2857 cx.update(|_| drop(acp_thread));
2858 let result = cx
2859 .update(|cx| {
2860 connection.prompt(
2861 Some(acp_thread::UserMessageId::new()),
2862 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2863 cx,
2864 )
2865 })
2866 .await;
2867 assert_eq!(
2868 result.as_ref().unwrap_err().to_string(),
2869 "Session not found",
2870 "unexpected result: {:?}",
2871 result
2872 );
2873}
2874
2875#[gpui::test]
2876async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2877 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2878 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2879 let fake_model = model.as_fake();
2880
2881 let mut events = thread
2882 .update(cx, |thread, cx| {
2883 thread.send(UserMessageId::new(), ["Think"], cx)
2884 })
2885 .unwrap();
2886 cx.run_until_parked();
2887
2888 // Simulate streaming partial input.
2889 let input = json!({});
2890 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2891 LanguageModelToolUse {
2892 id: "1".into(),
2893 name: ThinkingTool::name().into(),
2894 raw_input: input.to_string(),
2895 input,
2896 is_input_complete: false,
2897 thought_signature: None,
2898 },
2899 ));
2900
2901 // Input streaming completed
2902 let input = json!({ "content": "Thinking hard!" });
2903 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2904 LanguageModelToolUse {
2905 id: "1".into(),
2906 name: "thinking".into(),
2907 raw_input: input.to_string(),
2908 input,
2909 is_input_complete: true,
2910 thought_signature: None,
2911 },
2912 ));
2913 fake_model.end_last_completion_stream();
2914 cx.run_until_parked();
2915
2916 let tool_call = expect_tool_call(&mut events).await;
2917 assert_eq!(
2918 tool_call,
2919 acp::ToolCall::new("1", "Thinking")
2920 .kind(acp::ToolKind::Think)
2921 .raw_input(json!({}))
2922 .meta(acp::Meta::from_iter([(
2923 "tool_name".into(),
2924 "thinking".into()
2925 )]))
2926 );
2927 let update = expect_tool_call_update_fields(&mut events).await;
2928 assert_eq!(
2929 update,
2930 acp::ToolCallUpdate::new(
2931 "1",
2932 acp::ToolCallUpdateFields::new()
2933 .title("Thinking")
2934 .kind(acp::ToolKind::Think)
2935 .raw_input(json!({ "content": "Thinking hard!"}))
2936 )
2937 );
2938 let update = expect_tool_call_update_fields(&mut events).await;
2939 assert_eq!(
2940 update,
2941 acp::ToolCallUpdate::new(
2942 "1",
2943 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2944 )
2945 );
2946 let update = expect_tool_call_update_fields(&mut events).await;
2947 assert_eq!(
2948 update,
2949 acp::ToolCallUpdate::new(
2950 "1",
2951 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2952 )
2953 );
2954 let update = expect_tool_call_update_fields(&mut events).await;
2955 assert_eq!(
2956 update,
2957 acp::ToolCallUpdate::new(
2958 "1",
2959 acp::ToolCallUpdateFields::new()
2960 .status(acp::ToolCallStatus::Completed)
2961 .raw_output("Finished thinking.")
2962 )
2963 );
2964}
2965
2966#[gpui::test]
2967async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2968 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2969 let fake_model = model.as_fake();
2970
2971 let mut events = thread
2972 .update(cx, |thread, cx| {
2973 thread.send(UserMessageId::new(), ["Hello!"], cx)
2974 })
2975 .unwrap();
2976 cx.run_until_parked();
2977
2978 fake_model.send_last_completion_stream_text_chunk("Hey!");
2979 fake_model.end_last_completion_stream();
2980
2981 let mut retry_events = Vec::new();
2982 while let Some(Ok(event)) = events.next().await {
2983 match event {
2984 ThreadEvent::Retry(retry_status) => {
2985 retry_events.push(retry_status);
2986 }
2987 ThreadEvent::Stop(..) => break,
2988 _ => {}
2989 }
2990 }
2991
2992 assert_eq!(retry_events.len(), 0);
2993 thread.read_with(cx, |thread, _cx| {
2994 assert_eq!(
2995 thread.to_markdown(),
2996 indoc! {"
2997 ## User
2998
2999 Hello!
3000
3001 ## Assistant
3002
3003 Hey!
3004 "}
3005 )
3006 });
3007}
3008
3009#[gpui::test]
3010async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3011 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3012 let fake_model = model.as_fake();
3013
3014 let mut events = thread
3015 .update(cx, |thread, cx| {
3016 thread.send(UserMessageId::new(), ["Hello!"], cx)
3017 })
3018 .unwrap();
3019 cx.run_until_parked();
3020
3021 fake_model.send_last_completion_stream_text_chunk("Hey,");
3022 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3023 provider: LanguageModelProviderName::new("Anthropic"),
3024 retry_after: Some(Duration::from_secs(3)),
3025 });
3026 fake_model.end_last_completion_stream();
3027
3028 cx.executor().advance_clock(Duration::from_secs(3));
3029 cx.run_until_parked();
3030
3031 fake_model.send_last_completion_stream_text_chunk("there!");
3032 fake_model.end_last_completion_stream();
3033 cx.run_until_parked();
3034
3035 let mut retry_events = Vec::new();
3036 while let Some(Ok(event)) = events.next().await {
3037 match event {
3038 ThreadEvent::Retry(retry_status) => {
3039 retry_events.push(retry_status);
3040 }
3041 ThreadEvent::Stop(..) => break,
3042 _ => {}
3043 }
3044 }
3045
3046 assert_eq!(retry_events.len(), 1);
3047 assert!(matches!(
3048 retry_events[0],
3049 acp_thread::RetryStatus { attempt: 1, .. }
3050 ));
3051 thread.read_with(cx, |thread, _cx| {
3052 assert_eq!(
3053 thread.to_markdown(),
3054 indoc! {"
3055 ## User
3056
3057 Hello!
3058
3059 ## Assistant
3060
3061 Hey,
3062
3063 [resume]
3064
3065 ## Assistant
3066
3067 there!
3068 "}
3069 )
3070 });
3071}
3072
3073#[gpui::test]
3074async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3075 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3076 let fake_model = model.as_fake();
3077
3078 let events = thread
3079 .update(cx, |thread, cx| {
3080 thread.add_tool(EchoTool);
3081 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3082 })
3083 .unwrap();
3084 cx.run_until_parked();
3085
3086 let tool_use_1 = LanguageModelToolUse {
3087 id: "tool_1".into(),
3088 name: EchoTool::name().into(),
3089 raw_input: json!({"text": "test"}).to_string(),
3090 input: json!({"text": "test"}),
3091 is_input_complete: true,
3092 thought_signature: None,
3093 };
3094 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3095 tool_use_1.clone(),
3096 ));
3097 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3098 provider: LanguageModelProviderName::new("Anthropic"),
3099 retry_after: Some(Duration::from_secs(3)),
3100 });
3101 fake_model.end_last_completion_stream();
3102
3103 cx.executor().advance_clock(Duration::from_secs(3));
3104 let completion = fake_model.pending_completions().pop().unwrap();
3105 assert_eq!(
3106 completion.messages[1..],
3107 vec![
3108 LanguageModelRequestMessage {
3109 role: Role::User,
3110 content: vec!["Call the echo tool!".into()],
3111 cache: false,
3112 reasoning_details: None,
3113 },
3114 LanguageModelRequestMessage {
3115 role: Role::Assistant,
3116 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3117 cache: false,
3118 reasoning_details: None,
3119 },
3120 LanguageModelRequestMessage {
3121 role: Role::User,
3122 content: vec![language_model::MessageContent::ToolResult(
3123 LanguageModelToolResult {
3124 tool_use_id: tool_use_1.id.clone(),
3125 tool_name: tool_use_1.name.clone(),
3126 is_error: false,
3127 content: "test".into(),
3128 output: Some("test".into())
3129 }
3130 )],
3131 cache: true,
3132 reasoning_details: None,
3133 },
3134 ]
3135 );
3136
3137 fake_model.send_last_completion_stream_text_chunk("Done");
3138 fake_model.end_last_completion_stream();
3139 cx.run_until_parked();
3140 events.collect::<Vec<_>>().await;
3141 thread.read_with(cx, |thread, _cx| {
3142 assert_eq!(
3143 thread.last_message(),
3144 Some(Message::Agent(AgentMessage {
3145 content: vec![AgentMessageContent::Text("Done".into())],
3146 tool_results: IndexMap::default(),
3147 reasoning_details: None,
3148 }))
3149 );
3150 })
3151}
3152
3153#[gpui::test]
3154async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3155 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3156 let fake_model = model.as_fake();
3157
3158 let mut events = thread
3159 .update(cx, |thread, cx| {
3160 thread.send(UserMessageId::new(), ["Hello!"], cx)
3161 })
3162 .unwrap();
3163 cx.run_until_parked();
3164
3165 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3166 fake_model.send_last_completion_stream_error(
3167 LanguageModelCompletionError::ServerOverloaded {
3168 provider: LanguageModelProviderName::new("Anthropic"),
3169 retry_after: Some(Duration::from_secs(3)),
3170 },
3171 );
3172 fake_model.end_last_completion_stream();
3173 cx.executor().advance_clock(Duration::from_secs(3));
3174 cx.run_until_parked();
3175 }
3176
3177 let mut errors = Vec::new();
3178 let mut retry_events = Vec::new();
3179 while let Some(event) = events.next().await {
3180 match event {
3181 Ok(ThreadEvent::Retry(retry_status)) => {
3182 retry_events.push(retry_status);
3183 }
3184 Ok(ThreadEvent::Stop(..)) => break,
3185 Err(error) => errors.push(error),
3186 _ => {}
3187 }
3188 }
3189
3190 assert_eq!(
3191 retry_events.len(),
3192 crate::thread::MAX_RETRY_ATTEMPTS as usize
3193 );
3194 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3195 assert_eq!(retry_events[i].attempt, i + 1);
3196 }
3197 assert_eq!(errors.len(), 1);
3198 let error = errors[0]
3199 .downcast_ref::<LanguageModelCompletionError>()
3200 .unwrap();
3201 assert!(matches!(
3202 error,
3203 LanguageModelCompletionError::ServerOverloaded { .. }
3204 ));
3205}
3206
3207/// Filters out the stop events for asserting against in tests
3208fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3209 result_events
3210 .into_iter()
3211 .filter_map(|event| match event.unwrap() {
3212 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3213 _ => None,
3214 })
3215 .collect()
3216}
3217
3218struct ThreadTest {
3219 model: Arc<dyn LanguageModel>,
3220 thread: Entity<Thread>,
3221 project_context: Entity<ProjectContext>,
3222 context_server_store: Entity<ContextServerStore>,
3223 fs: Arc<FakeFs>,
3224}
3225
3226enum TestModel {
3227 Sonnet4,
3228 Fake,
3229}
3230
3231impl TestModel {
3232 fn id(&self) -> LanguageModelId {
3233 match self {
3234 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3235 TestModel::Fake => unreachable!(),
3236 }
3237 }
3238}
3239
3240async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3241 cx.executor().allow_parking();
3242
3243 let fs = FakeFs::new(cx.background_executor.clone());
3244 fs.create_dir(paths::settings_file().parent().unwrap())
3245 .await
3246 .unwrap();
3247 fs.insert_file(
3248 paths::settings_file(),
3249 json!({
3250 "agent": {
3251 "default_profile": "test-profile",
3252 "profiles": {
3253 "test-profile": {
3254 "name": "Test Profile",
3255 "tools": {
3256 EchoTool::name(): true,
3257 DelayTool::name(): true,
3258 WordListTool::name(): true,
3259 ToolRequiringPermission::name(): true,
3260 InfiniteTool::name(): true,
3261 CancellationAwareTool::name(): true,
3262 ThinkingTool::name(): true,
3263 "terminal": true,
3264 }
3265 }
3266 }
3267 }
3268 })
3269 .to_string()
3270 .into_bytes(),
3271 )
3272 .await;
3273
3274 cx.update(|cx| {
3275 settings::init(cx);
3276
3277 match model {
3278 TestModel::Fake => {}
3279 TestModel::Sonnet4 => {
3280 gpui_tokio::init(cx);
3281 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3282 cx.set_http_client(Arc::new(http_client));
3283 let client = Client::production(cx);
3284 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3285 language_model::init(client.clone(), cx);
3286 language_models::init(user_store, client.clone(), cx);
3287 }
3288 };
3289
3290 watch_settings(fs.clone(), cx);
3291 });
3292
3293 let templates = Templates::new();
3294
3295 fs.insert_tree(path!("/test"), json!({})).await;
3296 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3297
3298 let model = cx
3299 .update(|cx| {
3300 if let TestModel::Fake = model {
3301 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3302 } else {
3303 let model_id = model.id();
3304 let models = LanguageModelRegistry::read_global(cx);
3305 let model = models
3306 .available_models(cx)
3307 .find(|model| model.id() == model_id)
3308 .unwrap();
3309
3310 let provider = models.provider(&model.provider_id()).unwrap();
3311 let authenticated = provider.authenticate(cx);
3312
3313 cx.spawn(async move |_cx| {
3314 authenticated.await.unwrap();
3315 model
3316 })
3317 }
3318 })
3319 .await;
3320
3321 let project_context = cx.new(|_cx| ProjectContext::default());
3322 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3323 let context_server_registry =
3324 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3325 let thread = cx.new(|cx| {
3326 Thread::new(
3327 project,
3328 project_context.clone(),
3329 context_server_registry,
3330 templates,
3331 Some(model.clone()),
3332 cx,
3333 )
3334 });
3335 ThreadTest {
3336 model,
3337 thread,
3338 project_context,
3339 context_server_store,
3340 fs,
3341 }
3342}
3343
3344#[cfg(test)]
3345#[ctor::ctor]
3346fn init_logger() {
3347 if std::env::var("RUST_LOG").is_ok() {
3348 env_logger::init();
3349 }
3350}
3351
3352fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3353 let fs = fs.clone();
3354 cx.spawn({
3355 async move |cx| {
3356 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3357 cx.background_executor(),
3358 fs,
3359 paths::settings_file().clone(),
3360 );
3361 let _watcher_task = watcher_task;
3362
3363 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3364 cx.update(|cx| {
3365 SettingsStore::update_global(cx, |settings, cx| {
3366 settings.set_user_settings(&new_settings_content, cx)
3367 })
3368 })
3369 .ok();
3370 }
3371 }
3372 })
3373 .detach();
3374}
3375
3376fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3377 completion
3378 .tools
3379 .iter()
3380 .map(|tool| tool.name.clone())
3381 .collect()
3382}
3383
3384fn setup_context_server(
3385 name: &'static str,
3386 tools: Vec<context_server::types::Tool>,
3387 context_server_store: &Entity<ContextServerStore>,
3388 cx: &mut TestAppContext,
3389) -> mpsc::UnboundedReceiver<(
3390 context_server::types::CallToolParams,
3391 oneshot::Sender<context_server::types::CallToolResponse>,
3392)> {
3393 cx.update(|cx| {
3394 let mut settings = ProjectSettings::get_global(cx).clone();
3395 settings.context_servers.insert(
3396 name.into(),
3397 project::project_settings::ContextServerSettings::Stdio {
3398 enabled: true,
3399 remote: false,
3400 command: ContextServerCommand {
3401 path: "somebinary".into(),
3402 args: Vec::new(),
3403 env: None,
3404 timeout: None,
3405 },
3406 },
3407 );
3408 ProjectSettings::override_global(settings, cx);
3409 });
3410
3411 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3412 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3413 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3414 context_server::types::InitializeResponse {
3415 protocol_version: context_server::types::ProtocolVersion(
3416 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3417 ),
3418 server_info: context_server::types::Implementation {
3419 name: name.into(),
3420 version: "1.0.0".to_string(),
3421 },
3422 capabilities: context_server::types::ServerCapabilities {
3423 tools: Some(context_server::types::ToolsCapabilities {
3424 list_changed: Some(true),
3425 }),
3426 ..Default::default()
3427 },
3428 meta: None,
3429 }
3430 })
3431 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3432 let tools = tools.clone();
3433 async move {
3434 context_server::types::ListToolsResponse {
3435 tools,
3436 next_cursor: None,
3437 meta: None,
3438 }
3439 }
3440 })
3441 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3442 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3443 async move {
3444 let (response_tx, response_rx) = oneshot::channel();
3445 mcp_tool_calls_tx
3446 .unbounded_send((params, response_tx))
3447 .unwrap();
3448 response_rx.await.unwrap()
3449 }
3450 });
3451 context_server_store.update(cx, |store, cx| {
3452 store.start_server(
3453 Arc::new(ContextServer::new(
3454 ContextServerId(name.into()),
3455 Arc::new(fake_transport),
3456 )),
3457 cx,
3458 );
3459 });
3460 cx.run_until_parked();
3461 mcp_tool_calls_rx
3462}
3463
3464#[gpui::test]
3465async fn test_tokens_before_message(cx: &mut TestAppContext) {
3466 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3467 let fake_model = model.as_fake();
3468
3469 // First message
3470 let message_1_id = UserMessageId::new();
3471 thread
3472 .update(cx, |thread, cx| {
3473 thread.send(message_1_id.clone(), ["First message"], cx)
3474 })
3475 .unwrap();
3476 cx.run_until_parked();
3477
3478 // Before any response, tokens_before_message should return None for first message
3479 thread.read_with(cx, |thread, _| {
3480 assert_eq!(
3481 thread.tokens_before_message(&message_1_id),
3482 None,
3483 "First message should have no tokens before it"
3484 );
3485 });
3486
3487 // Complete first message with usage
3488 fake_model.send_last_completion_stream_text_chunk("Response 1");
3489 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3490 language_model::TokenUsage {
3491 input_tokens: 100,
3492 output_tokens: 50,
3493 cache_creation_input_tokens: 0,
3494 cache_read_input_tokens: 0,
3495 },
3496 ));
3497 fake_model.end_last_completion_stream();
3498 cx.run_until_parked();
3499
3500 // First message still has no tokens before it
3501 thread.read_with(cx, |thread, _| {
3502 assert_eq!(
3503 thread.tokens_before_message(&message_1_id),
3504 None,
3505 "First message should still have no tokens before it after response"
3506 );
3507 });
3508
3509 // Second message
3510 let message_2_id = UserMessageId::new();
3511 thread
3512 .update(cx, |thread, cx| {
3513 thread.send(message_2_id.clone(), ["Second message"], cx)
3514 })
3515 .unwrap();
3516 cx.run_until_parked();
3517
3518 // Second message should have first message's input tokens before it
3519 thread.read_with(cx, |thread, _| {
3520 assert_eq!(
3521 thread.tokens_before_message(&message_2_id),
3522 Some(100),
3523 "Second message should have 100 tokens before it (from first request)"
3524 );
3525 });
3526
3527 // Complete second message
3528 fake_model.send_last_completion_stream_text_chunk("Response 2");
3529 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3530 language_model::TokenUsage {
3531 input_tokens: 250, // Total for this request (includes previous context)
3532 output_tokens: 75,
3533 cache_creation_input_tokens: 0,
3534 cache_read_input_tokens: 0,
3535 },
3536 ));
3537 fake_model.end_last_completion_stream();
3538 cx.run_until_parked();
3539
3540 // Third message
3541 let message_3_id = UserMessageId::new();
3542 thread
3543 .update(cx, |thread, cx| {
3544 thread.send(message_3_id.clone(), ["Third message"], cx)
3545 })
3546 .unwrap();
3547 cx.run_until_parked();
3548
3549 // Third message should have second message's input tokens (250) before it
3550 thread.read_with(cx, |thread, _| {
3551 assert_eq!(
3552 thread.tokens_before_message(&message_3_id),
3553 Some(250),
3554 "Third message should have 250 tokens before it (from second request)"
3555 );
3556 // Second message should still have 100
3557 assert_eq!(
3558 thread.tokens_before_message(&message_2_id),
3559 Some(100),
3560 "Second message should still have 100 tokens before it"
3561 );
3562 // First message still has none
3563 assert_eq!(
3564 thread.tokens_before_message(&message_1_id),
3565 None,
3566 "First message should still have no tokens before it"
3567 );
3568 });
3569}
3570
3571#[gpui::test]
3572async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3573 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3574 let fake_model = model.as_fake();
3575
3576 // Set up three messages with responses
3577 let message_1_id = UserMessageId::new();
3578 thread
3579 .update(cx, |thread, cx| {
3580 thread.send(message_1_id.clone(), ["Message 1"], cx)
3581 })
3582 .unwrap();
3583 cx.run_until_parked();
3584 fake_model.send_last_completion_stream_text_chunk("Response 1");
3585 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3586 language_model::TokenUsage {
3587 input_tokens: 100,
3588 output_tokens: 50,
3589 cache_creation_input_tokens: 0,
3590 cache_read_input_tokens: 0,
3591 },
3592 ));
3593 fake_model.end_last_completion_stream();
3594 cx.run_until_parked();
3595
3596 let message_2_id = UserMessageId::new();
3597 thread
3598 .update(cx, |thread, cx| {
3599 thread.send(message_2_id.clone(), ["Message 2"], cx)
3600 })
3601 .unwrap();
3602 cx.run_until_parked();
3603 fake_model.send_last_completion_stream_text_chunk("Response 2");
3604 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3605 language_model::TokenUsage {
3606 input_tokens: 250,
3607 output_tokens: 75,
3608 cache_creation_input_tokens: 0,
3609 cache_read_input_tokens: 0,
3610 },
3611 ));
3612 fake_model.end_last_completion_stream();
3613 cx.run_until_parked();
3614
3615 // Verify initial state
3616 thread.read_with(cx, |thread, _| {
3617 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3618 });
3619
3620 // Truncate at message 2 (removes message 2 and everything after)
3621 thread
3622 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3623 .unwrap();
3624 cx.run_until_parked();
3625
3626 // After truncation, message_2_id no longer exists, so lookup should return None
3627 thread.read_with(cx, |thread, _| {
3628 assert_eq!(
3629 thread.tokens_before_message(&message_2_id),
3630 None,
3631 "After truncation, message 2 no longer exists"
3632 );
3633 // Message 1 still exists but has no tokens before it
3634 assert_eq!(
3635 thread.tokens_before_message(&message_1_id),
3636 None,
3637 "First message still has no tokens before it"
3638 );
3639 });
3640}
3641
3642#[gpui::test]
3643async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3644 init_test(cx);
3645
3646 let fs = FakeFs::new(cx.executor());
3647 fs.insert_tree("/root", json!({})).await;
3648 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3649
3650 // Test 1: Deny rule blocks command
3651 {
3652 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3653 let environment = Rc::new(FakeThreadEnvironment {
3654 handle: handle.clone(),
3655 });
3656
3657 cx.update(|cx| {
3658 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3659 settings.tool_permissions.tools.insert(
3660 "terminal".into(),
3661 agent_settings::ToolRules {
3662 default_mode: settings::ToolPermissionMode::Confirm,
3663 always_allow: vec![],
3664 always_deny: vec![
3665 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3666 ],
3667 always_confirm: vec![],
3668 invalid_patterns: vec![],
3669 },
3670 );
3671 agent_settings::AgentSettings::override_global(settings, cx);
3672 });
3673
3674 #[allow(clippy::arc_with_non_send_sync)]
3675 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3676 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3677
3678 let task = cx.update(|cx| {
3679 tool.run(
3680 crate::TerminalToolInput {
3681 command: "rm -rf /".to_string(),
3682 cd: ".".to_string(),
3683 timeout_ms: None,
3684 },
3685 event_stream,
3686 cx,
3687 )
3688 });
3689
3690 let result = task.await;
3691 assert!(
3692 result.is_err(),
3693 "expected command to be blocked by deny rule"
3694 );
3695 assert!(
3696 result.unwrap_err().to_string().contains("blocked"),
3697 "error should mention the command was blocked"
3698 );
3699 }
3700
3701 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3702 {
3703 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3704 let environment = Rc::new(FakeThreadEnvironment {
3705 handle: handle.clone(),
3706 });
3707
3708 cx.update(|cx| {
3709 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3710 settings.always_allow_tool_actions = false;
3711 settings.tool_permissions.tools.insert(
3712 "terminal".into(),
3713 agent_settings::ToolRules {
3714 default_mode: settings::ToolPermissionMode::Deny,
3715 always_allow: vec![
3716 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3717 ],
3718 always_deny: vec![],
3719 always_confirm: vec![],
3720 invalid_patterns: vec![],
3721 },
3722 );
3723 agent_settings::AgentSettings::override_global(settings, cx);
3724 });
3725
3726 #[allow(clippy::arc_with_non_send_sync)]
3727 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3728 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3729
3730 let task = cx.update(|cx| {
3731 tool.run(
3732 crate::TerminalToolInput {
3733 command: "echo hello".to_string(),
3734 cd: ".".to_string(),
3735 timeout_ms: None,
3736 },
3737 event_stream,
3738 cx,
3739 )
3740 });
3741
3742 let update = rx.expect_update_fields().await;
3743 assert!(
3744 update.content.iter().any(|blocks| {
3745 blocks
3746 .iter()
3747 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3748 }),
3749 "expected terminal content (allow rule should skip confirmation and override default deny)"
3750 );
3751
3752 let result = task.await;
3753 assert!(
3754 result.is_ok(),
3755 "expected command to succeed without confirmation"
3756 );
3757 }
3758
3759 // Test 3: always_allow_tool_actions=true overrides always_confirm patterns
3760 {
3761 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3762 let environment = Rc::new(FakeThreadEnvironment {
3763 handle: handle.clone(),
3764 });
3765
3766 cx.update(|cx| {
3767 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3768 settings.always_allow_tool_actions = true;
3769 settings.tool_permissions.tools.insert(
3770 "terminal".into(),
3771 agent_settings::ToolRules {
3772 default_mode: settings::ToolPermissionMode::Allow,
3773 always_allow: vec![],
3774 always_deny: vec![],
3775 always_confirm: vec![
3776 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3777 ],
3778 invalid_patterns: vec![],
3779 },
3780 );
3781 agent_settings::AgentSettings::override_global(settings, cx);
3782 });
3783
3784 #[allow(clippy::arc_with_non_send_sync)]
3785 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3786 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3787
3788 let task = cx.update(|cx| {
3789 tool.run(
3790 crate::TerminalToolInput {
3791 command: "sudo rm file".to_string(),
3792 cd: ".".to_string(),
3793 timeout_ms: None,
3794 },
3795 event_stream,
3796 cx,
3797 )
3798 });
3799
3800 // With always_allow_tool_actions=true, confirm patterns are overridden
3801 task.await
3802 .expect("command should be allowed with always_allow_tool_actions=true");
3803 }
3804
3805 // Test 4: always_allow_tool_actions=true overrides default_mode: Deny
3806 {
3807 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3808 let environment = Rc::new(FakeThreadEnvironment {
3809 handle: handle.clone(),
3810 });
3811
3812 cx.update(|cx| {
3813 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3814 settings.always_allow_tool_actions = true;
3815 settings.tool_permissions.tools.insert(
3816 "terminal".into(),
3817 agent_settings::ToolRules {
3818 default_mode: settings::ToolPermissionMode::Deny,
3819 always_allow: vec![],
3820 always_deny: vec![],
3821 always_confirm: vec![],
3822 invalid_patterns: vec![],
3823 },
3824 );
3825 agent_settings::AgentSettings::override_global(settings, cx);
3826 });
3827
3828 #[allow(clippy::arc_with_non_send_sync)]
3829 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3830 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3831
3832 let task = cx.update(|cx| {
3833 tool.run(
3834 crate::TerminalToolInput {
3835 command: "echo hello".to_string(),
3836 cd: ".".to_string(),
3837 timeout_ms: None,
3838 },
3839 event_stream,
3840 cx,
3841 )
3842 });
3843
3844 // With always_allow_tool_actions=true, even default_mode: Deny is overridden
3845 task.await
3846 .expect("command should be allowed with always_allow_tool_actions=true");
3847 }
3848}
3849
3850#[gpui::test]
3851async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3852 init_test(cx);
3853
3854 cx.update(|cx| {
3855 cx.update_flags(true, vec!["subagents".to_string()]);
3856 });
3857
3858 let fs = FakeFs::new(cx.executor());
3859 fs.insert_tree(path!("/test"), json!({})).await;
3860 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3861 let project_context = cx.new(|_cx| ProjectContext::default());
3862 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3863 let context_server_registry =
3864 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3865 let model = Arc::new(FakeLanguageModel::default());
3866
3867 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3868 let environment = Rc::new(FakeThreadEnvironment { handle });
3869
3870 let thread = cx.new(|cx| {
3871 let mut thread = Thread::new(
3872 project.clone(),
3873 project_context,
3874 context_server_registry,
3875 Templates::new(),
3876 Some(model),
3877 cx,
3878 );
3879 thread.add_default_tools(environment, cx);
3880 thread
3881 });
3882
3883 thread.read_with(cx, |thread, _| {
3884 assert!(
3885 thread.has_registered_tool("subagent"),
3886 "subagent tool should be present when feature flag is enabled"
3887 );
3888 });
3889}
3890
3891#[gpui::test]
3892async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3893 init_test(cx);
3894
3895 cx.update(|cx| {
3896 cx.update_flags(true, vec!["subagents".to_string()]);
3897 });
3898
3899 let fs = FakeFs::new(cx.executor());
3900 fs.insert_tree(path!("/test"), json!({})).await;
3901 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3902 let project_context = cx.new(|_cx| ProjectContext::default());
3903 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3904 let context_server_registry =
3905 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3906 let model = Arc::new(FakeLanguageModel::default());
3907
3908 let subagent_context = SubagentContext {
3909 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3910 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3911 depth: 1,
3912 summary_prompt: "Summarize".to_string(),
3913 context_low_prompt: "Context low".to_string(),
3914 };
3915
3916 let subagent = cx.new(|cx| {
3917 Thread::new_subagent(
3918 project.clone(),
3919 project_context,
3920 context_server_registry,
3921 Templates::new(),
3922 model.clone(),
3923 subagent_context,
3924 std::collections::BTreeMap::new(),
3925 cx,
3926 )
3927 });
3928
3929 subagent.read_with(cx, |thread, _| {
3930 assert!(thread.is_subagent());
3931 assert_eq!(thread.depth(), 1);
3932 assert!(thread.model().is_some());
3933 });
3934}
3935
3936#[gpui::test]
3937async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
3938 init_test(cx);
3939
3940 cx.update(|cx| {
3941 cx.update_flags(true, vec!["subagents".to_string()]);
3942 });
3943
3944 let fs = FakeFs::new(cx.executor());
3945 fs.insert_tree(path!("/test"), json!({})).await;
3946 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3947 let project_context = cx.new(|_cx| ProjectContext::default());
3948 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3949 let context_server_registry =
3950 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3951 let model = Arc::new(FakeLanguageModel::default());
3952
3953 let subagent_context = SubagentContext {
3954 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3955 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3956 depth: MAX_SUBAGENT_DEPTH,
3957 summary_prompt: "Summarize".to_string(),
3958 context_low_prompt: "Context low".to_string(),
3959 };
3960
3961 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3962 let environment = Rc::new(FakeThreadEnvironment { handle });
3963
3964 let deep_subagent = cx.new(|cx| {
3965 let mut thread = Thread::new_subagent(
3966 project.clone(),
3967 project_context,
3968 context_server_registry,
3969 Templates::new(),
3970 model.clone(),
3971 subagent_context,
3972 std::collections::BTreeMap::new(),
3973 cx,
3974 );
3975 thread.add_default_tools(environment, cx);
3976 thread
3977 });
3978
3979 deep_subagent.read_with(cx, |thread, _| {
3980 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
3981 assert!(
3982 !thread.has_registered_tool("subagent"),
3983 "subagent tool should not be present at max depth"
3984 );
3985 });
3986}
3987
3988#[gpui::test]
3989async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
3990 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3991 let fake_model = model.as_fake();
3992
3993 cx.update(|cx| {
3994 cx.update_flags(true, vec!["subagents".to_string()]);
3995 });
3996
3997 let subagent_context = SubagentContext {
3998 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3999 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4000 depth: 1,
4001 summary_prompt: "Summarize your work".to_string(),
4002 context_low_prompt: "Context low, wrap up".to_string(),
4003 };
4004
4005 let project = thread.read_with(cx, |t, _| t.project.clone());
4006 let project_context = cx.new(|_cx| ProjectContext::default());
4007 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4008 let context_server_registry =
4009 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4010
4011 let subagent = cx.new(|cx| {
4012 Thread::new_subagent(
4013 project.clone(),
4014 project_context,
4015 context_server_registry,
4016 Templates::new(),
4017 model.clone(),
4018 subagent_context,
4019 std::collections::BTreeMap::new(),
4020 cx,
4021 )
4022 });
4023
4024 let task_prompt = "Find all TODO comments in the codebase";
4025 subagent
4026 .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4027 .unwrap();
4028 cx.run_until_parked();
4029
4030 let pending = fake_model.pending_completions();
4031 assert_eq!(pending.len(), 1, "should have one pending completion");
4032
4033 let messages = &pending[0].messages;
4034 let user_messages: Vec<_> = messages
4035 .iter()
4036 .filter(|m| m.role == language_model::Role::User)
4037 .collect();
4038 assert_eq!(user_messages.len(), 1, "should have one user message");
4039
4040 let content = &user_messages[0].content[0];
4041 assert!(
4042 content.to_str().unwrap().contains("TODO"),
4043 "task prompt should be in user message"
4044 );
4045}
4046
4047#[gpui::test]
4048async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4049 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4050 let fake_model = model.as_fake();
4051
4052 cx.update(|cx| {
4053 cx.update_flags(true, vec!["subagents".to_string()]);
4054 });
4055
4056 let subagent_context = SubagentContext {
4057 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4058 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4059 depth: 1,
4060 summary_prompt: "Please summarize what you found".to_string(),
4061 context_low_prompt: "Context low, wrap up".to_string(),
4062 };
4063
4064 let project = thread.read_with(cx, |t, _| t.project.clone());
4065 let project_context = cx.new(|_cx| ProjectContext::default());
4066 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4067 let context_server_registry =
4068 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4069
4070 let subagent = cx.new(|cx| {
4071 Thread::new_subagent(
4072 project.clone(),
4073 project_context,
4074 context_server_registry,
4075 Templates::new(),
4076 model.clone(),
4077 subagent_context,
4078 std::collections::BTreeMap::new(),
4079 cx,
4080 )
4081 });
4082
4083 subagent
4084 .update(cx, |thread, cx| {
4085 thread.submit_user_message("Do some work", cx)
4086 })
4087 .unwrap();
4088 cx.run_until_parked();
4089
4090 fake_model.send_last_completion_stream_text_chunk("I did the work");
4091 fake_model
4092 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4093 fake_model.end_last_completion_stream();
4094 cx.run_until_parked();
4095
4096 subagent
4097 .update(cx, |thread, cx| thread.request_final_summary(cx))
4098 .unwrap();
4099 cx.run_until_parked();
4100
4101 let pending = fake_model.pending_completions();
4102 assert!(
4103 !pending.is_empty(),
4104 "should have pending completion for summary"
4105 );
4106
4107 let messages = &pending.last().unwrap().messages;
4108 let user_messages: Vec<_> = messages
4109 .iter()
4110 .filter(|m| m.role == language_model::Role::User)
4111 .collect();
4112
4113 let last_user = user_messages.last().unwrap();
4114 assert!(
4115 last_user.content[0].to_str().unwrap().contains("summarize"),
4116 "summary prompt should be sent"
4117 );
4118}
4119
4120#[gpui::test]
4121async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4122 init_test(cx);
4123
4124 cx.update(|cx| {
4125 cx.update_flags(true, vec!["subagents".to_string()]);
4126 });
4127
4128 let fs = FakeFs::new(cx.executor());
4129 fs.insert_tree(path!("/test"), json!({})).await;
4130 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4131 let project_context = cx.new(|_cx| ProjectContext::default());
4132 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4133 let context_server_registry =
4134 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4135 let model = Arc::new(FakeLanguageModel::default());
4136
4137 let subagent_context = SubagentContext {
4138 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4139 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4140 depth: 1,
4141 summary_prompt: "Summarize".to_string(),
4142 context_low_prompt: "Context low".to_string(),
4143 };
4144
4145 let subagent = cx.new(|cx| {
4146 let mut thread = Thread::new_subagent(
4147 project.clone(),
4148 project_context,
4149 context_server_registry,
4150 Templates::new(),
4151 model.clone(),
4152 subagent_context,
4153 std::collections::BTreeMap::new(),
4154 cx,
4155 );
4156 thread.add_tool(EchoTool);
4157 thread.add_tool(DelayTool);
4158 thread.add_tool(WordListTool);
4159 thread
4160 });
4161
4162 subagent.read_with(cx, |thread, _| {
4163 assert!(thread.has_registered_tool("echo"));
4164 assert!(thread.has_registered_tool("delay"));
4165 assert!(thread.has_registered_tool("word_list"));
4166 });
4167
4168 let allowed: collections::HashSet<gpui::SharedString> =
4169 vec!["echo".into()].into_iter().collect();
4170
4171 subagent.update(cx, |thread, _cx| {
4172 thread.restrict_tools(&allowed);
4173 });
4174
4175 subagent.read_with(cx, |thread, _| {
4176 assert!(
4177 thread.has_registered_tool("echo"),
4178 "echo should still be available"
4179 );
4180 assert!(
4181 !thread.has_registered_tool("delay"),
4182 "delay should be removed"
4183 );
4184 assert!(
4185 !thread.has_registered_tool("word_list"),
4186 "word_list should be removed"
4187 );
4188 });
4189}
4190
4191#[gpui::test]
4192async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4193 init_test(cx);
4194
4195 cx.update(|cx| {
4196 cx.update_flags(true, vec!["subagents".to_string()]);
4197 });
4198
4199 let fs = FakeFs::new(cx.executor());
4200 fs.insert_tree(path!("/test"), json!({})).await;
4201 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4202 let project_context = cx.new(|_cx| ProjectContext::default());
4203 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4204 let context_server_registry =
4205 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4206 let model = Arc::new(FakeLanguageModel::default());
4207
4208 let parent = cx.new(|cx| {
4209 Thread::new(
4210 project.clone(),
4211 project_context.clone(),
4212 context_server_registry.clone(),
4213 Templates::new(),
4214 Some(model.clone()),
4215 cx,
4216 )
4217 });
4218
4219 let subagent_context = SubagentContext {
4220 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4221 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4222 depth: 1,
4223 summary_prompt: "Summarize".to_string(),
4224 context_low_prompt: "Context low".to_string(),
4225 };
4226
4227 let subagent = cx.new(|cx| {
4228 Thread::new_subagent(
4229 project.clone(),
4230 project_context.clone(),
4231 context_server_registry.clone(),
4232 Templates::new(),
4233 model.clone(),
4234 subagent_context,
4235 std::collections::BTreeMap::new(),
4236 cx,
4237 )
4238 });
4239
4240 parent.update(cx, |thread, _cx| {
4241 thread.register_running_subagent(subagent.downgrade());
4242 });
4243
4244 subagent
4245 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4246 .unwrap();
4247 cx.run_until_parked();
4248
4249 subagent.read_with(cx, |thread, _| {
4250 assert!(!thread.is_turn_complete(), "subagent should be running");
4251 });
4252
4253 parent.update(cx, |thread, cx| {
4254 thread.cancel(cx).detach();
4255 });
4256
4257 subagent.read_with(cx, |thread, _| {
4258 assert!(
4259 thread.is_turn_complete(),
4260 "subagent should be cancelled when parent cancels"
4261 );
4262 });
4263}
4264
4265#[gpui::test]
4266async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4267 // This test verifies that the subagent tool properly handles user cancellation
4268 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4269 init_test(cx);
4270 always_allow_tools(cx);
4271
4272 cx.update(|cx| {
4273 cx.update_flags(true, vec!["subagents".to_string()]);
4274 });
4275
4276 let fs = FakeFs::new(cx.executor());
4277 fs.insert_tree(path!("/test"), json!({})).await;
4278 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4279 let project_context = cx.new(|_cx| ProjectContext::default());
4280 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4281 let context_server_registry =
4282 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4283 let model = Arc::new(FakeLanguageModel::default());
4284
4285 let parent = cx.new(|cx| {
4286 Thread::new(
4287 project.clone(),
4288 project_context.clone(),
4289 context_server_registry.clone(),
4290 Templates::new(),
4291 Some(model.clone()),
4292 cx,
4293 )
4294 });
4295
4296 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4297 std::collections::BTreeMap::new();
4298
4299 #[allow(clippy::arc_with_non_send_sync)]
4300 let tool = Arc::new(SubagentTool::new(
4301 parent.downgrade(),
4302 project.clone(),
4303 project_context,
4304 context_server_registry,
4305 Templates::new(),
4306 0,
4307 parent_tools,
4308 ));
4309
4310 let (event_stream, _rx, mut cancellation_tx) =
4311 crate::ToolCallEventStream::test_with_cancellation();
4312
4313 // Start the subagent tool
4314 let task = cx.update(|cx| {
4315 tool.run(
4316 SubagentToolInput {
4317 subagents: vec![crate::SubagentConfig {
4318 label: "Long running task".to_string(),
4319 task_prompt: "Do a very long task that takes forever".to_string(),
4320 summary_prompt: "Summarize".to_string(),
4321 context_low_prompt: "Context low".to_string(),
4322 timeout_ms: None,
4323 allowed_tools: None,
4324 }],
4325 },
4326 event_stream.clone(),
4327 cx,
4328 )
4329 });
4330
4331 cx.run_until_parked();
4332
4333 // Signal cancellation via the event stream
4334 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4335
4336 // The task should complete promptly with a cancellation error
4337 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4338 let result = futures::select! {
4339 result = task.fuse() => result,
4340 _ = timeout.fuse() => {
4341 panic!("subagent tool did not respond to cancellation within timeout");
4342 }
4343 };
4344
4345 // Verify we got a cancellation error
4346 let err = result.unwrap_err();
4347 assert!(
4348 err.to_string().contains("cancelled by user"),
4349 "expected cancellation error, got: {}",
4350 err
4351 );
4352}
4353
4354#[gpui::test]
4355async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4356 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4357 let fake_model = model.as_fake();
4358
4359 cx.update(|cx| {
4360 cx.update_flags(true, vec!["subagents".to_string()]);
4361 });
4362
4363 let subagent_context = SubagentContext {
4364 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4365 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4366 depth: 1,
4367 summary_prompt: "Summarize".to_string(),
4368 context_low_prompt: "Context low".to_string(),
4369 };
4370
4371 let project = thread.read_with(cx, |t, _| t.project.clone());
4372 let project_context = cx.new(|_cx| ProjectContext::default());
4373 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4374 let context_server_registry =
4375 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4376
4377 let subagent = cx.new(|cx| {
4378 Thread::new_subagent(
4379 project.clone(),
4380 project_context,
4381 context_server_registry,
4382 Templates::new(),
4383 model.clone(),
4384 subagent_context,
4385 std::collections::BTreeMap::new(),
4386 cx,
4387 )
4388 });
4389
4390 subagent
4391 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4392 .unwrap();
4393 cx.run_until_parked();
4394
4395 subagent.read_with(cx, |thread, _| {
4396 assert!(!thread.is_turn_complete(), "turn should be in progress");
4397 });
4398
4399 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4400 provider: LanguageModelProviderName::from("Fake".to_string()),
4401 });
4402 fake_model.end_last_completion_stream();
4403 cx.run_until_parked();
4404
4405 subagent.read_with(cx, |thread, _| {
4406 assert!(
4407 thread.is_turn_complete(),
4408 "turn should be complete after non-retryable error"
4409 );
4410 });
4411}
4412
4413#[gpui::test]
4414async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4415 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4416 let fake_model = model.as_fake();
4417
4418 cx.update(|cx| {
4419 cx.update_flags(true, vec!["subagents".to_string()]);
4420 });
4421
4422 let subagent_context = SubagentContext {
4423 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4424 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4425 depth: 1,
4426 summary_prompt: "Summarize your work".to_string(),
4427 context_low_prompt: "Context low, stop and summarize".to_string(),
4428 };
4429
4430 let project = thread.read_with(cx, |t, _| t.project.clone());
4431 let project_context = cx.new(|_cx| ProjectContext::default());
4432 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4433 let context_server_registry =
4434 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4435
4436 let subagent = cx.new(|cx| {
4437 Thread::new_subagent(
4438 project.clone(),
4439 project_context.clone(),
4440 context_server_registry.clone(),
4441 Templates::new(),
4442 model.clone(),
4443 subagent_context.clone(),
4444 std::collections::BTreeMap::new(),
4445 cx,
4446 )
4447 });
4448
4449 subagent.update(cx, |thread, _| {
4450 thread.add_tool(EchoTool);
4451 });
4452
4453 subagent
4454 .update(cx, |thread, cx| {
4455 thread.submit_user_message("Do some work", cx)
4456 })
4457 .unwrap();
4458 cx.run_until_parked();
4459
4460 fake_model.send_last_completion_stream_text_chunk("Working on it...");
4461 fake_model
4462 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4463 fake_model.end_last_completion_stream();
4464 cx.run_until_parked();
4465
4466 let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4467 assert!(
4468 interrupt_result.is_ok(),
4469 "interrupt_for_summary should succeed"
4470 );
4471
4472 cx.run_until_parked();
4473
4474 let pending = fake_model.pending_completions();
4475 assert!(
4476 !pending.is_empty(),
4477 "should have pending completion for interrupted summary"
4478 );
4479
4480 let messages = &pending.last().unwrap().messages;
4481 let user_messages: Vec<_> = messages
4482 .iter()
4483 .filter(|m| m.role == language_model::Role::User)
4484 .collect();
4485
4486 let last_user = user_messages.last().unwrap();
4487 let content_str = last_user.content[0].to_str().unwrap();
4488 assert!(
4489 content_str.contains("Context low") || content_str.contains("stop and summarize"),
4490 "context_low_prompt should be sent when interrupting: got {:?}",
4491 content_str
4492 );
4493}
4494
4495#[gpui::test]
4496async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4497 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4498 let fake_model = model.as_fake();
4499
4500 cx.update(|cx| {
4501 cx.update_flags(true, vec!["subagents".to_string()]);
4502 });
4503
4504 let subagent_context = SubagentContext {
4505 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4506 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4507 depth: 1,
4508 summary_prompt: "Summarize".to_string(),
4509 context_low_prompt: "Context low".to_string(),
4510 };
4511
4512 let project = thread.read_with(cx, |t, _| t.project.clone());
4513 let project_context = cx.new(|_cx| ProjectContext::default());
4514 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4515 let context_server_registry =
4516 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4517
4518 let subagent = cx.new(|cx| {
4519 Thread::new_subagent(
4520 project.clone(),
4521 project_context,
4522 context_server_registry,
4523 Templates::new(),
4524 model.clone(),
4525 subagent_context,
4526 std::collections::BTreeMap::new(),
4527 cx,
4528 )
4529 });
4530
4531 subagent
4532 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4533 .unwrap();
4534 cx.run_until_parked();
4535
4536 let max_tokens = model.max_token_count();
4537 let high_usage = language_model::TokenUsage {
4538 input_tokens: (max_tokens as f64 * 0.80) as u64,
4539 output_tokens: 0,
4540 cache_creation_input_tokens: 0,
4541 cache_read_input_tokens: 0,
4542 };
4543
4544 fake_model
4545 .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4546 fake_model.send_last_completion_stream_text_chunk("Working...");
4547 fake_model
4548 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4549 fake_model.end_last_completion_stream();
4550 cx.run_until_parked();
4551
4552 let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4553 assert!(usage.is_some(), "should have token usage after completion");
4554
4555 let usage = usage.unwrap();
4556 let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4557 assert!(
4558 remaining_ratio <= 0.25,
4559 "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4560 remaining_ratio * 100.0
4561 );
4562}
4563
4564#[gpui::test]
4565async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4566 init_test(cx);
4567
4568 cx.update(|cx| {
4569 cx.update_flags(true, vec!["subagents".to_string()]);
4570 });
4571
4572 let fs = FakeFs::new(cx.executor());
4573 fs.insert_tree(path!("/test"), json!({})).await;
4574 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4575 let project_context = cx.new(|_cx| ProjectContext::default());
4576 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4577 let context_server_registry =
4578 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4579 let model = Arc::new(FakeLanguageModel::default());
4580
4581 let parent = cx.new(|cx| {
4582 let mut thread = Thread::new(
4583 project.clone(),
4584 project_context.clone(),
4585 context_server_registry.clone(),
4586 Templates::new(),
4587 Some(model.clone()),
4588 cx,
4589 );
4590 thread.add_tool(EchoTool);
4591 thread
4592 });
4593
4594 let mut parent_tools: std::collections::BTreeMap<
4595 gpui::SharedString,
4596 Arc<dyn crate::AnyAgentTool>,
4597 > = std::collections::BTreeMap::new();
4598 parent_tools.insert("echo".into(), EchoTool.erase());
4599
4600 #[allow(clippy::arc_with_non_send_sync)]
4601 let tool = Arc::new(SubagentTool::new(
4602 parent.downgrade(),
4603 project,
4604 project_context,
4605 context_server_registry,
4606 Templates::new(),
4607 0,
4608 parent_tools,
4609 ));
4610
4611 let subagent_configs = vec![crate::SubagentConfig {
4612 label: "Test".to_string(),
4613 task_prompt: "Do something".to_string(),
4614 summary_prompt: "Summarize".to_string(),
4615 context_low_prompt: "Context low".to_string(),
4616 timeout_ms: None,
4617 allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4618 }];
4619 let result = tool.validate_subagents(&subagent_configs);
4620 assert!(result.is_err(), "should reject unknown tool");
4621 let err_msg = result.unwrap_err().to_string();
4622 assert!(
4623 err_msg.contains("nonexistent_tool"),
4624 "error should mention the invalid tool name: {}",
4625 err_msg
4626 );
4627 assert!(
4628 err_msg.contains("do not exist"),
4629 "error should explain the tool does not exist: {}",
4630 err_msg
4631 );
4632}
4633
4634#[gpui::test]
4635async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4636 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4637 let fake_model = model.as_fake();
4638
4639 cx.update(|cx| {
4640 cx.update_flags(true, vec!["subagents".to_string()]);
4641 });
4642
4643 let subagent_context = SubagentContext {
4644 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4645 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4646 depth: 1,
4647 summary_prompt: "Summarize".to_string(),
4648 context_low_prompt: "Context low".to_string(),
4649 };
4650
4651 let project = thread.read_with(cx, |t, _| t.project.clone());
4652 let project_context = cx.new(|_cx| ProjectContext::default());
4653 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4654 let context_server_registry =
4655 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4656
4657 let subagent = cx.new(|cx| {
4658 Thread::new_subagent(
4659 project.clone(),
4660 project_context,
4661 context_server_registry,
4662 Templates::new(),
4663 model.clone(),
4664 subagent_context,
4665 std::collections::BTreeMap::new(),
4666 cx,
4667 )
4668 });
4669
4670 subagent
4671 .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4672 .unwrap();
4673 cx.run_until_parked();
4674
4675 fake_model
4676 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4677 fake_model.end_last_completion_stream();
4678 cx.run_until_parked();
4679
4680 subagent.read_with(cx, |thread, _| {
4681 assert!(
4682 thread.is_turn_complete(),
4683 "turn should complete even with empty response"
4684 );
4685 });
4686}
4687
4688#[gpui::test]
4689async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4690 init_test(cx);
4691
4692 cx.update(|cx| {
4693 cx.update_flags(true, vec!["subagents".to_string()]);
4694 });
4695
4696 let fs = FakeFs::new(cx.executor());
4697 fs.insert_tree(path!("/test"), json!({})).await;
4698 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4699 let project_context = cx.new(|_cx| ProjectContext::default());
4700 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4701 let context_server_registry =
4702 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4703 let model = Arc::new(FakeLanguageModel::default());
4704
4705 let depth_1_context = SubagentContext {
4706 parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4707 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4708 depth: 1,
4709 summary_prompt: "Summarize".to_string(),
4710 context_low_prompt: "Context low".to_string(),
4711 };
4712
4713 let depth_1_subagent = cx.new(|cx| {
4714 Thread::new_subagent(
4715 project.clone(),
4716 project_context.clone(),
4717 context_server_registry.clone(),
4718 Templates::new(),
4719 model.clone(),
4720 depth_1_context,
4721 std::collections::BTreeMap::new(),
4722 cx,
4723 )
4724 });
4725
4726 depth_1_subagent.read_with(cx, |thread, _| {
4727 assert_eq!(thread.depth(), 1);
4728 assert!(thread.is_subagent());
4729 });
4730
4731 let depth_2_context = SubagentContext {
4732 parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4733 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4734 depth: 2,
4735 summary_prompt: "Summarize depth 2".to_string(),
4736 context_low_prompt: "Context low depth 2".to_string(),
4737 };
4738
4739 let depth_2_subagent = cx.new(|cx| {
4740 Thread::new_subagent(
4741 project.clone(),
4742 project_context.clone(),
4743 context_server_registry.clone(),
4744 Templates::new(),
4745 model.clone(),
4746 depth_2_context,
4747 std::collections::BTreeMap::new(),
4748 cx,
4749 )
4750 });
4751
4752 depth_2_subagent.read_with(cx, |thread, _| {
4753 assert_eq!(thread.depth(), 2);
4754 assert!(thread.is_subagent());
4755 });
4756
4757 depth_2_subagent
4758 .update(cx, |thread, cx| {
4759 thread.submit_user_message("Nested task", cx)
4760 })
4761 .unwrap();
4762 cx.run_until_parked();
4763
4764 let pending = model.as_fake().pending_completions();
4765 assert!(
4766 !pending.is_empty(),
4767 "depth-2 subagent should be able to submit messages"
4768 );
4769}
4770
4771#[gpui::test]
4772async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4773 init_test(cx);
4774 always_allow_tools(cx);
4775
4776 cx.update(|cx| {
4777 cx.update_flags(true, vec!["subagents".to_string()]);
4778 });
4779
4780 let fs = FakeFs::new(cx.executor());
4781 fs.insert_tree(path!("/test"), json!({})).await;
4782 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4783 let project_context = cx.new(|_cx| ProjectContext::default());
4784 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4785 let context_server_registry =
4786 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4787 let model = Arc::new(FakeLanguageModel::default());
4788 let fake_model = model.as_fake();
4789
4790 let subagent_context = SubagentContext {
4791 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4792 tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4793 depth: 1,
4794 summary_prompt: "Summarize what you did".to_string(),
4795 context_low_prompt: "Context low".to_string(),
4796 };
4797
4798 let subagent = cx.new(|cx| {
4799 let mut thread = Thread::new_subagent(
4800 project.clone(),
4801 project_context,
4802 context_server_registry,
4803 Templates::new(),
4804 model.clone(),
4805 subagent_context,
4806 std::collections::BTreeMap::new(),
4807 cx,
4808 );
4809 thread.add_tool(EchoTool);
4810 thread
4811 });
4812
4813 subagent.read_with(cx, |thread, _| {
4814 assert!(
4815 thread.has_registered_tool("echo"),
4816 "subagent should have echo tool"
4817 );
4818 });
4819
4820 subagent
4821 .update(cx, |thread, cx| {
4822 thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4823 })
4824 .unwrap();
4825 cx.run_until_parked();
4826
4827 let tool_use = LanguageModelToolUse {
4828 id: "tool_call_1".into(),
4829 name: EchoTool::name().into(),
4830 raw_input: json!({"text": "hello world"}).to_string(),
4831 input: json!({"text": "hello world"}),
4832 is_input_complete: true,
4833 thought_signature: None,
4834 };
4835 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4836 fake_model.end_last_completion_stream();
4837 cx.run_until_parked();
4838
4839 let pending = fake_model.pending_completions();
4840 assert!(
4841 !pending.is_empty(),
4842 "should have pending completion after tool use"
4843 );
4844
4845 let last_completion = pending.last().unwrap();
4846 let has_tool_result = last_completion.messages.iter().any(|m| {
4847 m.content
4848 .iter()
4849 .any(|c| matches!(c, MessageContent::ToolResult(_)))
4850 });
4851 assert!(
4852 has_tool_result,
4853 "tool result should be in the messages sent back to the model"
4854 );
4855}
4856
4857#[gpui::test]
4858async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4859 init_test(cx);
4860
4861 cx.update(|cx| {
4862 cx.update_flags(true, vec!["subagents".to_string()]);
4863 });
4864
4865 let fs = FakeFs::new(cx.executor());
4866 fs.insert_tree(path!("/test"), json!({})).await;
4867 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4868 let project_context = cx.new(|_cx| ProjectContext::default());
4869 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4870 let context_server_registry =
4871 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4872 let model = Arc::new(FakeLanguageModel::default());
4873
4874 let parent = cx.new(|cx| {
4875 Thread::new(
4876 project.clone(),
4877 project_context.clone(),
4878 context_server_registry.clone(),
4879 Templates::new(),
4880 Some(model.clone()),
4881 cx,
4882 )
4883 });
4884
4885 let mut subagents = Vec::new();
4886 for i in 0..MAX_PARALLEL_SUBAGENTS {
4887 let subagent_context = SubagentContext {
4888 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4889 tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4890 depth: 1,
4891 summary_prompt: "Summarize".to_string(),
4892 context_low_prompt: "Context low".to_string(),
4893 };
4894
4895 let subagent = cx.new(|cx| {
4896 Thread::new_subagent(
4897 project.clone(),
4898 project_context.clone(),
4899 context_server_registry.clone(),
4900 Templates::new(),
4901 model.clone(),
4902 subagent_context,
4903 std::collections::BTreeMap::new(),
4904 cx,
4905 )
4906 });
4907
4908 parent.update(cx, |thread, _cx| {
4909 thread.register_running_subagent(subagent.downgrade());
4910 });
4911 subagents.push(subagent);
4912 }
4913
4914 parent.read_with(cx, |thread, _| {
4915 assert_eq!(
4916 thread.running_subagent_count(),
4917 MAX_PARALLEL_SUBAGENTS,
4918 "should have MAX_PARALLEL_SUBAGENTS registered"
4919 );
4920 });
4921
4922 let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4923 std::collections::BTreeMap::new();
4924
4925 #[allow(clippy::arc_with_non_send_sync)]
4926 let tool = Arc::new(SubagentTool::new(
4927 parent.downgrade(),
4928 project.clone(),
4929 project_context,
4930 context_server_registry,
4931 Templates::new(),
4932 0,
4933 parent_tools,
4934 ));
4935
4936 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4937
4938 let result = cx.update(|cx| {
4939 tool.run(
4940 SubagentToolInput {
4941 subagents: vec![crate::SubagentConfig {
4942 label: "Test".to_string(),
4943 task_prompt: "Do something".to_string(),
4944 summary_prompt: "Summarize".to_string(),
4945 context_low_prompt: "Context low".to_string(),
4946 timeout_ms: None,
4947 allowed_tools: None,
4948 }],
4949 },
4950 event_stream,
4951 cx,
4952 )
4953 });
4954
4955 let err = result.await.unwrap_err();
4956 assert!(
4957 err.to_string().contains("Maximum parallel subagents"),
4958 "should reject when max parallel subagents reached: {}",
4959 err
4960 );
4961
4962 drop(subagents);
4963}
4964
4965#[gpui::test]
4966async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
4967 init_test(cx);
4968 always_allow_tools(cx);
4969
4970 cx.update(|cx| {
4971 cx.update_flags(true, vec!["subagents".to_string()]);
4972 });
4973
4974 let fs = FakeFs::new(cx.executor());
4975 fs.insert_tree(path!("/test"), json!({})).await;
4976 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4977 let project_context = cx.new(|_cx| ProjectContext::default());
4978 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4979 let context_server_registry =
4980 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4981 let model = Arc::new(FakeLanguageModel::default());
4982 let fake_model = model.as_fake();
4983
4984 let parent = cx.new(|cx| {
4985 let mut thread = Thread::new(
4986 project.clone(),
4987 project_context.clone(),
4988 context_server_registry.clone(),
4989 Templates::new(),
4990 Some(model.clone()),
4991 cx,
4992 );
4993 thread.add_tool(EchoTool);
4994 thread
4995 });
4996
4997 let mut parent_tools: std::collections::BTreeMap<
4998 gpui::SharedString,
4999 Arc<dyn crate::AnyAgentTool>,
5000 > = std::collections::BTreeMap::new();
5001 parent_tools.insert("echo".into(), EchoTool.erase());
5002
5003 #[allow(clippy::arc_with_non_send_sync)]
5004 let tool = Arc::new(SubagentTool::new(
5005 parent.downgrade(),
5006 project.clone(),
5007 project_context,
5008 context_server_registry,
5009 Templates::new(),
5010 0,
5011 parent_tools,
5012 ));
5013
5014 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5015
5016 let task = cx.update(|cx| {
5017 tool.run(
5018 SubagentToolInput {
5019 subagents: vec![crate::SubagentConfig {
5020 label: "Research task".to_string(),
5021 task_prompt: "Find all TODOs in the codebase".to_string(),
5022 summary_prompt: "Summarize what you found".to_string(),
5023 context_low_prompt: "Context low, wrap up".to_string(),
5024 timeout_ms: None,
5025 allowed_tools: None,
5026 }],
5027 },
5028 event_stream,
5029 cx,
5030 )
5031 });
5032
5033 cx.run_until_parked();
5034
5035 let pending = fake_model.pending_completions();
5036 assert!(
5037 !pending.is_empty(),
5038 "subagent should have started and sent a completion request"
5039 );
5040
5041 let first_completion = &pending[0];
5042 let has_task_prompt = first_completion.messages.iter().any(|m| {
5043 m.role == language_model::Role::User
5044 && m.content
5045 .iter()
5046 .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
5047 });
5048 assert!(has_task_prompt, "task prompt should be sent to subagent");
5049
5050 fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
5051 fake_model
5052 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5053 fake_model.end_last_completion_stream();
5054 cx.run_until_parked();
5055
5056 let pending = fake_model.pending_completions();
5057 assert!(
5058 !pending.is_empty(),
5059 "should have pending completion for summary request"
5060 );
5061
5062 let last_completion = pending.last().unwrap();
5063 let has_summary_prompt = last_completion.messages.iter().any(|m| {
5064 m.role == language_model::Role::User
5065 && m.content.iter().any(|c| {
5066 c.to_str()
5067 .map(|s| s.contains("Summarize") || s.contains("summarize"))
5068 .unwrap_or(false)
5069 })
5070 });
5071 assert!(
5072 has_summary_prompt,
5073 "summary prompt should be sent after task completion"
5074 );
5075
5076 fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5077 fake_model
5078 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5079 fake_model.end_last_completion_stream();
5080 cx.run_until_parked();
5081
5082 let result = task.await;
5083 assert!(result.is_ok(), "subagent tool should complete successfully");
5084
5085 let summary = result.unwrap();
5086 assert!(
5087 summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5088 "summary should contain subagent's response: {}",
5089 summary
5090 );
5091}
5092
5093#[gpui::test]
5094async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5095 init_test(cx);
5096
5097 let fs = FakeFs::new(cx.executor());
5098 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5099 .await;
5100 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5101
5102 cx.update(|cx| {
5103 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5104 settings.tool_permissions.tools.insert(
5105 "edit_file".into(),
5106 agent_settings::ToolRules {
5107 default_mode: settings::ToolPermissionMode::Allow,
5108 always_allow: vec![],
5109 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5110 always_confirm: vec![],
5111 invalid_patterns: vec![],
5112 },
5113 );
5114 agent_settings::AgentSettings::override_global(settings, cx);
5115 });
5116
5117 let context_server_registry =
5118 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5119 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5120 let templates = crate::Templates::new();
5121 let thread = cx.new(|cx| {
5122 crate::Thread::new(
5123 project.clone(),
5124 cx.new(|_cx| prompt_store::ProjectContext::default()),
5125 context_server_registry,
5126 templates.clone(),
5127 None,
5128 cx,
5129 )
5130 });
5131
5132 #[allow(clippy::arc_with_non_send_sync)]
5133 let tool = Arc::new(crate::EditFileTool::new(
5134 project.clone(),
5135 thread.downgrade(),
5136 language_registry,
5137 templates,
5138 ));
5139 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5140
5141 let task = cx.update(|cx| {
5142 tool.run(
5143 crate::EditFileToolInput {
5144 display_description: "Edit sensitive file".to_string(),
5145 path: "root/sensitive_config.txt".into(),
5146 mode: crate::EditFileMode::Edit,
5147 },
5148 event_stream,
5149 cx,
5150 )
5151 });
5152
5153 let result = task.await;
5154 assert!(result.is_err(), "expected edit to be blocked");
5155 assert!(
5156 result.unwrap_err().to_string().contains("blocked"),
5157 "error should mention the edit was blocked"
5158 );
5159}
5160
5161#[gpui::test]
5162async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5163 init_test(cx);
5164
5165 let fs = FakeFs::new(cx.executor());
5166 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5167 .await;
5168 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5169
5170 cx.update(|cx| {
5171 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5172 settings.tool_permissions.tools.insert(
5173 "delete_path".into(),
5174 agent_settings::ToolRules {
5175 default_mode: settings::ToolPermissionMode::Allow,
5176 always_allow: vec![],
5177 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5178 always_confirm: vec![],
5179 invalid_patterns: vec![],
5180 },
5181 );
5182 agent_settings::AgentSettings::override_global(settings, cx);
5183 });
5184
5185 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5186
5187 #[allow(clippy::arc_with_non_send_sync)]
5188 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5189 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5190
5191 let task = cx.update(|cx| {
5192 tool.run(
5193 crate::DeletePathToolInput {
5194 path: "root/important_data.txt".to_string(),
5195 },
5196 event_stream,
5197 cx,
5198 )
5199 });
5200
5201 let result = task.await;
5202 assert!(result.is_err(), "expected deletion to be blocked");
5203 assert!(
5204 result.unwrap_err().to_string().contains("blocked"),
5205 "error should mention the deletion was blocked"
5206 );
5207}
5208
5209#[gpui::test]
5210async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5211 init_test(cx);
5212
5213 let fs = FakeFs::new(cx.executor());
5214 fs.insert_tree(
5215 "/root",
5216 json!({
5217 "safe.txt": "content",
5218 "protected": {}
5219 }),
5220 )
5221 .await;
5222 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5223
5224 cx.update(|cx| {
5225 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5226 settings.tool_permissions.tools.insert(
5227 "move_path".into(),
5228 agent_settings::ToolRules {
5229 default_mode: settings::ToolPermissionMode::Allow,
5230 always_allow: vec![],
5231 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5232 always_confirm: vec![],
5233 invalid_patterns: vec![],
5234 },
5235 );
5236 agent_settings::AgentSettings::override_global(settings, cx);
5237 });
5238
5239 #[allow(clippy::arc_with_non_send_sync)]
5240 let tool = Arc::new(crate::MovePathTool::new(project));
5241 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5242
5243 let task = cx.update(|cx| {
5244 tool.run(
5245 crate::MovePathToolInput {
5246 source_path: "root/safe.txt".to_string(),
5247 destination_path: "root/protected/safe.txt".to_string(),
5248 },
5249 event_stream,
5250 cx,
5251 )
5252 });
5253
5254 let result = task.await;
5255 assert!(
5256 result.is_err(),
5257 "expected move to be blocked due to destination path"
5258 );
5259 assert!(
5260 result.unwrap_err().to_string().contains("blocked"),
5261 "error should mention the move was blocked"
5262 );
5263}
5264
5265#[gpui::test]
5266async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5267 init_test(cx);
5268
5269 let fs = FakeFs::new(cx.executor());
5270 fs.insert_tree(
5271 "/root",
5272 json!({
5273 "secret.txt": "secret content",
5274 "public": {}
5275 }),
5276 )
5277 .await;
5278 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5279
5280 cx.update(|cx| {
5281 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5282 settings.tool_permissions.tools.insert(
5283 "move_path".into(),
5284 agent_settings::ToolRules {
5285 default_mode: settings::ToolPermissionMode::Allow,
5286 always_allow: vec![],
5287 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5288 always_confirm: vec![],
5289 invalid_patterns: vec![],
5290 },
5291 );
5292 agent_settings::AgentSettings::override_global(settings, cx);
5293 });
5294
5295 #[allow(clippy::arc_with_non_send_sync)]
5296 let tool = Arc::new(crate::MovePathTool::new(project));
5297 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5298
5299 let task = cx.update(|cx| {
5300 tool.run(
5301 crate::MovePathToolInput {
5302 source_path: "root/secret.txt".to_string(),
5303 destination_path: "root/public/not_secret.txt".to_string(),
5304 },
5305 event_stream,
5306 cx,
5307 )
5308 });
5309
5310 let result = task.await;
5311 assert!(
5312 result.is_err(),
5313 "expected move to be blocked due to source path"
5314 );
5315 assert!(
5316 result.unwrap_err().to_string().contains("blocked"),
5317 "error should mention the move was blocked"
5318 );
5319}
5320
5321#[gpui::test]
5322async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5323 init_test(cx);
5324
5325 let fs = FakeFs::new(cx.executor());
5326 fs.insert_tree(
5327 "/root",
5328 json!({
5329 "confidential.txt": "confidential data",
5330 "dest": {}
5331 }),
5332 )
5333 .await;
5334 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5335
5336 cx.update(|cx| {
5337 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5338 settings.tool_permissions.tools.insert(
5339 "copy_path".into(),
5340 agent_settings::ToolRules {
5341 default_mode: settings::ToolPermissionMode::Allow,
5342 always_allow: vec![],
5343 always_deny: vec![
5344 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5345 ],
5346 always_confirm: vec![],
5347 invalid_patterns: vec![],
5348 },
5349 );
5350 agent_settings::AgentSettings::override_global(settings, cx);
5351 });
5352
5353 #[allow(clippy::arc_with_non_send_sync)]
5354 let tool = Arc::new(crate::CopyPathTool::new(project));
5355 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5356
5357 let task = cx.update(|cx| {
5358 tool.run(
5359 crate::CopyPathToolInput {
5360 source_path: "root/confidential.txt".to_string(),
5361 destination_path: "root/dest/copy.txt".to_string(),
5362 },
5363 event_stream,
5364 cx,
5365 )
5366 });
5367
5368 let result = task.await;
5369 assert!(result.is_err(), "expected copy to be blocked");
5370 assert!(
5371 result.unwrap_err().to_string().contains("blocked"),
5372 "error should mention the copy was blocked"
5373 );
5374}
5375
5376#[gpui::test]
5377async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5378 init_test(cx);
5379
5380 let fs = FakeFs::new(cx.executor());
5381 fs.insert_tree(
5382 "/root",
5383 json!({
5384 "normal.txt": "normal content",
5385 "readonly": {
5386 "config.txt": "readonly content"
5387 }
5388 }),
5389 )
5390 .await;
5391 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5392
5393 cx.update(|cx| {
5394 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5395 settings.tool_permissions.tools.insert(
5396 "save_file".into(),
5397 agent_settings::ToolRules {
5398 default_mode: settings::ToolPermissionMode::Allow,
5399 always_allow: vec![],
5400 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5401 always_confirm: vec![],
5402 invalid_patterns: vec![],
5403 },
5404 );
5405 agent_settings::AgentSettings::override_global(settings, cx);
5406 });
5407
5408 #[allow(clippy::arc_with_non_send_sync)]
5409 let tool = Arc::new(crate::SaveFileTool::new(project));
5410 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5411
5412 let task = cx.update(|cx| {
5413 tool.run(
5414 crate::SaveFileToolInput {
5415 paths: vec![
5416 std::path::PathBuf::from("root/normal.txt"),
5417 std::path::PathBuf::from("root/readonly/config.txt"),
5418 ],
5419 },
5420 event_stream,
5421 cx,
5422 )
5423 });
5424
5425 let result = task.await;
5426 assert!(
5427 result.is_err(),
5428 "expected save to be blocked due to denied path"
5429 );
5430 assert!(
5431 result.unwrap_err().to_string().contains("blocked"),
5432 "error should mention the save was blocked"
5433 );
5434}
5435
5436#[gpui::test]
5437async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5438 init_test(cx);
5439
5440 let fs = FakeFs::new(cx.executor());
5441 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5442 .await;
5443 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5444
5445 cx.update(|cx| {
5446 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5447 settings.always_allow_tool_actions = false;
5448 settings.tool_permissions.tools.insert(
5449 "save_file".into(),
5450 agent_settings::ToolRules {
5451 default_mode: settings::ToolPermissionMode::Allow,
5452 always_allow: vec![],
5453 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5454 always_confirm: vec![],
5455 invalid_patterns: vec![],
5456 },
5457 );
5458 agent_settings::AgentSettings::override_global(settings, cx);
5459 });
5460
5461 #[allow(clippy::arc_with_non_send_sync)]
5462 let tool = Arc::new(crate::SaveFileTool::new(project));
5463 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5464
5465 let task = cx.update(|cx| {
5466 tool.run(
5467 crate::SaveFileToolInput {
5468 paths: vec![std::path::PathBuf::from("root/config.secret")],
5469 },
5470 event_stream,
5471 cx,
5472 )
5473 });
5474
5475 let result = task.await;
5476 assert!(result.is_err(), "expected save to be blocked");
5477 assert!(
5478 result.unwrap_err().to_string().contains("blocked"),
5479 "error should mention the save was blocked"
5480 );
5481}
5482
5483#[gpui::test]
5484async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5485 init_test(cx);
5486
5487 cx.update(|cx| {
5488 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5489 settings.tool_permissions.tools.insert(
5490 "web_search".into(),
5491 agent_settings::ToolRules {
5492 default_mode: settings::ToolPermissionMode::Allow,
5493 always_allow: vec![],
5494 always_deny: vec![
5495 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5496 ],
5497 always_confirm: vec![],
5498 invalid_patterns: vec![],
5499 },
5500 );
5501 agent_settings::AgentSettings::override_global(settings, cx);
5502 });
5503
5504 #[allow(clippy::arc_with_non_send_sync)]
5505 let tool = Arc::new(crate::WebSearchTool);
5506 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5507
5508 let input: crate::WebSearchToolInput =
5509 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5510
5511 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5512
5513 let result = task.await;
5514 assert!(result.is_err(), "expected search to be blocked");
5515 assert!(
5516 result.unwrap_err().to_string().contains("blocked"),
5517 "error should mention the search was blocked"
5518 );
5519}
5520
5521#[gpui::test]
5522async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5523 init_test(cx);
5524
5525 let fs = FakeFs::new(cx.executor());
5526 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5527 .await;
5528 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5529
5530 cx.update(|cx| {
5531 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5532 settings.always_allow_tool_actions = false;
5533 settings.tool_permissions.tools.insert(
5534 "edit_file".into(),
5535 agent_settings::ToolRules {
5536 default_mode: settings::ToolPermissionMode::Confirm,
5537 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5538 always_deny: vec![],
5539 always_confirm: vec![],
5540 invalid_patterns: vec![],
5541 },
5542 );
5543 agent_settings::AgentSettings::override_global(settings, cx);
5544 });
5545
5546 let context_server_registry =
5547 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5548 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5549 let templates = crate::Templates::new();
5550 let thread = cx.new(|cx| {
5551 crate::Thread::new(
5552 project.clone(),
5553 cx.new(|_cx| prompt_store::ProjectContext::default()),
5554 context_server_registry,
5555 templates.clone(),
5556 None,
5557 cx,
5558 )
5559 });
5560
5561 #[allow(clippy::arc_with_non_send_sync)]
5562 let tool = Arc::new(crate::EditFileTool::new(
5563 project,
5564 thread.downgrade(),
5565 language_registry,
5566 templates,
5567 ));
5568 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5569
5570 let _task = cx.update(|cx| {
5571 tool.run(
5572 crate::EditFileToolInput {
5573 display_description: "Edit README".to_string(),
5574 path: "root/README.md".into(),
5575 mode: crate::EditFileMode::Edit,
5576 },
5577 event_stream,
5578 cx,
5579 )
5580 });
5581
5582 cx.run_until_parked();
5583
5584 let event = rx.try_next();
5585 assert!(
5586 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5587 "expected no authorization request for allowed .md file"
5588 );
5589}
5590
5591#[gpui::test]
5592async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5593 init_test(cx);
5594
5595 cx.update(|cx| {
5596 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5597 settings.tool_permissions.tools.insert(
5598 "fetch".into(),
5599 agent_settings::ToolRules {
5600 default_mode: settings::ToolPermissionMode::Allow,
5601 always_allow: vec![],
5602 always_deny: vec![
5603 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5604 ],
5605 always_confirm: vec![],
5606 invalid_patterns: vec![],
5607 },
5608 );
5609 agent_settings::AgentSettings::override_global(settings, cx);
5610 });
5611
5612 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5613
5614 #[allow(clippy::arc_with_non_send_sync)]
5615 let tool = Arc::new(crate::FetchTool::new(http_client));
5616 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5617
5618 let input: crate::FetchToolInput =
5619 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5620
5621 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5622
5623 let result = task.await;
5624 assert!(result.is_err(), "expected fetch to be blocked");
5625 assert!(
5626 result.unwrap_err().to_string().contains("blocked"),
5627 "error should mention the fetch was blocked"
5628 );
5629}
5630
5631#[gpui::test]
5632async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5633 init_test(cx);
5634
5635 cx.update(|cx| {
5636 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5637 settings.always_allow_tool_actions = false;
5638 settings.tool_permissions.tools.insert(
5639 "fetch".into(),
5640 agent_settings::ToolRules {
5641 default_mode: settings::ToolPermissionMode::Confirm,
5642 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5643 always_deny: vec![],
5644 always_confirm: vec![],
5645 invalid_patterns: vec![],
5646 },
5647 );
5648 agent_settings::AgentSettings::override_global(settings, cx);
5649 });
5650
5651 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5652
5653 #[allow(clippy::arc_with_non_send_sync)]
5654 let tool = Arc::new(crate::FetchTool::new(http_client));
5655 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5656
5657 let input: crate::FetchToolInput =
5658 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5659
5660 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5661
5662 cx.run_until_parked();
5663
5664 let event = rx.try_next();
5665 assert!(
5666 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5667 "expected no authorization request for allowed docs.rs URL"
5668 );
5669}
5670
5671#[gpui::test]
5672async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5673 init_test(cx);
5674 always_allow_tools(cx);
5675
5676 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5677 let fake_model = model.as_fake();
5678
5679 // Add a tool so we can simulate tool calls
5680 thread.update(cx, |thread, _cx| {
5681 thread.add_tool(EchoTool);
5682 });
5683
5684 // Start a turn by sending a message
5685 let mut events = thread
5686 .update(cx, |thread, cx| {
5687 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5688 })
5689 .unwrap();
5690 cx.run_until_parked();
5691
5692 // Simulate the model making a tool call
5693 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5694 LanguageModelToolUse {
5695 id: "tool_1".into(),
5696 name: "echo".into(),
5697 raw_input: r#"{"text": "hello"}"#.into(),
5698 input: json!({"text": "hello"}),
5699 is_input_complete: true,
5700 thought_signature: None,
5701 },
5702 ));
5703 fake_model
5704 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5705
5706 // Queue a message before ending the stream
5707 thread.update(cx, |thread, _cx| {
5708 thread.queue_message(
5709 vec![acp::ContentBlock::Text(acp::TextContent::new(
5710 "This is my queued message".to_string(),
5711 ))],
5712 vec![],
5713 );
5714 });
5715
5716 // Now end the stream - tool will run, and the boundary check should see the queue
5717 fake_model.end_last_completion_stream();
5718
5719 // Collect all events until the turn stops
5720 let all_events = collect_events_until_stop(&mut events, cx).await;
5721
5722 // Verify we received the tool call event
5723 let tool_call_ids: Vec<_> = all_events
5724 .iter()
5725 .filter_map(|e| match e {
5726 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5727 _ => None,
5728 })
5729 .collect();
5730 assert_eq!(
5731 tool_call_ids,
5732 vec!["tool_1"],
5733 "Should have received a tool call event for our echo tool"
5734 );
5735
5736 // The turn should have stopped with EndTurn
5737 let stop_reasons = stop_events(all_events);
5738 assert_eq!(
5739 stop_reasons,
5740 vec![acp::StopReason::EndTurn],
5741 "Turn should have ended after tool completion due to queued message"
5742 );
5743
5744 // Verify the queued message is still there
5745 thread.update(cx, |thread, _cx| {
5746 let queued = thread.queued_messages();
5747 assert_eq!(queued.len(), 1, "Should still have one queued message");
5748 assert!(matches!(
5749 &queued[0].content[0],
5750 acp::ContentBlock::Text(t) if t.text == "This is my queued message"
5751 ));
5752 });
5753
5754 // Thread should be idle now
5755 thread.update(cx, |thread, _cx| {
5756 assert!(
5757 thread.is_turn_complete(),
5758 "Thread should not be running after turn ends"
5759 );
5760 });
5761}