1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeSubagentHandle {
159 session_id: acp::SessionId,
160 wait_for_summary_task: Shared<Task<String>>,
161}
162
163impl FakeSubagentHandle {
164 fn new_never_completes(cx: &App) -> Self {
165 Self {
166 session_id: acp::SessionId::new("subagent-id"),
167 wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
168 }
169 }
170}
171
172impl SubagentHandle for FakeSubagentHandle {
173 fn id(&self) -> acp::SessionId {
174 self.session_id.clone()
175 }
176
177 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
178 let task = self.wait_for_summary_task.clone();
179 cx.background_spawn(async move { Ok(task.await) })
180 }
181}
182
183#[derive(Default)]
184struct FakeThreadEnvironment {
185 terminal_handle: Option<Rc<FakeTerminalHandle>>,
186 subagent_handle: Option<Rc<FakeSubagentHandle>>,
187}
188
189impl FakeThreadEnvironment {
190 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
191 Self {
192 terminal_handle: Some(terminal_handle.into()),
193 ..self
194 }
195 }
196
197 pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
198 Self {
199 subagent_handle: Some(subagent_handle.into()),
200 ..self
201 }
202 }
203}
204
205impl crate::ThreadEnvironment for FakeThreadEnvironment {
206 fn create_terminal(
207 &self,
208 _command: String,
209 _cwd: Option<std::path::PathBuf>,
210 _output_byte_limit: Option<u64>,
211 _cx: &mut AsyncApp,
212 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
213 let handle = self
214 .terminal_handle
215 .clone()
216 .expect("Terminal handle not available on FakeThreadEnvironment");
217 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
218 }
219
220 fn create_subagent(
221 &self,
222 _parent_thread: Entity<Thread>,
223 _label: String,
224 _initial_prompt: String,
225 _timeout_ms: Option<Duration>,
226 _allowed_tools: Option<Vec<String>>,
227 _cx: &mut App,
228 ) -> Result<Rc<dyn SubagentHandle>> {
229 Ok(self
230 .subagent_handle
231 .clone()
232 .expect("Subagent handle not available on FakeThreadEnvironment")
233 as Rc<dyn SubagentHandle>)
234 }
235}
236
237/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
238struct MultiTerminalEnvironment {
239 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
240}
241
242impl MultiTerminalEnvironment {
243 fn new() -> Self {
244 Self {
245 handles: std::cell::RefCell::new(Vec::new()),
246 }
247 }
248
249 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
250 self.handles.borrow().clone()
251 }
252}
253
254impl crate::ThreadEnvironment for MultiTerminalEnvironment {
255 fn create_terminal(
256 &self,
257 _command: String,
258 _cwd: Option<std::path::PathBuf>,
259 _output_byte_limit: Option<u64>,
260 cx: &mut AsyncApp,
261 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
263 self.handles.borrow_mut().push(handle.clone());
264 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
265 }
266
267 fn create_subagent(
268 &self,
269 _parent_thread: Entity<Thread>,
270 _label: String,
271 _initial_prompt: String,
272 _timeout: Option<Duration>,
273 _allowed_tools: Option<Vec<String>>,
274 _cx: &mut App,
275 ) -> Result<Rc<dyn SubagentHandle>> {
276 unimplemented!()
277 }
278}
279
280fn always_allow_tools(cx: &mut TestAppContext) {
281 cx.update(|cx| {
282 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
283 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
284 agent_settings::AgentSettings::override_global(settings, cx);
285 });
286}
287
288#[gpui::test]
289async fn test_echo(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
296 })
297 .unwrap();
298 cx.run_until_parked();
299 fake_model.send_last_completion_stream_text_chunk("Hello");
300 fake_model
301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
302 fake_model.end_last_completion_stream();
303
304 let events = events.collect().await;
305 thread.update(cx, |thread, _cx| {
306 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
307 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
308 });
309 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
310}
311
312#[gpui::test]
313async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
314 init_test(cx);
315 always_allow_tools(cx);
316
317 let fs = FakeFs::new(cx.executor());
318 let project = Project::test(fs, [], cx).await;
319
320 let environment = Rc::new(cx.update(|cx| {
321 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
322 }));
323 let handle = environment.terminal_handle.clone().unwrap();
324
325 #[allow(clippy::arc_with_non_send_sync)]
326 let tool = Arc::new(crate::TerminalTool::new(project, environment));
327 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
328
329 let task = cx.update(|cx| {
330 tool.run(
331 crate::TerminalToolInput {
332 command: "sleep 1000".to_string(),
333 cd: ".".to_string(),
334 timeout_ms: Some(5),
335 },
336 event_stream,
337 cx,
338 )
339 });
340
341 let update = rx.expect_update_fields().await;
342 assert!(
343 update.content.iter().any(|blocks| {
344 blocks
345 .iter()
346 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
347 }),
348 "expected tool call update to include terminal content"
349 );
350
351 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
352
353 let deadline = std::time::Instant::now() + Duration::from_millis(500);
354 loop {
355 if let Some(result) = task_future.as_mut().now_or_never() {
356 let result = result.expect("terminal tool task should complete");
357
358 assert!(
359 handle.was_killed(),
360 "expected terminal handle to be killed on timeout"
361 );
362 assert!(
363 result.contains("partial output"),
364 "expected result to include terminal output, got: {result}"
365 );
366 return;
367 }
368
369 if std::time::Instant::now() >= deadline {
370 panic!("timed out waiting for terminal tool task to complete");
371 }
372
373 cx.run_until_parked();
374 cx.background_executor.timer(Duration::from_millis(1)).await;
375 }
376}
377
378#[gpui::test]
379#[ignore]
380async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
381 init_test(cx);
382 always_allow_tools(cx);
383
384 let fs = FakeFs::new(cx.executor());
385 let project = Project::test(fs, [], cx).await;
386
387 let environment = Rc::new(cx.update(|cx| {
388 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
389 }));
390 let handle = environment.terminal_handle.clone().unwrap();
391
392 #[allow(clippy::arc_with_non_send_sync)]
393 let tool = Arc::new(crate::TerminalTool::new(project, environment));
394 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
395
396 let _task = cx.update(|cx| {
397 tool.run(
398 crate::TerminalToolInput {
399 command: "sleep 1000".to_string(),
400 cd: ".".to_string(),
401 timeout_ms: None,
402 },
403 event_stream,
404 cx,
405 )
406 });
407
408 let update = rx.expect_update_fields().await;
409 assert!(
410 update.content.iter().any(|blocks| {
411 blocks
412 .iter()
413 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
414 }),
415 "expected tool call update to include terminal content"
416 );
417
418 cx.background_executor
419 .timer(Duration::from_millis(25))
420 .await;
421
422 assert!(
423 !handle.was_killed(),
424 "did not expect terminal handle to be killed without a timeout"
425 );
426}
427
428#[gpui::test]
429async fn test_thinking(cx: &mut TestAppContext) {
430 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
431 let fake_model = model.as_fake();
432
433 let events = thread
434 .update(cx, |thread, cx| {
435 thread.send(
436 UserMessageId::new(),
437 [indoc! {"
438 Testing:
439
440 Generate a thinking step where you just think the word 'Think',
441 and have your final answer be 'Hello'
442 "}],
443 cx,
444 )
445 })
446 .unwrap();
447 cx.run_until_parked();
448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
449 text: "Think".to_string(),
450 signature: None,
451 });
452 fake_model.send_last_completion_stream_text_chunk("Hello");
453 fake_model
454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
455 fake_model.end_last_completion_stream();
456
457 let events = events.collect().await;
458 thread.update(cx, |thread, _cx| {
459 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
460 assert_eq!(
461 thread.last_message().unwrap().to_markdown(),
462 indoc! {"
463 <think>Think</think>
464 Hello
465 "}
466 )
467 });
468 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
469}
470
471#[gpui::test]
472async fn test_system_prompt(cx: &mut TestAppContext) {
473 let ThreadTest {
474 model,
475 thread,
476 project_context,
477 ..
478 } = setup(cx, TestModel::Fake).await;
479 let fake_model = model.as_fake();
480
481 project_context.update(cx, |project_context, _cx| {
482 project_context.shell = "test-shell".into()
483 });
484 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["abc"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491 let mut pending_completions = fake_model.pending_completions();
492 assert_eq!(
493 pending_completions.len(),
494 1,
495 "unexpected pending completions: {:?}",
496 pending_completions
497 );
498
499 let pending_completion = pending_completions.pop().unwrap();
500 assert_eq!(pending_completion.messages[0].role, Role::System);
501
502 let system_message = &pending_completion.messages[0];
503 let system_prompt = system_message.content[0].to_str().unwrap();
504 assert!(
505 system_prompt.contains("test-shell"),
506 "unexpected system message: {:?}",
507 system_message
508 );
509 assert!(
510 system_prompt.contains("## Fixing Diagnostics"),
511 "unexpected system message: {:?}",
512 system_message
513 );
514}
515
516#[gpui::test]
517async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
518 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
519 let fake_model = model.as_fake();
520
521 thread
522 .update(cx, |thread, cx| {
523 thread.send(UserMessageId::new(), ["abc"], cx)
524 })
525 .unwrap();
526 cx.run_until_parked();
527 let mut pending_completions = fake_model.pending_completions();
528 assert_eq!(
529 pending_completions.len(),
530 1,
531 "unexpected pending completions: {:?}",
532 pending_completions
533 );
534
535 let pending_completion = pending_completions.pop().unwrap();
536 assert_eq!(pending_completion.messages[0].role, Role::System);
537
538 let system_message = &pending_completion.messages[0];
539 let system_prompt = system_message.content[0].to_str().unwrap();
540 assert!(
541 !system_prompt.contains("## Tool Use"),
542 "unexpected system message: {:?}",
543 system_message
544 );
545 assert!(
546 !system_prompt.contains("## Fixing Diagnostics"),
547 "unexpected system message: {:?}",
548 system_message
549 );
550}
551
552#[gpui::test]
553async fn test_prompt_caching(cx: &mut TestAppContext) {
554 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
555 let fake_model = model.as_fake();
556
557 // Send initial user message and verify it's cached
558 thread
559 .update(cx, |thread, cx| {
560 thread.send(UserMessageId::new(), ["Message 1"], cx)
561 })
562 .unwrap();
563 cx.run_until_parked();
564
565 let completion = fake_model.pending_completions().pop().unwrap();
566 assert_eq!(
567 completion.messages[1..],
568 vec![LanguageModelRequestMessage {
569 role: Role::User,
570 content: vec!["Message 1".into()],
571 cache: true,
572 reasoning_details: None,
573 }]
574 );
575 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
576 "Response to Message 1".into(),
577 ));
578 fake_model.end_last_completion_stream();
579 cx.run_until_parked();
580
581 // Send another user message and verify only the latest is cached
582 thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["Message 2"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588
589 let completion = fake_model.pending_completions().pop().unwrap();
590 assert_eq!(
591 completion.messages[1..],
592 vec![
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 1".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 1".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Message 2".into()],
608 cache: true,
609 reasoning_details: None,
610 }
611 ]
612 );
613 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
614 "Response to Message 2".into(),
615 ));
616 fake_model.end_last_completion_stream();
617 cx.run_until_parked();
618
619 // Simulate a tool call and verify that the latest tool result is cached
620 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
621 thread
622 .update(cx, |thread, cx| {
623 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
624 })
625 .unwrap();
626 cx.run_until_parked();
627
628 let tool_use = LanguageModelToolUse {
629 id: "tool_1".into(),
630 name: EchoTool::NAME.into(),
631 raw_input: json!({"text": "test"}).to_string(),
632 input: json!({"text": "test"}),
633 is_input_complete: true,
634 thought_signature: None,
635 };
636 fake_model
637 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
638 fake_model.end_last_completion_stream();
639 cx.run_until_parked();
640
641 let completion = fake_model.pending_completions().pop().unwrap();
642 let tool_result = LanguageModelToolResult {
643 tool_use_id: "tool_1".into(),
644 tool_name: EchoTool::NAME.into(),
645 is_error: false,
646 content: "test".into(),
647 output: Some("test".into()),
648 };
649 assert_eq!(
650 completion.messages[1..],
651 vec![
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec!["Message 1".into()],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::Assistant,
660 content: vec!["Response to Message 1".into()],
661 cache: false,
662 reasoning_details: None,
663 },
664 LanguageModelRequestMessage {
665 role: Role::User,
666 content: vec!["Message 2".into()],
667 cache: false,
668 reasoning_details: None,
669 },
670 LanguageModelRequestMessage {
671 role: Role::Assistant,
672 content: vec!["Response to Message 2".into()],
673 cache: false,
674 reasoning_details: None,
675 },
676 LanguageModelRequestMessage {
677 role: Role::User,
678 content: vec!["Use the echo tool".into()],
679 cache: false,
680 reasoning_details: None,
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false,
686 reasoning_details: None,
687 },
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec![MessageContent::ToolResult(tool_result)],
691 cache: true,
692 reasoning_details: None,
693 }
694 ]
695 );
696}
697
698#[gpui::test]
699#[cfg_attr(not(feature = "e2e"), ignore)]
700async fn test_basic_tool_calls(cx: &mut TestAppContext) {
701 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
702
703 // Test a tool call that's likely to complete *before* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.add_tool(EchoTool, None);
707 thread.send(
708 UserMessageId::new(),
709 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
710 cx,
711 )
712 })
713 .unwrap()
714 .collect()
715 .await;
716 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
717
718 // Test a tool calls that's likely to complete *after* streaming stops.
719 let events = thread
720 .update(cx, |thread, cx| {
721 thread.remove_tool(&EchoTool::NAME);
722 thread.add_tool(DelayTool, None);
723 thread.send(
724 UserMessageId::new(),
725 [
726 "Now call the delay tool with 200ms.",
727 "When the timer goes off, then you echo the output of the tool.",
728 ],
729 cx,
730 )
731 })
732 .unwrap()
733 .collect()
734 .await;
735 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
736 thread.update(cx, |thread, _cx| {
737 assert!(
738 thread
739 .last_message()
740 .unwrap()
741 .as_agent_message()
742 .unwrap()
743 .content
744 .iter()
745 .any(|content| {
746 if let AgentMessageContent::Text(text) = content {
747 text.contains("Ding")
748 } else {
749 false
750 }
751 }),
752 "{}",
753 thread.to_markdown()
754 );
755 });
756}
757
758#[gpui::test]
759#[cfg_attr(not(feature = "e2e"), ignore)]
760async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
761 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
762
763 // Test a tool call that's likely to complete *before* streaming stops.
764 let mut events = thread
765 .update(cx, |thread, cx| {
766 thread.add_tool(WordListTool, None);
767 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
768 })
769 .unwrap();
770
771 let mut saw_partial_tool_use = false;
772 while let Some(event) = events.next().await {
773 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
774 thread.update(cx, |thread, _cx| {
775 // Look for a tool use in the thread's last message
776 let message = thread.last_message().unwrap();
777 let agent_message = message.as_agent_message().unwrap();
778 let last_content = agent_message.content.last().unwrap();
779 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
780 assert_eq!(last_tool_use.name.as_ref(), "word_list");
781 if tool_call.status == acp::ToolCallStatus::Pending {
782 if !last_tool_use.is_input_complete
783 && last_tool_use.input.get("g").is_none()
784 {
785 saw_partial_tool_use = true;
786 }
787 } else {
788 last_tool_use
789 .input
790 .get("a")
791 .expect("'a' has streamed because input is now complete");
792 last_tool_use
793 .input
794 .get("g")
795 .expect("'g' has streamed because input is now complete");
796 }
797 } else {
798 panic!("last content should be a tool use");
799 }
800 });
801 }
802 }
803
804 assert!(
805 saw_partial_tool_use,
806 "should see at least one partially streamed tool use in the history"
807 );
808}
809
810#[gpui::test]
811async fn test_tool_authorization(cx: &mut TestAppContext) {
812 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
813 let fake_model = model.as_fake();
814
815 let mut events = thread
816 .update(cx, |thread, cx| {
817 thread.add_tool(ToolRequiringPermission, None);
818 thread.send(UserMessageId::new(), ["abc"], cx)
819 })
820 .unwrap();
821 cx.run_until_parked();
822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
823 LanguageModelToolUse {
824 id: "tool_id_1".into(),
825 name: ToolRequiringPermission::NAME.into(),
826 raw_input: "{}".into(),
827 input: json!({}),
828 is_input_complete: true,
829 thought_signature: None,
830 },
831 ));
832 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
833 LanguageModelToolUse {
834 id: "tool_id_2".into(),
835 name: ToolRequiringPermission::NAME.into(),
836 raw_input: "{}".into(),
837 input: json!({}),
838 is_input_complete: true,
839 thought_signature: None,
840 },
841 ));
842 fake_model.end_last_completion_stream();
843 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
844 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
845
846 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
847 tool_call_auth_1
848 .response
849 .send(acp::PermissionOptionId::new("allow"))
850 .unwrap();
851 cx.run_until_parked();
852
853 // Reject the second - send "deny" option_id directly since Deny is now a button
854 tool_call_auth_2
855 .response
856 .send(acp::PermissionOptionId::new("deny"))
857 .unwrap();
858 cx.run_until_parked();
859
860 let completion = fake_model.pending_completions().pop().unwrap();
861 let message = completion.messages.last().unwrap();
862 assert_eq!(
863 message.content,
864 vec![
865 language_model::MessageContent::ToolResult(LanguageModelToolResult {
866 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
867 tool_name: ToolRequiringPermission::NAME.into(),
868 is_error: false,
869 content: "Allowed".into(),
870 output: Some("Allowed".into())
871 }),
872 language_model::MessageContent::ToolResult(LanguageModelToolResult {
873 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
874 tool_name: ToolRequiringPermission::NAME.into(),
875 is_error: true,
876 content: "Permission to run tool denied by user".into(),
877 output: Some("Permission to run tool denied by user".into())
878 })
879 ]
880 );
881
882 // Simulate yet another tool call.
883 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
884 LanguageModelToolUse {
885 id: "tool_id_3".into(),
886 name: ToolRequiringPermission::NAME.into(),
887 raw_input: "{}".into(),
888 input: json!({}),
889 is_input_complete: true,
890 thought_signature: None,
891 },
892 ));
893 fake_model.end_last_completion_stream();
894
895 // Respond by always allowing tools - send transformed option_id
896 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
897 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
898 tool_call_auth_3
899 .response
900 .send(acp::PermissionOptionId::new(
901 "always_allow:tool_requiring_permission",
902 ))
903 .unwrap();
904 cx.run_until_parked();
905 let completion = fake_model.pending_completions().pop().unwrap();
906 let message = completion.messages.last().unwrap();
907 assert_eq!(
908 message.content,
909 vec![language_model::MessageContent::ToolResult(
910 LanguageModelToolResult {
911 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
912 tool_name: ToolRequiringPermission::NAME.into(),
913 is_error: false,
914 content: "Allowed".into(),
915 output: Some("Allowed".into())
916 }
917 )]
918 );
919
920 // Simulate a final tool call, ensuring we don't trigger authorization.
921 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
922 LanguageModelToolUse {
923 id: "tool_id_4".into(),
924 name: ToolRequiringPermission::NAME.into(),
925 raw_input: "{}".into(),
926 input: json!({}),
927 is_input_complete: true,
928 thought_signature: None,
929 },
930 ));
931 fake_model.end_last_completion_stream();
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let message = completion.messages.last().unwrap();
935 assert_eq!(
936 message.content,
937 vec![language_model::MessageContent::ToolResult(
938 LanguageModelToolResult {
939 tool_use_id: "tool_id_4".into(),
940 tool_name: ToolRequiringPermission::NAME.into(),
941 is_error: false,
942 content: "Allowed".into(),
943 output: Some("Allowed".into())
944 }
945 )]
946 );
947}
948
949#[gpui::test]
950async fn test_tool_hallucination(cx: &mut TestAppContext) {
951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
952 let fake_model = model.as_fake();
953
954 let mut events = thread
955 .update(cx, |thread, cx| {
956 thread.send(UserMessageId::new(), ["abc"], cx)
957 })
958 .unwrap();
959 cx.run_until_parked();
960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
961 LanguageModelToolUse {
962 id: "tool_id_1".into(),
963 name: "nonexistent_tool".into(),
964 raw_input: "{}".into(),
965 input: json!({}),
966 is_input_complete: true,
967 thought_signature: None,
968 },
969 ));
970 fake_model.end_last_completion_stream();
971
972 let tool_call = expect_tool_call(&mut events).await;
973 assert_eq!(tool_call.title, "nonexistent_tool");
974 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
975 let update = expect_tool_call_update_fields(&mut events).await;
976 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
977}
978
979async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
980 let event = events
981 .next()
982 .await
983 .expect("no tool call authorization event received")
984 .unwrap();
985 match event {
986 ThreadEvent::ToolCall(tool_call) => tool_call,
987 event => {
988 panic!("Unexpected event {event:?}");
989 }
990 }
991}
992
993async fn expect_tool_call_update_fields(
994 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
995) -> acp::ToolCallUpdate {
996 let event = events
997 .next()
998 .await
999 .expect("no tool call authorization event received")
1000 .unwrap();
1001 match event {
1002 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1003 event => {
1004 panic!("Unexpected event {event:?}");
1005 }
1006 }
1007}
1008
1009async fn next_tool_call_authorization(
1010 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1011) -> ToolCallAuthorization {
1012 loop {
1013 let event = events
1014 .next()
1015 .await
1016 .expect("no tool call authorization event received")
1017 .unwrap();
1018 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1019 let permission_kinds = tool_call_authorization
1020 .options
1021 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1022 .map(|option| option.kind);
1023 let allow_once = tool_call_authorization
1024 .options
1025 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1026 .map(|option| option.kind);
1027
1028 assert_eq!(
1029 permission_kinds,
1030 Some(acp::PermissionOptionKind::AllowAlways)
1031 );
1032 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1033 return tool_call_authorization;
1034 }
1035 }
1036}
1037
1038#[test]
1039fn test_permission_options_terminal_with_pattern() {
1040 let permission_options = ToolPermissionContext::new(
1041 TerminalTool::NAME,
1042 vec!["cargo build --release".to_string()],
1043 )
1044 .build_permission_options();
1045
1046 let PermissionOptions::Dropdown(choices) = permission_options else {
1047 panic!("Expected dropdown permission options");
1048 };
1049
1050 assert_eq!(choices.len(), 3);
1051 let labels: Vec<&str> = choices
1052 .iter()
1053 .map(|choice| choice.allow.name.as_ref())
1054 .collect();
1055 assert!(labels.contains(&"Always for terminal"));
1056 assert!(labels.contains(&"Always for `cargo` commands"));
1057 assert!(labels.contains(&"Only this time"));
1058}
1059
1060#[test]
1061fn test_permission_options_edit_file_with_path_pattern() {
1062 let permission_options =
1063 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1064 .build_permission_options();
1065
1066 let PermissionOptions::Dropdown(choices) = permission_options else {
1067 panic!("Expected dropdown permission options");
1068 };
1069
1070 let labels: Vec<&str> = choices
1071 .iter()
1072 .map(|choice| choice.allow.name.as_ref())
1073 .collect();
1074 assert!(labels.contains(&"Always for edit file"));
1075 assert!(labels.contains(&"Always for `src/`"));
1076}
1077
1078#[test]
1079fn test_permission_options_fetch_with_domain_pattern() {
1080 let permission_options =
1081 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1082 .build_permission_options();
1083
1084 let PermissionOptions::Dropdown(choices) = permission_options else {
1085 panic!("Expected dropdown permission options");
1086 };
1087
1088 let labels: Vec<&str> = choices
1089 .iter()
1090 .map(|choice| choice.allow.name.as_ref())
1091 .collect();
1092 assert!(labels.contains(&"Always for fetch"));
1093 assert!(labels.contains(&"Always for `docs.rs`"));
1094}
1095
1096#[test]
1097fn test_permission_options_without_pattern() {
1098 let permission_options = ToolPermissionContext::new(
1099 TerminalTool::NAME,
1100 vec!["./deploy.sh --production".to_string()],
1101 )
1102 .build_permission_options();
1103
1104 let PermissionOptions::Dropdown(choices) = permission_options else {
1105 panic!("Expected dropdown permission options");
1106 };
1107
1108 assert_eq!(choices.len(), 2);
1109 let labels: Vec<&str> = choices
1110 .iter()
1111 .map(|choice| choice.allow.name.as_ref())
1112 .collect();
1113 assert!(labels.contains(&"Always for terminal"));
1114 assert!(labels.contains(&"Only this time"));
1115 assert!(!labels.iter().any(|label| label.contains("commands")));
1116}
1117
1118#[test]
1119fn test_permission_option_ids_for_terminal() {
1120 let permission_options = ToolPermissionContext::new(
1121 TerminalTool::NAME,
1122 vec!["cargo build --release".to_string()],
1123 )
1124 .build_permission_options();
1125
1126 let PermissionOptions::Dropdown(choices) = permission_options else {
1127 panic!("Expected dropdown permission options");
1128 };
1129
1130 let allow_ids: Vec<String> = choices
1131 .iter()
1132 .map(|choice| choice.allow.option_id.0.to_string())
1133 .collect();
1134 let deny_ids: Vec<String> = choices
1135 .iter()
1136 .map(|choice| choice.deny.option_id.0.to_string())
1137 .collect();
1138
1139 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1140 assert!(allow_ids.contains(&"allow".to_string()));
1141 assert!(
1142 allow_ids
1143 .iter()
1144 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1145 "Missing allow pattern option"
1146 );
1147
1148 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1149 assert!(deny_ids.contains(&"deny".to_string()));
1150 assert!(
1151 deny_ids
1152 .iter()
1153 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1154 "Missing deny pattern option"
1155 );
1156}
1157
1158#[gpui::test]
1159#[cfg_attr(not(feature = "e2e"), ignore)]
1160async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1161 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1162
1163 // Test concurrent tool calls with different delay times
1164 let events = thread
1165 .update(cx, |thread, cx| {
1166 thread.add_tool(DelayTool, None);
1167 thread.send(
1168 UserMessageId::new(),
1169 [
1170 "Call the delay tool twice in the same message.",
1171 "Once with 100ms. Once with 300ms.",
1172 "When both timers are complete, describe the outputs.",
1173 ],
1174 cx,
1175 )
1176 })
1177 .unwrap()
1178 .collect()
1179 .await;
1180
1181 let stop_reasons = stop_events(events);
1182 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1183
1184 thread.update(cx, |thread, _cx| {
1185 let last_message = thread.last_message().unwrap();
1186 let agent_message = last_message.as_agent_message().unwrap();
1187 let text = agent_message
1188 .content
1189 .iter()
1190 .filter_map(|content| {
1191 if let AgentMessageContent::Text(text) = content {
1192 Some(text.as_str())
1193 } else {
1194 None
1195 }
1196 })
1197 .collect::<String>();
1198
1199 assert!(text.contains("Ding"));
1200 });
1201}
1202
1203#[gpui::test]
1204async fn test_profiles(cx: &mut TestAppContext) {
1205 let ThreadTest {
1206 model, thread, fs, ..
1207 } = setup(cx, TestModel::Fake).await;
1208 let fake_model = model.as_fake();
1209
1210 thread.update(cx, |thread, _cx| {
1211 thread.add_tool(DelayTool, None);
1212 thread.add_tool(EchoTool, None);
1213 thread.add_tool(InfiniteTool, None);
1214 });
1215
1216 // Override profiles and wait for settings to be loaded.
1217 fs.insert_file(
1218 paths::settings_file(),
1219 json!({
1220 "agent": {
1221 "profiles": {
1222 "test-1": {
1223 "name": "Test Profile 1",
1224 "tools": {
1225 EchoTool::NAME: true,
1226 DelayTool::NAME: true,
1227 }
1228 },
1229 "test-2": {
1230 "name": "Test Profile 2",
1231 "tools": {
1232 InfiniteTool::NAME: true,
1233 }
1234 }
1235 }
1236 }
1237 })
1238 .to_string()
1239 .into_bytes(),
1240 )
1241 .await;
1242 cx.run_until_parked();
1243
1244 // Test that test-1 profile (default) has echo and delay tools
1245 thread
1246 .update(cx, |thread, cx| {
1247 thread.set_profile(AgentProfileId("test-1".into()), cx);
1248 thread.send(UserMessageId::new(), ["test"], cx)
1249 })
1250 .unwrap();
1251 cx.run_until_parked();
1252
1253 let mut pending_completions = fake_model.pending_completions();
1254 assert_eq!(pending_completions.len(), 1);
1255 let completion = pending_completions.pop().unwrap();
1256 let tool_names: Vec<String> = completion
1257 .tools
1258 .iter()
1259 .map(|tool| tool.name.clone())
1260 .collect();
1261 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1262 fake_model.end_last_completion_stream();
1263
1264 // Switch to test-2 profile, and verify that it has only the infinite tool.
1265 thread
1266 .update(cx, |thread, cx| {
1267 thread.set_profile(AgentProfileId("test-2".into()), cx);
1268 thread.send(UserMessageId::new(), ["test2"], cx)
1269 })
1270 .unwrap();
1271 cx.run_until_parked();
1272 let mut pending_completions = fake_model.pending_completions();
1273 assert_eq!(pending_completions.len(), 1);
1274 let completion = pending_completions.pop().unwrap();
1275 let tool_names: Vec<String> = completion
1276 .tools
1277 .iter()
1278 .map(|tool| tool.name.clone())
1279 .collect();
1280 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1281}
1282
1283#[gpui::test]
1284async fn test_mcp_tools(cx: &mut TestAppContext) {
1285 let ThreadTest {
1286 model,
1287 thread,
1288 context_server_store,
1289 fs,
1290 ..
1291 } = setup(cx, TestModel::Fake).await;
1292 let fake_model = model.as_fake();
1293
1294 // Override profiles and wait for settings to be loaded.
1295 fs.insert_file(
1296 paths::settings_file(),
1297 json!({
1298 "agent": {
1299 "tool_permissions": { "default": "allow" },
1300 "profiles": {
1301 "test": {
1302 "name": "Test Profile",
1303 "enable_all_context_servers": true,
1304 "tools": {
1305 EchoTool::NAME: true,
1306 }
1307 },
1308 }
1309 }
1310 })
1311 .to_string()
1312 .into_bytes(),
1313 )
1314 .await;
1315 cx.run_until_parked();
1316 thread.update(cx, |thread, cx| {
1317 thread.set_profile(AgentProfileId("test".into()), cx)
1318 });
1319
1320 let mut mcp_tool_calls = setup_context_server(
1321 "test_server",
1322 vec![context_server::types::Tool {
1323 name: "echo".into(),
1324 description: None,
1325 input_schema: serde_json::to_value(EchoTool::input_schema(
1326 LanguageModelToolSchemaFormat::JsonSchema,
1327 ))
1328 .unwrap(),
1329 output_schema: None,
1330 annotations: None,
1331 }],
1332 &context_server_store,
1333 cx,
1334 );
1335
1336 let events = thread.update(cx, |thread, cx| {
1337 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1338 });
1339 cx.run_until_parked();
1340
1341 // Simulate the model calling the MCP tool.
1342 let completion = fake_model.pending_completions().pop().unwrap();
1343 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1344 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1345 LanguageModelToolUse {
1346 id: "tool_1".into(),
1347 name: "echo".into(),
1348 raw_input: json!({"text": "test"}).to_string(),
1349 input: json!({"text": "test"}),
1350 is_input_complete: true,
1351 thought_signature: None,
1352 },
1353 ));
1354 fake_model.end_last_completion_stream();
1355 cx.run_until_parked();
1356
1357 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1358 assert_eq!(tool_call_params.name, "echo");
1359 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1360 tool_call_response
1361 .send(context_server::types::CallToolResponse {
1362 content: vec![context_server::types::ToolResponseContent::Text {
1363 text: "test".into(),
1364 }],
1365 is_error: None,
1366 meta: None,
1367 structured_content: None,
1368 })
1369 .unwrap();
1370 cx.run_until_parked();
1371
1372 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1373 fake_model.send_last_completion_stream_text_chunk("Done!");
1374 fake_model.end_last_completion_stream();
1375 events.collect::<Vec<_>>().await;
1376
1377 // Send again after adding the echo tool, ensuring the name collision is resolved.
1378 let events = thread.update(cx, |thread, cx| {
1379 thread.add_tool(EchoTool, None);
1380 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1381 });
1382 cx.run_until_parked();
1383 let completion = fake_model.pending_completions().pop().unwrap();
1384 assert_eq!(
1385 tool_names_for_completion(&completion),
1386 vec!["echo", "test_server_echo"]
1387 );
1388 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1389 LanguageModelToolUse {
1390 id: "tool_2".into(),
1391 name: "test_server_echo".into(),
1392 raw_input: json!({"text": "mcp"}).to_string(),
1393 input: json!({"text": "mcp"}),
1394 is_input_complete: true,
1395 thought_signature: None,
1396 },
1397 ));
1398 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1399 LanguageModelToolUse {
1400 id: "tool_3".into(),
1401 name: "echo".into(),
1402 raw_input: json!({"text": "native"}).to_string(),
1403 input: json!({"text": "native"}),
1404 is_input_complete: true,
1405 thought_signature: None,
1406 },
1407 ));
1408 fake_model.end_last_completion_stream();
1409 cx.run_until_parked();
1410
1411 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1412 assert_eq!(tool_call_params.name, "echo");
1413 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1414 tool_call_response
1415 .send(context_server::types::CallToolResponse {
1416 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1417 is_error: None,
1418 meta: None,
1419 structured_content: None,
1420 })
1421 .unwrap();
1422 cx.run_until_parked();
1423
1424 // Ensure the tool results were inserted with the correct names.
1425 let completion = fake_model.pending_completions().pop().unwrap();
1426 assert_eq!(
1427 completion.messages.last().unwrap().content,
1428 vec![
1429 MessageContent::ToolResult(LanguageModelToolResult {
1430 tool_use_id: "tool_3".into(),
1431 tool_name: "echo".into(),
1432 is_error: false,
1433 content: "native".into(),
1434 output: Some("native".into()),
1435 },),
1436 MessageContent::ToolResult(LanguageModelToolResult {
1437 tool_use_id: "tool_2".into(),
1438 tool_name: "test_server_echo".into(),
1439 is_error: false,
1440 content: "mcp".into(),
1441 output: Some("mcp".into()),
1442 },),
1443 ]
1444 );
1445 fake_model.end_last_completion_stream();
1446 events.collect::<Vec<_>>().await;
1447}
1448
1449#[gpui::test]
1450async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1451 let ThreadTest {
1452 model,
1453 thread,
1454 context_server_store,
1455 fs,
1456 ..
1457 } = setup(cx, TestModel::Fake).await;
1458 let fake_model = model.as_fake();
1459
1460 // Set up a profile with all tools enabled
1461 fs.insert_file(
1462 paths::settings_file(),
1463 json!({
1464 "agent": {
1465 "profiles": {
1466 "test": {
1467 "name": "Test Profile",
1468 "enable_all_context_servers": true,
1469 "tools": {
1470 EchoTool::NAME: true,
1471 DelayTool::NAME: true,
1472 WordListTool::NAME: true,
1473 ToolRequiringPermission::NAME: true,
1474 InfiniteTool::NAME: true,
1475 }
1476 },
1477 }
1478 }
1479 })
1480 .to_string()
1481 .into_bytes(),
1482 )
1483 .await;
1484 cx.run_until_parked();
1485
1486 thread.update(cx, |thread, cx| {
1487 thread.set_profile(AgentProfileId("test".into()), cx);
1488 thread.add_tool(EchoTool, None);
1489 thread.add_tool(DelayTool, None);
1490 thread.add_tool(WordListTool, None);
1491 thread.add_tool(ToolRequiringPermission, None);
1492 thread.add_tool(InfiniteTool, None);
1493 });
1494
1495 // Set up multiple context servers with some overlapping tool names
1496 let _server1_calls = setup_context_server(
1497 "xxx",
1498 vec![
1499 context_server::types::Tool {
1500 name: "echo".into(), // Conflicts with native EchoTool
1501 description: None,
1502 input_schema: serde_json::to_value(EchoTool::input_schema(
1503 LanguageModelToolSchemaFormat::JsonSchema,
1504 ))
1505 .unwrap(),
1506 output_schema: None,
1507 annotations: None,
1508 },
1509 context_server::types::Tool {
1510 name: "unique_tool_1".into(),
1511 description: None,
1512 input_schema: json!({"type": "object", "properties": {}}),
1513 output_schema: None,
1514 annotations: None,
1515 },
1516 ],
1517 &context_server_store,
1518 cx,
1519 );
1520
1521 let _server2_calls = setup_context_server(
1522 "yyy",
1523 vec![
1524 context_server::types::Tool {
1525 name: "echo".into(), // Also conflicts with native EchoTool
1526 description: None,
1527 input_schema: serde_json::to_value(EchoTool::input_schema(
1528 LanguageModelToolSchemaFormat::JsonSchema,
1529 ))
1530 .unwrap(),
1531 output_schema: None,
1532 annotations: None,
1533 },
1534 context_server::types::Tool {
1535 name: "unique_tool_2".into(),
1536 description: None,
1537 input_schema: json!({"type": "object", "properties": {}}),
1538 output_schema: None,
1539 annotations: None,
1540 },
1541 context_server::types::Tool {
1542 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1543 description: None,
1544 input_schema: json!({"type": "object", "properties": {}}),
1545 output_schema: None,
1546 annotations: None,
1547 },
1548 context_server::types::Tool {
1549 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1550 description: None,
1551 input_schema: json!({"type": "object", "properties": {}}),
1552 output_schema: None,
1553 annotations: None,
1554 },
1555 ],
1556 &context_server_store,
1557 cx,
1558 );
1559 let _server3_calls = setup_context_server(
1560 "zzz",
1561 vec![
1562 context_server::types::Tool {
1563 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1564 description: None,
1565 input_schema: json!({"type": "object", "properties": {}}),
1566 output_schema: None,
1567 annotations: None,
1568 },
1569 context_server::types::Tool {
1570 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1571 description: None,
1572 input_schema: json!({"type": "object", "properties": {}}),
1573 output_schema: None,
1574 annotations: None,
1575 },
1576 context_server::types::Tool {
1577 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1578 description: None,
1579 input_schema: json!({"type": "object", "properties": {}}),
1580 output_schema: None,
1581 annotations: None,
1582 },
1583 ],
1584 &context_server_store,
1585 cx,
1586 );
1587
1588 thread
1589 .update(cx, |thread, cx| {
1590 thread.send(UserMessageId::new(), ["Go"], cx)
1591 })
1592 .unwrap();
1593 cx.run_until_parked();
1594 let completion = fake_model.pending_completions().pop().unwrap();
1595 assert_eq!(
1596 tool_names_for_completion(&completion),
1597 vec![
1598 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1599 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1600 "delay",
1601 "echo",
1602 "infinite",
1603 "tool_requiring_permission",
1604 "unique_tool_1",
1605 "unique_tool_2",
1606 "word_list",
1607 "xxx_echo",
1608 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1609 "yyy_echo",
1610 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1611 ]
1612 );
1613}
1614
1615#[gpui::test]
1616#[cfg_attr(not(feature = "e2e"), ignore)]
1617async fn test_cancellation(cx: &mut TestAppContext) {
1618 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1619
1620 let mut events = thread
1621 .update(cx, |thread, cx| {
1622 thread.add_tool(InfiniteTool, None);
1623 thread.add_tool(EchoTool, None);
1624 thread.send(
1625 UserMessageId::new(),
1626 ["Call the echo tool, then call the infinite tool, then explain their output"],
1627 cx,
1628 )
1629 })
1630 .unwrap();
1631
1632 // Wait until both tools are called.
1633 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1634 let mut echo_id = None;
1635 let mut echo_completed = false;
1636 while let Some(event) = events.next().await {
1637 match event.unwrap() {
1638 ThreadEvent::ToolCall(tool_call) => {
1639 assert_eq!(tool_call.title, expected_tools.remove(0));
1640 if tool_call.title == "Echo" {
1641 echo_id = Some(tool_call.tool_call_id);
1642 }
1643 }
1644 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1645 acp::ToolCallUpdate {
1646 tool_call_id,
1647 fields:
1648 acp::ToolCallUpdateFields {
1649 status: Some(acp::ToolCallStatus::Completed),
1650 ..
1651 },
1652 ..
1653 },
1654 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1655 echo_completed = true;
1656 }
1657 _ => {}
1658 }
1659
1660 if expected_tools.is_empty() && echo_completed {
1661 break;
1662 }
1663 }
1664
1665 // Cancel the current send and ensure that the event stream is closed, even
1666 // if one of the tools is still running.
1667 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1668 let events = events.collect::<Vec<_>>().await;
1669 let last_event = events.last();
1670 assert!(
1671 matches!(
1672 last_event,
1673 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1674 ),
1675 "unexpected event {last_event:?}"
1676 );
1677
1678 // Ensure we can still send a new message after cancellation.
1679 let events = thread
1680 .update(cx, |thread, cx| {
1681 thread.send(
1682 UserMessageId::new(),
1683 ["Testing: reply with 'Hello' then stop."],
1684 cx,
1685 )
1686 })
1687 .unwrap()
1688 .collect::<Vec<_>>()
1689 .await;
1690 thread.update(cx, |thread, _cx| {
1691 let message = thread.last_message().unwrap();
1692 let agent_message = message.as_agent_message().unwrap();
1693 assert_eq!(
1694 agent_message.content,
1695 vec![AgentMessageContent::Text("Hello".to_string())]
1696 );
1697 });
1698 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1699}
1700
1701#[gpui::test]
1702async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1703 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1704 always_allow_tools(cx);
1705 let fake_model = model.as_fake();
1706
1707 let environment = Rc::new(cx.update(|cx| {
1708 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1709 }));
1710 let handle = environment.terminal_handle.clone().unwrap();
1711
1712 let mut events = thread
1713 .update(cx, |thread, cx| {
1714 thread.add_tool(
1715 crate::TerminalTool::new(thread.project().clone(), environment),
1716 None,
1717 );
1718 thread.send(UserMessageId::new(), ["run a command"], cx)
1719 })
1720 .unwrap();
1721
1722 cx.run_until_parked();
1723
1724 // Simulate the model calling the terminal tool
1725 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1726 LanguageModelToolUse {
1727 id: "terminal_tool_1".into(),
1728 name: TerminalTool::NAME.into(),
1729 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1730 input: json!({"command": "sleep 1000", "cd": "."}),
1731 is_input_complete: true,
1732 thought_signature: None,
1733 },
1734 ));
1735 fake_model.end_last_completion_stream();
1736
1737 // Wait for the terminal tool to start running
1738 wait_for_terminal_tool_started(&mut events, cx).await;
1739
1740 // Cancel the thread while the terminal is running
1741 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1742
1743 // Collect remaining events, driving the executor to let cancellation complete
1744 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1745
1746 // Verify the terminal was killed
1747 assert!(
1748 handle.was_killed(),
1749 "expected terminal handle to be killed on cancellation"
1750 );
1751
1752 // Verify we got a cancellation stop event
1753 assert_eq!(
1754 stop_events(remaining_events),
1755 vec![acp::StopReason::Cancelled],
1756 );
1757
1758 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1759 thread.update(cx, |thread, _cx| {
1760 let message = thread.last_message().unwrap();
1761 let agent_message = message.as_agent_message().unwrap();
1762
1763 let tool_use = agent_message
1764 .content
1765 .iter()
1766 .find_map(|content| match content {
1767 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1768 _ => None,
1769 })
1770 .expect("expected tool use in agent message");
1771
1772 let tool_result = agent_message
1773 .tool_results
1774 .get(&tool_use.id)
1775 .expect("expected tool result");
1776
1777 let result_text = match &tool_result.content {
1778 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1779 _ => panic!("expected text content in tool result"),
1780 };
1781
1782 // "partial output" comes from FakeTerminalHandle's output field
1783 assert!(
1784 result_text.contains("partial output"),
1785 "expected tool result to contain terminal output, got: {result_text}"
1786 );
1787 // Match the actual format from process_content in terminal_tool.rs
1788 assert!(
1789 result_text.contains("The user stopped this command"),
1790 "expected tool result to indicate user stopped, got: {result_text}"
1791 );
1792 });
1793
1794 // Verify we can send a new message after cancellation
1795 verify_thread_recovery(&thread, &fake_model, cx).await;
1796}
1797
1798#[gpui::test]
1799async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1800 // This test verifies that tools which properly handle cancellation via
1801 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1802 // to cancellation and report that they were cancelled.
1803 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1804 always_allow_tools(cx);
1805 let fake_model = model.as_fake();
1806
1807 let (tool, was_cancelled) = CancellationAwareTool::new();
1808
1809 let mut events = thread
1810 .update(cx, |thread, cx| {
1811 thread.add_tool(tool, None);
1812 thread.send(
1813 UserMessageId::new(),
1814 ["call the cancellation aware tool"],
1815 cx,
1816 )
1817 })
1818 .unwrap();
1819
1820 cx.run_until_parked();
1821
1822 // Simulate the model calling the cancellation-aware tool
1823 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1824 LanguageModelToolUse {
1825 id: "cancellation_aware_1".into(),
1826 name: "cancellation_aware".into(),
1827 raw_input: r#"{}"#.into(),
1828 input: json!({}),
1829 is_input_complete: true,
1830 thought_signature: None,
1831 },
1832 ));
1833 fake_model.end_last_completion_stream();
1834
1835 cx.run_until_parked();
1836
1837 // Wait for the tool call to be reported
1838 let mut tool_started = false;
1839 let deadline = cx.executor().num_cpus() * 100;
1840 for _ in 0..deadline {
1841 cx.run_until_parked();
1842
1843 while let Some(Some(event)) = events.next().now_or_never() {
1844 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1845 if tool_call.title == "Cancellation Aware Tool" {
1846 tool_started = true;
1847 break;
1848 }
1849 }
1850 }
1851
1852 if tool_started {
1853 break;
1854 }
1855
1856 cx.background_executor
1857 .timer(Duration::from_millis(10))
1858 .await;
1859 }
1860 assert!(tool_started, "expected cancellation aware tool to start");
1861
1862 // Cancel the thread and wait for it to complete
1863 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1864
1865 // The cancel task should complete promptly because the tool handles cancellation
1866 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1867 futures::select! {
1868 _ = cancel_task.fuse() => {}
1869 _ = timeout.fuse() => {
1870 panic!("cancel task timed out - tool did not respond to cancellation");
1871 }
1872 }
1873
1874 // Verify the tool detected cancellation via its flag
1875 assert!(
1876 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1877 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1878 );
1879
1880 // Collect remaining events
1881 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1882
1883 // Verify we got a cancellation stop event
1884 assert_eq!(
1885 stop_events(remaining_events),
1886 vec![acp::StopReason::Cancelled],
1887 );
1888
1889 // Verify we can send a new message after cancellation
1890 verify_thread_recovery(&thread, &fake_model, cx).await;
1891}
1892
1893/// Helper to verify thread can recover after cancellation by sending a simple message.
1894async fn verify_thread_recovery(
1895 thread: &Entity<Thread>,
1896 fake_model: &FakeLanguageModel,
1897 cx: &mut TestAppContext,
1898) {
1899 let events = thread
1900 .update(cx, |thread, cx| {
1901 thread.send(
1902 UserMessageId::new(),
1903 ["Testing: reply with 'Hello' then stop."],
1904 cx,
1905 )
1906 })
1907 .unwrap();
1908 cx.run_until_parked();
1909 fake_model.send_last_completion_stream_text_chunk("Hello");
1910 fake_model
1911 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1912 fake_model.end_last_completion_stream();
1913
1914 let events = events.collect::<Vec<_>>().await;
1915 thread.update(cx, |thread, _cx| {
1916 let message = thread.last_message().unwrap();
1917 let agent_message = message.as_agent_message().unwrap();
1918 assert_eq!(
1919 agent_message.content,
1920 vec![AgentMessageContent::Text("Hello".to_string())]
1921 );
1922 });
1923 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1924}
1925
1926/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1927async fn wait_for_terminal_tool_started(
1928 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1929 cx: &mut TestAppContext,
1930) {
1931 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1932 for _ in 0..deadline {
1933 cx.run_until_parked();
1934
1935 while let Some(Some(event)) = events.next().now_or_never() {
1936 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1937 update,
1938 ))) = &event
1939 {
1940 if update.fields.content.as_ref().is_some_and(|content| {
1941 content
1942 .iter()
1943 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1944 }) {
1945 return;
1946 }
1947 }
1948 }
1949
1950 cx.background_executor
1951 .timer(Duration::from_millis(10))
1952 .await;
1953 }
1954 panic!("terminal tool did not start within the expected time");
1955}
1956
1957/// Collects events until a Stop event is received, driving the executor to completion.
1958async fn collect_events_until_stop(
1959 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1960 cx: &mut TestAppContext,
1961) -> Vec<Result<ThreadEvent>> {
1962 let mut collected = Vec::new();
1963 let deadline = cx.executor().num_cpus() * 200;
1964
1965 for _ in 0..deadline {
1966 cx.executor().advance_clock(Duration::from_millis(10));
1967 cx.run_until_parked();
1968
1969 while let Some(Some(event)) = events.next().now_or_never() {
1970 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1971 collected.push(event);
1972 if is_stop {
1973 return collected;
1974 }
1975 }
1976 }
1977 panic!(
1978 "did not receive Stop event within the expected time; collected {} events",
1979 collected.len()
1980 );
1981}
1982
1983#[gpui::test]
1984async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1985 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1986 always_allow_tools(cx);
1987 let fake_model = model.as_fake();
1988
1989 let environment = Rc::new(cx.update(|cx| {
1990 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1991 }));
1992 let handle = environment.terminal_handle.clone().unwrap();
1993
1994 let message_id = UserMessageId::new();
1995 let mut events = thread
1996 .update(cx, |thread, cx| {
1997 thread.add_tool(
1998 crate::TerminalTool::new(thread.project().clone(), environment),
1999 None,
2000 );
2001 thread.send(message_id.clone(), ["run a command"], cx)
2002 })
2003 .unwrap();
2004
2005 cx.run_until_parked();
2006
2007 // Simulate the model calling the terminal tool
2008 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2009 LanguageModelToolUse {
2010 id: "terminal_tool_1".into(),
2011 name: TerminalTool::NAME.into(),
2012 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2013 input: json!({"command": "sleep 1000", "cd": "."}),
2014 is_input_complete: true,
2015 thought_signature: None,
2016 },
2017 ));
2018 fake_model.end_last_completion_stream();
2019
2020 // Wait for the terminal tool to start running
2021 wait_for_terminal_tool_started(&mut events, cx).await;
2022
2023 // Truncate the thread while the terminal is running
2024 thread
2025 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2026 .unwrap();
2027
2028 // Drive the executor to let cancellation complete
2029 let _ = collect_events_until_stop(&mut events, cx).await;
2030
2031 // Verify the terminal was killed
2032 assert!(
2033 handle.was_killed(),
2034 "expected terminal handle to be killed on truncate"
2035 );
2036
2037 // Verify the thread is empty after truncation
2038 thread.update(cx, |thread, _cx| {
2039 assert_eq!(
2040 thread.to_markdown(),
2041 "",
2042 "expected thread to be empty after truncating the only message"
2043 );
2044 });
2045
2046 // Verify we can send a new message after truncation
2047 verify_thread_recovery(&thread, &fake_model, cx).await;
2048}
2049
2050#[gpui::test]
2051async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2052 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2054 always_allow_tools(cx);
2055 let fake_model = model.as_fake();
2056
2057 let environment = Rc::new(MultiTerminalEnvironment::new());
2058
2059 let mut events = thread
2060 .update(cx, |thread, cx| {
2061 thread.add_tool(
2062 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2063 None,
2064 );
2065 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2066 })
2067 .unwrap();
2068
2069 cx.run_until_parked();
2070
2071 // Simulate the model calling two terminal tools
2072 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2073 LanguageModelToolUse {
2074 id: "terminal_tool_1".into(),
2075 name: TerminalTool::NAME.into(),
2076 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2077 input: json!({"command": "sleep 1000", "cd": "."}),
2078 is_input_complete: true,
2079 thought_signature: None,
2080 },
2081 ));
2082 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2083 LanguageModelToolUse {
2084 id: "terminal_tool_2".into(),
2085 name: TerminalTool::NAME.into(),
2086 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2087 input: json!({"command": "sleep 2000", "cd": "."}),
2088 is_input_complete: true,
2089 thought_signature: None,
2090 },
2091 ));
2092 fake_model.end_last_completion_stream();
2093
2094 // Wait for both terminal tools to start by counting terminal content updates
2095 let mut terminals_started = 0;
2096 let deadline = cx.executor().num_cpus() * 100;
2097 for _ in 0..deadline {
2098 cx.run_until_parked();
2099
2100 while let Some(Some(event)) = events.next().now_or_never() {
2101 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2102 update,
2103 ))) = &event
2104 {
2105 if update.fields.content.as_ref().is_some_and(|content| {
2106 content
2107 .iter()
2108 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2109 }) {
2110 terminals_started += 1;
2111 if terminals_started >= 2 {
2112 break;
2113 }
2114 }
2115 }
2116 }
2117 if terminals_started >= 2 {
2118 break;
2119 }
2120
2121 cx.background_executor
2122 .timer(Duration::from_millis(10))
2123 .await;
2124 }
2125 assert!(
2126 terminals_started >= 2,
2127 "expected 2 terminal tools to start, got {terminals_started}"
2128 );
2129
2130 // Cancel the thread while both terminals are running
2131 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2132
2133 // Collect remaining events
2134 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2135
2136 // Verify both terminal handles were killed
2137 let handles = environment.handles();
2138 assert_eq!(
2139 handles.len(),
2140 2,
2141 "expected 2 terminal handles to be created"
2142 );
2143 assert!(
2144 handles[0].was_killed(),
2145 "expected first terminal handle to be killed on cancellation"
2146 );
2147 assert!(
2148 handles[1].was_killed(),
2149 "expected second terminal handle to be killed on cancellation"
2150 );
2151
2152 // Verify we got a cancellation stop event
2153 assert_eq!(
2154 stop_events(remaining_events),
2155 vec![acp::StopReason::Cancelled],
2156 );
2157}
2158
2159#[gpui::test]
2160async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2161 // Tests that clicking the stop button on the terminal card (as opposed to the main
2162 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2163 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2164 always_allow_tools(cx);
2165 let fake_model = model.as_fake();
2166
2167 let environment = Rc::new(cx.update(|cx| {
2168 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2169 }));
2170 let handle = environment.terminal_handle.clone().unwrap();
2171
2172 let mut events = thread
2173 .update(cx, |thread, cx| {
2174 thread.add_tool(
2175 crate::TerminalTool::new(thread.project().clone(), environment),
2176 None,
2177 );
2178 thread.send(UserMessageId::new(), ["run a command"], cx)
2179 })
2180 .unwrap();
2181
2182 cx.run_until_parked();
2183
2184 // Simulate the model calling the terminal tool
2185 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2186 LanguageModelToolUse {
2187 id: "terminal_tool_1".into(),
2188 name: TerminalTool::NAME.into(),
2189 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2190 input: json!({"command": "sleep 1000", "cd": "."}),
2191 is_input_complete: true,
2192 thought_signature: None,
2193 },
2194 ));
2195 fake_model.end_last_completion_stream();
2196
2197 // Wait for the terminal tool to start running
2198 wait_for_terminal_tool_started(&mut events, cx).await;
2199
2200 // Simulate user clicking stop on the terminal card itself.
2201 // This sets the flag and signals exit (simulating what the real UI would do).
2202 handle.set_stopped_by_user(true);
2203 handle.killed.store(true, Ordering::SeqCst);
2204 handle.signal_exit();
2205
2206 // Wait for the tool to complete
2207 cx.run_until_parked();
2208
2209 // The thread continues after tool completion - simulate the model ending its turn
2210 fake_model
2211 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2212 fake_model.end_last_completion_stream();
2213
2214 // Collect remaining events
2215 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2216
2217 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2218 assert_eq!(
2219 stop_events(remaining_events),
2220 vec![acp::StopReason::EndTurn],
2221 );
2222
2223 // Verify the tool result indicates user stopped
2224 thread.update(cx, |thread, _cx| {
2225 let message = thread.last_message().unwrap();
2226 let agent_message = message.as_agent_message().unwrap();
2227
2228 let tool_use = agent_message
2229 .content
2230 .iter()
2231 .find_map(|content| match content {
2232 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2233 _ => None,
2234 })
2235 .expect("expected tool use in agent message");
2236
2237 let tool_result = agent_message
2238 .tool_results
2239 .get(&tool_use.id)
2240 .expect("expected tool result");
2241
2242 let result_text = match &tool_result.content {
2243 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2244 _ => panic!("expected text content in tool result"),
2245 };
2246
2247 assert!(
2248 result_text.contains("The user stopped this command"),
2249 "expected tool result to indicate user stopped, got: {result_text}"
2250 );
2251 });
2252}
2253
2254#[gpui::test]
2255async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2256 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2257 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2258 always_allow_tools(cx);
2259 let fake_model = model.as_fake();
2260
2261 let environment = Rc::new(cx.update(|cx| {
2262 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2263 }));
2264 let handle = environment.terminal_handle.clone().unwrap();
2265
2266 let mut events = thread
2267 .update(cx, |thread, cx| {
2268 thread.add_tool(
2269 crate::TerminalTool::new(thread.project().clone(), environment),
2270 None,
2271 );
2272 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2273 })
2274 .unwrap();
2275
2276 cx.run_until_parked();
2277
2278 // Simulate the model calling the terminal tool with a short timeout
2279 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2280 LanguageModelToolUse {
2281 id: "terminal_tool_1".into(),
2282 name: TerminalTool::NAME.into(),
2283 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2284 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2285 is_input_complete: true,
2286 thought_signature: None,
2287 },
2288 ));
2289 fake_model.end_last_completion_stream();
2290
2291 // Wait for the terminal tool to start running
2292 wait_for_terminal_tool_started(&mut events, cx).await;
2293
2294 // Advance clock past the timeout
2295 cx.executor().advance_clock(Duration::from_millis(200));
2296 cx.run_until_parked();
2297
2298 // The thread continues after tool completion - simulate the model ending its turn
2299 fake_model
2300 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2301 fake_model.end_last_completion_stream();
2302
2303 // Collect remaining events
2304 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2305
2306 // Verify the terminal was killed due to timeout
2307 assert!(
2308 handle.was_killed(),
2309 "expected terminal handle to be killed on timeout"
2310 );
2311
2312 // Verify we got an EndTurn (the tool completed, just with timeout)
2313 assert_eq!(
2314 stop_events(remaining_events),
2315 vec![acp::StopReason::EndTurn],
2316 );
2317
2318 // Verify the tool result indicates timeout, not user stopped
2319 thread.update(cx, |thread, _cx| {
2320 let message = thread.last_message().unwrap();
2321 let agent_message = message.as_agent_message().unwrap();
2322
2323 let tool_use = agent_message
2324 .content
2325 .iter()
2326 .find_map(|content| match content {
2327 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2328 _ => None,
2329 })
2330 .expect("expected tool use in agent message");
2331
2332 let tool_result = agent_message
2333 .tool_results
2334 .get(&tool_use.id)
2335 .expect("expected tool result");
2336
2337 let result_text = match &tool_result.content {
2338 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2339 _ => panic!("expected text content in tool result"),
2340 };
2341
2342 assert!(
2343 result_text.contains("timed out"),
2344 "expected tool result to indicate timeout, got: {result_text}"
2345 );
2346 assert!(
2347 !result_text.contains("The user stopped"),
2348 "tool result should not mention user stopped when it timed out, got: {result_text}"
2349 );
2350 });
2351}
2352
2353#[gpui::test]
2354async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2355 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2356 let fake_model = model.as_fake();
2357
2358 let events_1 = thread
2359 .update(cx, |thread, cx| {
2360 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2361 })
2362 .unwrap();
2363 cx.run_until_parked();
2364 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2365 cx.run_until_parked();
2366
2367 let events_2 = thread
2368 .update(cx, |thread, cx| {
2369 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2370 })
2371 .unwrap();
2372 cx.run_until_parked();
2373 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2374 fake_model
2375 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2376 fake_model.end_last_completion_stream();
2377
2378 let events_1 = events_1.collect::<Vec<_>>().await;
2379 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2380 let events_2 = events_2.collect::<Vec<_>>().await;
2381 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2382}
2383
2384#[gpui::test]
2385async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2386 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2387 let fake_model = model.as_fake();
2388
2389 let events_1 = thread
2390 .update(cx, |thread, cx| {
2391 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2392 })
2393 .unwrap();
2394 cx.run_until_parked();
2395 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2396 fake_model
2397 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2398 fake_model.end_last_completion_stream();
2399 let events_1 = events_1.collect::<Vec<_>>().await;
2400
2401 let events_2 = thread
2402 .update(cx, |thread, cx| {
2403 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2404 })
2405 .unwrap();
2406 cx.run_until_parked();
2407 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2408 fake_model
2409 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2410 fake_model.end_last_completion_stream();
2411 let events_2 = events_2.collect::<Vec<_>>().await;
2412
2413 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2414 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2415}
2416
2417#[gpui::test]
2418async fn test_refusal(cx: &mut TestAppContext) {
2419 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2420 let fake_model = model.as_fake();
2421
2422 let events = thread
2423 .update(cx, |thread, cx| {
2424 thread.send(UserMessageId::new(), ["Hello"], cx)
2425 })
2426 .unwrap();
2427 cx.run_until_parked();
2428 thread.read_with(cx, |thread, _| {
2429 assert_eq!(
2430 thread.to_markdown(),
2431 indoc! {"
2432 ## User
2433
2434 Hello
2435 "}
2436 );
2437 });
2438
2439 fake_model.send_last_completion_stream_text_chunk("Hey!");
2440 cx.run_until_parked();
2441 thread.read_with(cx, |thread, _| {
2442 assert_eq!(
2443 thread.to_markdown(),
2444 indoc! {"
2445 ## User
2446
2447 Hello
2448
2449 ## Assistant
2450
2451 Hey!
2452 "}
2453 );
2454 });
2455
2456 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2457 fake_model
2458 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2459 let events = events.collect::<Vec<_>>().await;
2460 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2461 thread.read_with(cx, |thread, _| {
2462 assert_eq!(thread.to_markdown(), "");
2463 });
2464}
2465
2466#[gpui::test]
2467async fn test_truncate_first_message(cx: &mut TestAppContext) {
2468 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2469 let fake_model = model.as_fake();
2470
2471 let message_id = UserMessageId::new();
2472 thread
2473 .update(cx, |thread, cx| {
2474 thread.send(message_id.clone(), ["Hello"], cx)
2475 })
2476 .unwrap();
2477 cx.run_until_parked();
2478 thread.read_with(cx, |thread, _| {
2479 assert_eq!(
2480 thread.to_markdown(),
2481 indoc! {"
2482 ## User
2483
2484 Hello
2485 "}
2486 );
2487 assert_eq!(thread.latest_token_usage(), None);
2488 });
2489
2490 fake_model.send_last_completion_stream_text_chunk("Hey!");
2491 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2492 language_model::TokenUsage {
2493 input_tokens: 32_000,
2494 output_tokens: 16_000,
2495 cache_creation_input_tokens: 0,
2496 cache_read_input_tokens: 0,
2497 },
2498 ));
2499 cx.run_until_parked();
2500 thread.read_with(cx, |thread, _| {
2501 assert_eq!(
2502 thread.to_markdown(),
2503 indoc! {"
2504 ## User
2505
2506 Hello
2507
2508 ## Assistant
2509
2510 Hey!
2511 "}
2512 );
2513 assert_eq!(
2514 thread.latest_token_usage(),
2515 Some(acp_thread::TokenUsage {
2516 used_tokens: 32_000 + 16_000,
2517 max_tokens: 1_000_000,
2518 input_tokens: 32_000,
2519 output_tokens: 16_000,
2520 })
2521 );
2522 });
2523
2524 thread
2525 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2526 .unwrap();
2527 cx.run_until_parked();
2528 thread.read_with(cx, |thread, _| {
2529 assert_eq!(thread.to_markdown(), "");
2530 assert_eq!(thread.latest_token_usage(), None);
2531 });
2532
2533 // Ensure we can still send a new message after truncation.
2534 thread
2535 .update(cx, |thread, cx| {
2536 thread.send(UserMessageId::new(), ["Hi"], cx)
2537 })
2538 .unwrap();
2539 thread.update(cx, |thread, _cx| {
2540 assert_eq!(
2541 thread.to_markdown(),
2542 indoc! {"
2543 ## User
2544
2545 Hi
2546 "}
2547 );
2548 });
2549 cx.run_until_parked();
2550 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2551 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2552 language_model::TokenUsage {
2553 input_tokens: 40_000,
2554 output_tokens: 20_000,
2555 cache_creation_input_tokens: 0,
2556 cache_read_input_tokens: 0,
2557 },
2558 ));
2559 cx.run_until_parked();
2560 thread.read_with(cx, |thread, _| {
2561 assert_eq!(
2562 thread.to_markdown(),
2563 indoc! {"
2564 ## User
2565
2566 Hi
2567
2568 ## Assistant
2569
2570 Ahoy!
2571 "}
2572 );
2573
2574 assert_eq!(
2575 thread.latest_token_usage(),
2576 Some(acp_thread::TokenUsage {
2577 used_tokens: 40_000 + 20_000,
2578 max_tokens: 1_000_000,
2579 input_tokens: 40_000,
2580 output_tokens: 20_000,
2581 })
2582 );
2583 });
2584}
2585
2586#[gpui::test]
2587async fn test_truncate_second_message(cx: &mut TestAppContext) {
2588 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2589 let fake_model = model.as_fake();
2590
2591 thread
2592 .update(cx, |thread, cx| {
2593 thread.send(UserMessageId::new(), ["Message 1"], cx)
2594 })
2595 .unwrap();
2596 cx.run_until_parked();
2597 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2598 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2599 language_model::TokenUsage {
2600 input_tokens: 32_000,
2601 output_tokens: 16_000,
2602 cache_creation_input_tokens: 0,
2603 cache_read_input_tokens: 0,
2604 },
2605 ));
2606 fake_model.end_last_completion_stream();
2607 cx.run_until_parked();
2608
2609 let assert_first_message_state = |cx: &mut TestAppContext| {
2610 thread.clone().read_with(cx, |thread, _| {
2611 assert_eq!(
2612 thread.to_markdown(),
2613 indoc! {"
2614 ## User
2615
2616 Message 1
2617
2618 ## Assistant
2619
2620 Message 1 response
2621 "}
2622 );
2623
2624 assert_eq!(
2625 thread.latest_token_usage(),
2626 Some(acp_thread::TokenUsage {
2627 used_tokens: 32_000 + 16_000,
2628 max_tokens: 1_000_000,
2629 input_tokens: 32_000,
2630 output_tokens: 16_000,
2631 })
2632 );
2633 });
2634 };
2635
2636 assert_first_message_state(cx);
2637
2638 let second_message_id = UserMessageId::new();
2639 thread
2640 .update(cx, |thread, cx| {
2641 thread.send(second_message_id.clone(), ["Message 2"], cx)
2642 })
2643 .unwrap();
2644 cx.run_until_parked();
2645
2646 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2647 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2648 language_model::TokenUsage {
2649 input_tokens: 40_000,
2650 output_tokens: 20_000,
2651 cache_creation_input_tokens: 0,
2652 cache_read_input_tokens: 0,
2653 },
2654 ));
2655 fake_model.end_last_completion_stream();
2656 cx.run_until_parked();
2657
2658 thread.read_with(cx, |thread, _| {
2659 assert_eq!(
2660 thread.to_markdown(),
2661 indoc! {"
2662 ## User
2663
2664 Message 1
2665
2666 ## Assistant
2667
2668 Message 1 response
2669
2670 ## User
2671
2672 Message 2
2673
2674 ## Assistant
2675
2676 Message 2 response
2677 "}
2678 );
2679
2680 assert_eq!(
2681 thread.latest_token_usage(),
2682 Some(acp_thread::TokenUsage {
2683 used_tokens: 40_000 + 20_000,
2684 max_tokens: 1_000_000,
2685 input_tokens: 40_000,
2686 output_tokens: 20_000,
2687 })
2688 );
2689 });
2690
2691 thread
2692 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2693 .unwrap();
2694 cx.run_until_parked();
2695
2696 assert_first_message_state(cx);
2697}
2698
2699#[gpui::test]
2700async fn test_title_generation(cx: &mut TestAppContext) {
2701 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2702 let fake_model = model.as_fake();
2703
2704 let summary_model = Arc::new(FakeLanguageModel::default());
2705 thread.update(cx, |thread, cx| {
2706 thread.set_summarization_model(Some(summary_model.clone()), cx)
2707 });
2708
2709 let send = thread
2710 .update(cx, |thread, cx| {
2711 thread.send(UserMessageId::new(), ["Hello"], cx)
2712 })
2713 .unwrap();
2714 cx.run_until_parked();
2715
2716 fake_model.send_last_completion_stream_text_chunk("Hey!");
2717 fake_model.end_last_completion_stream();
2718 cx.run_until_parked();
2719 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2720
2721 // Ensure the summary model has been invoked to generate a title.
2722 summary_model.send_last_completion_stream_text_chunk("Hello ");
2723 summary_model.send_last_completion_stream_text_chunk("world\nG");
2724 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2725 summary_model.end_last_completion_stream();
2726 send.collect::<Vec<_>>().await;
2727 cx.run_until_parked();
2728 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2729
2730 // Send another message, ensuring no title is generated this time.
2731 let send = thread
2732 .update(cx, |thread, cx| {
2733 thread.send(UserMessageId::new(), ["Hello again"], cx)
2734 })
2735 .unwrap();
2736 cx.run_until_parked();
2737 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2738 fake_model.end_last_completion_stream();
2739 cx.run_until_parked();
2740 assert_eq!(summary_model.pending_completions(), Vec::new());
2741 send.collect::<Vec<_>>().await;
2742 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2743}
2744
2745#[gpui::test]
2746async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2747 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2748 let fake_model = model.as_fake();
2749
2750 let _events = thread
2751 .update(cx, |thread, cx| {
2752 thread.add_tool(ToolRequiringPermission, None);
2753 thread.add_tool(EchoTool, None);
2754 thread.send(UserMessageId::new(), ["Hey!"], cx)
2755 })
2756 .unwrap();
2757 cx.run_until_parked();
2758
2759 let permission_tool_use = LanguageModelToolUse {
2760 id: "tool_id_1".into(),
2761 name: ToolRequiringPermission::NAME.into(),
2762 raw_input: "{}".into(),
2763 input: json!({}),
2764 is_input_complete: true,
2765 thought_signature: None,
2766 };
2767 let echo_tool_use = LanguageModelToolUse {
2768 id: "tool_id_2".into(),
2769 name: EchoTool::NAME.into(),
2770 raw_input: json!({"text": "test"}).to_string(),
2771 input: json!({"text": "test"}),
2772 is_input_complete: true,
2773 thought_signature: None,
2774 };
2775 fake_model.send_last_completion_stream_text_chunk("Hi!");
2776 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2777 permission_tool_use,
2778 ));
2779 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2780 echo_tool_use.clone(),
2781 ));
2782 fake_model.end_last_completion_stream();
2783 cx.run_until_parked();
2784
2785 // Ensure pending tools are skipped when building a request.
2786 let request = thread
2787 .read_with(cx, |thread, cx| {
2788 thread.build_completion_request(CompletionIntent::EditFile, cx)
2789 })
2790 .unwrap();
2791 assert_eq!(
2792 request.messages[1..],
2793 vec![
2794 LanguageModelRequestMessage {
2795 role: Role::User,
2796 content: vec!["Hey!".into()],
2797 cache: true,
2798 reasoning_details: None,
2799 },
2800 LanguageModelRequestMessage {
2801 role: Role::Assistant,
2802 content: vec![
2803 MessageContent::Text("Hi!".into()),
2804 MessageContent::ToolUse(echo_tool_use.clone())
2805 ],
2806 cache: false,
2807 reasoning_details: None,
2808 },
2809 LanguageModelRequestMessage {
2810 role: Role::User,
2811 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2812 tool_use_id: echo_tool_use.id.clone(),
2813 tool_name: echo_tool_use.name,
2814 is_error: false,
2815 content: "test".into(),
2816 output: Some("test".into())
2817 })],
2818 cache: false,
2819 reasoning_details: None,
2820 },
2821 ],
2822 );
2823}
2824
2825#[gpui::test]
2826async fn test_agent_connection(cx: &mut TestAppContext) {
2827 cx.update(settings::init);
2828 let templates = Templates::new();
2829
2830 // Initialize language model system with test provider
2831 cx.update(|cx| {
2832 gpui_tokio::init(cx);
2833
2834 let http_client = FakeHttpClient::with_404_response();
2835 let clock = Arc::new(clock::FakeSystemClock::new());
2836 let client = Client::new(clock, http_client, cx);
2837 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2838 language_model::init(client.clone(), cx);
2839 language_models::init(user_store, client.clone(), cx);
2840 LanguageModelRegistry::test(cx);
2841 });
2842 cx.executor().forbid_parking();
2843
2844 // Create a project for new_thread
2845 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2846 fake_fs.insert_tree(path!("/test"), json!({})).await;
2847 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2848 let cwd = Path::new("/test");
2849 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2850
2851 // Create agent and connection
2852 let agent = NativeAgent::new(
2853 project.clone(),
2854 thread_store,
2855 templates.clone(),
2856 None,
2857 fake_fs.clone(),
2858 &mut cx.to_async(),
2859 )
2860 .await
2861 .unwrap();
2862 let connection = NativeAgentConnection(agent.clone());
2863
2864 // Create a thread using new_thread
2865 let connection_rc = Rc::new(connection.clone());
2866 let acp_thread = cx
2867 .update(|cx| connection_rc.new_session(project, cwd, cx))
2868 .await
2869 .expect("new_thread should succeed");
2870
2871 // Get the session_id from the AcpThread
2872 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2873
2874 // Test model_selector returns Some
2875 let selector_opt = connection.model_selector(&session_id);
2876 assert!(
2877 selector_opt.is_some(),
2878 "agent should always support ModelSelector"
2879 );
2880 let selector = selector_opt.unwrap();
2881
2882 // Test list_models
2883 let listed_models = cx
2884 .update(|cx| selector.list_models(cx))
2885 .await
2886 .expect("list_models should succeed");
2887 let AgentModelList::Grouped(listed_models) = listed_models else {
2888 panic!("Unexpected model list type");
2889 };
2890 assert!(!listed_models.is_empty(), "should have at least one model");
2891 assert_eq!(
2892 listed_models[&AgentModelGroupName("Fake".into())][0]
2893 .id
2894 .0
2895 .as_ref(),
2896 "fake/fake"
2897 );
2898
2899 // Test selected_model returns the default
2900 let model = cx
2901 .update(|cx| selector.selected_model(cx))
2902 .await
2903 .expect("selected_model should succeed");
2904 let model = cx
2905 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2906 .unwrap();
2907 let model = model.as_fake();
2908 assert_eq!(model.id().0, "fake", "should return default model");
2909
2910 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2911 cx.run_until_parked();
2912 model.send_last_completion_stream_text_chunk("def");
2913 cx.run_until_parked();
2914 acp_thread.read_with(cx, |thread, cx| {
2915 assert_eq!(
2916 thread.to_markdown(cx),
2917 indoc! {"
2918 ## User
2919
2920 abc
2921
2922 ## Assistant
2923
2924 def
2925
2926 "}
2927 )
2928 });
2929
2930 // Test cancel
2931 cx.update(|cx| connection.cancel(&session_id, cx));
2932 request.await.expect("prompt should fail gracefully");
2933
2934 // Explicitly close the session and drop the ACP thread.
2935 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
2936 .await
2937 .unwrap();
2938 drop(acp_thread);
2939 let result = cx
2940 .update(|cx| {
2941 connection.prompt(
2942 Some(acp_thread::UserMessageId::new()),
2943 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2944 cx,
2945 )
2946 })
2947 .await;
2948 assert_eq!(
2949 result.as_ref().unwrap_err().to_string(),
2950 "Session not found",
2951 "unexpected result: {:?}",
2952 result
2953 );
2954}
2955
2956#[gpui::test]
2957async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2958 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2959 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool, None));
2960 let fake_model = model.as_fake();
2961
2962 let mut events = thread
2963 .update(cx, |thread, cx| {
2964 thread.send(UserMessageId::new(), ["Think"], cx)
2965 })
2966 .unwrap();
2967 cx.run_until_parked();
2968
2969 // Simulate streaming partial input.
2970 let input = json!({});
2971 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2972 LanguageModelToolUse {
2973 id: "1".into(),
2974 name: ThinkingTool::NAME.into(),
2975 raw_input: input.to_string(),
2976 input,
2977 is_input_complete: false,
2978 thought_signature: None,
2979 },
2980 ));
2981
2982 // Input streaming completed
2983 let input = json!({ "content": "Thinking hard!" });
2984 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2985 LanguageModelToolUse {
2986 id: "1".into(),
2987 name: "thinking".into(),
2988 raw_input: input.to_string(),
2989 input,
2990 is_input_complete: true,
2991 thought_signature: None,
2992 },
2993 ));
2994 fake_model.end_last_completion_stream();
2995 cx.run_until_parked();
2996
2997 let tool_call = expect_tool_call(&mut events).await;
2998 assert_eq!(
2999 tool_call,
3000 acp::ToolCall::new("1", "Thinking")
3001 .kind(acp::ToolKind::Think)
3002 .raw_input(json!({}))
3003 .meta(acp::Meta::from_iter([(
3004 "tool_name".into(),
3005 "thinking".into()
3006 )]))
3007 );
3008 let update = expect_tool_call_update_fields(&mut events).await;
3009 assert_eq!(
3010 update,
3011 acp::ToolCallUpdate::new(
3012 "1",
3013 acp::ToolCallUpdateFields::new()
3014 .title("Thinking")
3015 .kind(acp::ToolKind::Think)
3016 .raw_input(json!({ "content": "Thinking hard!"}))
3017 )
3018 );
3019 let update = expect_tool_call_update_fields(&mut events).await;
3020 assert_eq!(
3021 update,
3022 acp::ToolCallUpdate::new(
3023 "1",
3024 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3025 )
3026 );
3027 let update = expect_tool_call_update_fields(&mut events).await;
3028 assert_eq!(
3029 update,
3030 acp::ToolCallUpdate::new(
3031 "1",
3032 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
3033 )
3034 );
3035 let update = expect_tool_call_update_fields(&mut events).await;
3036 assert_eq!(
3037 update,
3038 acp::ToolCallUpdate::new(
3039 "1",
3040 acp::ToolCallUpdateFields::new()
3041 .status(acp::ToolCallStatus::Completed)
3042 .raw_output("Finished thinking.")
3043 )
3044 );
3045}
3046
3047#[gpui::test]
3048async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3049 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3050 let fake_model = model.as_fake();
3051
3052 let mut events = thread
3053 .update(cx, |thread, cx| {
3054 thread.send(UserMessageId::new(), ["Hello!"], cx)
3055 })
3056 .unwrap();
3057 cx.run_until_parked();
3058
3059 fake_model.send_last_completion_stream_text_chunk("Hey!");
3060 fake_model.end_last_completion_stream();
3061
3062 let mut retry_events = Vec::new();
3063 while let Some(Ok(event)) = events.next().await {
3064 match event {
3065 ThreadEvent::Retry(retry_status) => {
3066 retry_events.push(retry_status);
3067 }
3068 ThreadEvent::Stop(..) => break,
3069 _ => {}
3070 }
3071 }
3072
3073 assert_eq!(retry_events.len(), 0);
3074 thread.read_with(cx, |thread, _cx| {
3075 assert_eq!(
3076 thread.to_markdown(),
3077 indoc! {"
3078 ## User
3079
3080 Hello!
3081
3082 ## Assistant
3083
3084 Hey!
3085 "}
3086 )
3087 });
3088}
3089
3090#[gpui::test]
3091async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3092 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3093 let fake_model = model.as_fake();
3094
3095 let mut events = thread
3096 .update(cx, |thread, cx| {
3097 thread.send(UserMessageId::new(), ["Hello!"], cx)
3098 })
3099 .unwrap();
3100 cx.run_until_parked();
3101
3102 fake_model.send_last_completion_stream_text_chunk("Hey,");
3103 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3104 provider: LanguageModelProviderName::new("Anthropic"),
3105 retry_after: Some(Duration::from_secs(3)),
3106 });
3107 fake_model.end_last_completion_stream();
3108
3109 cx.executor().advance_clock(Duration::from_secs(3));
3110 cx.run_until_parked();
3111
3112 fake_model.send_last_completion_stream_text_chunk("there!");
3113 fake_model.end_last_completion_stream();
3114 cx.run_until_parked();
3115
3116 let mut retry_events = Vec::new();
3117 while let Some(Ok(event)) = events.next().await {
3118 match event {
3119 ThreadEvent::Retry(retry_status) => {
3120 retry_events.push(retry_status);
3121 }
3122 ThreadEvent::Stop(..) => break,
3123 _ => {}
3124 }
3125 }
3126
3127 assert_eq!(retry_events.len(), 1);
3128 assert!(matches!(
3129 retry_events[0],
3130 acp_thread::RetryStatus { attempt: 1, .. }
3131 ));
3132 thread.read_with(cx, |thread, _cx| {
3133 assert_eq!(
3134 thread.to_markdown(),
3135 indoc! {"
3136 ## User
3137
3138 Hello!
3139
3140 ## Assistant
3141
3142 Hey,
3143
3144 [resume]
3145
3146 ## Assistant
3147
3148 there!
3149 "}
3150 )
3151 });
3152}
3153
3154#[gpui::test]
3155async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3156 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3157 let fake_model = model.as_fake();
3158
3159 let events = thread
3160 .update(cx, |thread, cx| {
3161 thread.add_tool(EchoTool, None);
3162 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3163 })
3164 .unwrap();
3165 cx.run_until_parked();
3166
3167 let tool_use_1 = LanguageModelToolUse {
3168 id: "tool_1".into(),
3169 name: EchoTool::NAME.into(),
3170 raw_input: json!({"text": "test"}).to_string(),
3171 input: json!({"text": "test"}),
3172 is_input_complete: true,
3173 thought_signature: None,
3174 };
3175 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3176 tool_use_1.clone(),
3177 ));
3178 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3179 provider: LanguageModelProviderName::new("Anthropic"),
3180 retry_after: Some(Duration::from_secs(3)),
3181 });
3182 fake_model.end_last_completion_stream();
3183
3184 cx.executor().advance_clock(Duration::from_secs(3));
3185 let completion = fake_model.pending_completions().pop().unwrap();
3186 assert_eq!(
3187 completion.messages[1..],
3188 vec![
3189 LanguageModelRequestMessage {
3190 role: Role::User,
3191 content: vec!["Call the echo tool!".into()],
3192 cache: false,
3193 reasoning_details: None,
3194 },
3195 LanguageModelRequestMessage {
3196 role: Role::Assistant,
3197 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3198 cache: false,
3199 reasoning_details: None,
3200 },
3201 LanguageModelRequestMessage {
3202 role: Role::User,
3203 content: vec![language_model::MessageContent::ToolResult(
3204 LanguageModelToolResult {
3205 tool_use_id: tool_use_1.id.clone(),
3206 tool_name: tool_use_1.name.clone(),
3207 is_error: false,
3208 content: "test".into(),
3209 output: Some("test".into())
3210 }
3211 )],
3212 cache: true,
3213 reasoning_details: None,
3214 },
3215 ]
3216 );
3217
3218 fake_model.send_last_completion_stream_text_chunk("Done");
3219 fake_model.end_last_completion_stream();
3220 cx.run_until_parked();
3221 events.collect::<Vec<_>>().await;
3222 thread.read_with(cx, |thread, _cx| {
3223 assert_eq!(
3224 thread.last_message(),
3225 Some(Message::Agent(AgentMessage {
3226 content: vec![AgentMessageContent::Text("Done".into())],
3227 tool_results: IndexMap::default(),
3228 reasoning_details: None,
3229 }))
3230 );
3231 })
3232}
3233
3234#[gpui::test]
3235async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3236 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3237 let fake_model = model.as_fake();
3238
3239 let mut events = thread
3240 .update(cx, |thread, cx| {
3241 thread.send(UserMessageId::new(), ["Hello!"], cx)
3242 })
3243 .unwrap();
3244 cx.run_until_parked();
3245
3246 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3247 fake_model.send_last_completion_stream_error(
3248 LanguageModelCompletionError::ServerOverloaded {
3249 provider: LanguageModelProviderName::new("Anthropic"),
3250 retry_after: Some(Duration::from_secs(3)),
3251 },
3252 );
3253 fake_model.end_last_completion_stream();
3254 cx.executor().advance_clock(Duration::from_secs(3));
3255 cx.run_until_parked();
3256 }
3257
3258 let mut errors = Vec::new();
3259 let mut retry_events = Vec::new();
3260 while let Some(event) = events.next().await {
3261 match event {
3262 Ok(ThreadEvent::Retry(retry_status)) => {
3263 retry_events.push(retry_status);
3264 }
3265 Ok(ThreadEvent::Stop(..)) => break,
3266 Err(error) => errors.push(error),
3267 _ => {}
3268 }
3269 }
3270
3271 assert_eq!(
3272 retry_events.len(),
3273 crate::thread::MAX_RETRY_ATTEMPTS as usize
3274 );
3275 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3276 assert_eq!(retry_events[i].attempt, i + 1);
3277 }
3278 assert_eq!(errors.len(), 1);
3279 let error = errors[0]
3280 .downcast_ref::<LanguageModelCompletionError>()
3281 .unwrap();
3282 assert!(matches!(
3283 error,
3284 LanguageModelCompletionError::ServerOverloaded { .. }
3285 ));
3286}
3287
3288/// Filters out the stop events for asserting against in tests
3289fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3290 result_events
3291 .into_iter()
3292 .filter_map(|event| match event.unwrap() {
3293 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3294 _ => None,
3295 })
3296 .collect()
3297}
3298
3299struct ThreadTest {
3300 model: Arc<dyn LanguageModel>,
3301 thread: Entity<Thread>,
3302 project_context: Entity<ProjectContext>,
3303 context_server_store: Entity<ContextServerStore>,
3304 fs: Arc<FakeFs>,
3305}
3306
3307enum TestModel {
3308 Sonnet4,
3309 Fake,
3310}
3311
3312impl TestModel {
3313 fn id(&self) -> LanguageModelId {
3314 match self {
3315 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3316 TestModel::Fake => unreachable!(),
3317 }
3318 }
3319}
3320
3321async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3322 cx.executor().allow_parking();
3323
3324 let fs = FakeFs::new(cx.background_executor.clone());
3325 fs.create_dir(paths::settings_file().parent().unwrap())
3326 .await
3327 .unwrap();
3328 fs.insert_file(
3329 paths::settings_file(),
3330 json!({
3331 "agent": {
3332 "default_profile": "test-profile",
3333 "profiles": {
3334 "test-profile": {
3335 "name": "Test Profile",
3336 "tools": {
3337 EchoTool::NAME: true,
3338 DelayTool::NAME: true,
3339 WordListTool::NAME: true,
3340 ToolRequiringPermission::NAME: true,
3341 InfiniteTool::NAME: true,
3342 CancellationAwareTool::NAME: true,
3343 ThinkingTool::NAME: true,
3344 (TerminalTool::NAME): true,
3345 }
3346 }
3347 }
3348 }
3349 })
3350 .to_string()
3351 .into_bytes(),
3352 )
3353 .await;
3354
3355 cx.update(|cx| {
3356 settings::init(cx);
3357
3358 match model {
3359 TestModel::Fake => {}
3360 TestModel::Sonnet4 => {
3361 gpui_tokio::init(cx);
3362 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3363 cx.set_http_client(Arc::new(http_client));
3364 let client = Client::production(cx);
3365 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3366 language_model::init(client.clone(), cx);
3367 language_models::init(user_store, client.clone(), cx);
3368 }
3369 };
3370
3371 watch_settings(fs.clone(), cx);
3372 });
3373
3374 let templates = Templates::new();
3375
3376 fs.insert_tree(path!("/test"), json!({})).await;
3377 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3378
3379 let model = cx
3380 .update(|cx| {
3381 if let TestModel::Fake = model {
3382 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3383 } else {
3384 let model_id = model.id();
3385 let models = LanguageModelRegistry::read_global(cx);
3386 let model = models
3387 .available_models(cx)
3388 .find(|model| model.id() == model_id)
3389 .unwrap();
3390
3391 let provider = models.provider(&model.provider_id()).unwrap();
3392 let authenticated = provider.authenticate(cx);
3393
3394 cx.spawn(async move |_cx| {
3395 authenticated.await.unwrap();
3396 model
3397 })
3398 }
3399 })
3400 .await;
3401
3402 let project_context = cx.new(|_cx| ProjectContext::default());
3403 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3404 let context_server_registry =
3405 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3406 let thread = cx.new(|cx| {
3407 Thread::new(
3408 project,
3409 project_context.clone(),
3410 context_server_registry,
3411 templates,
3412 Some(model.clone()),
3413 cx,
3414 )
3415 });
3416 ThreadTest {
3417 model,
3418 thread,
3419 project_context,
3420 context_server_store,
3421 fs,
3422 }
3423}
3424
3425#[cfg(test)]
3426#[ctor::ctor]
3427fn init_logger() {
3428 if std::env::var("RUST_LOG").is_ok() {
3429 env_logger::init();
3430 }
3431}
3432
3433fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3434 let fs = fs.clone();
3435 cx.spawn({
3436 async move |cx| {
3437 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3438 cx.background_executor(),
3439 fs,
3440 paths::settings_file().clone(),
3441 );
3442 let _watcher_task = watcher_task;
3443
3444 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3445 cx.update(|cx| {
3446 SettingsStore::update_global(cx, |settings, cx| {
3447 settings.set_user_settings(&new_settings_content, cx)
3448 })
3449 })
3450 .ok();
3451 }
3452 }
3453 })
3454 .detach();
3455}
3456
3457fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3458 completion
3459 .tools
3460 .iter()
3461 .map(|tool| tool.name.clone())
3462 .collect()
3463}
3464
3465fn setup_context_server(
3466 name: &'static str,
3467 tools: Vec<context_server::types::Tool>,
3468 context_server_store: &Entity<ContextServerStore>,
3469 cx: &mut TestAppContext,
3470) -> mpsc::UnboundedReceiver<(
3471 context_server::types::CallToolParams,
3472 oneshot::Sender<context_server::types::CallToolResponse>,
3473)> {
3474 cx.update(|cx| {
3475 let mut settings = ProjectSettings::get_global(cx).clone();
3476 settings.context_servers.insert(
3477 name.into(),
3478 project::project_settings::ContextServerSettings::Stdio {
3479 enabled: true,
3480 remote: false,
3481 command: ContextServerCommand {
3482 path: "somebinary".into(),
3483 args: Vec::new(),
3484 env: None,
3485 timeout: None,
3486 },
3487 },
3488 );
3489 ProjectSettings::override_global(settings, cx);
3490 });
3491
3492 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3493 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3494 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3495 context_server::types::InitializeResponse {
3496 protocol_version: context_server::types::ProtocolVersion(
3497 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3498 ),
3499 server_info: context_server::types::Implementation {
3500 name: name.into(),
3501 version: "1.0.0".to_string(),
3502 },
3503 capabilities: context_server::types::ServerCapabilities {
3504 tools: Some(context_server::types::ToolsCapabilities {
3505 list_changed: Some(true),
3506 }),
3507 ..Default::default()
3508 },
3509 meta: None,
3510 }
3511 })
3512 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3513 let tools = tools.clone();
3514 async move {
3515 context_server::types::ListToolsResponse {
3516 tools,
3517 next_cursor: None,
3518 meta: None,
3519 }
3520 }
3521 })
3522 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3523 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3524 async move {
3525 let (response_tx, response_rx) = oneshot::channel();
3526 mcp_tool_calls_tx
3527 .unbounded_send((params, response_tx))
3528 .unwrap();
3529 response_rx.await.unwrap()
3530 }
3531 });
3532 context_server_store.update(cx, |store, cx| {
3533 store.start_server(
3534 Arc::new(ContextServer::new(
3535 ContextServerId(name.into()),
3536 Arc::new(fake_transport),
3537 )),
3538 cx,
3539 );
3540 });
3541 cx.run_until_parked();
3542 mcp_tool_calls_rx
3543}
3544
3545#[gpui::test]
3546async fn test_tokens_before_message(cx: &mut TestAppContext) {
3547 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3548 let fake_model = model.as_fake();
3549
3550 // First message
3551 let message_1_id = UserMessageId::new();
3552 thread
3553 .update(cx, |thread, cx| {
3554 thread.send(message_1_id.clone(), ["First message"], cx)
3555 })
3556 .unwrap();
3557 cx.run_until_parked();
3558
3559 // Before any response, tokens_before_message should return None for first message
3560 thread.read_with(cx, |thread, _| {
3561 assert_eq!(
3562 thread.tokens_before_message(&message_1_id),
3563 None,
3564 "First message should have no tokens before it"
3565 );
3566 });
3567
3568 // Complete first message with usage
3569 fake_model.send_last_completion_stream_text_chunk("Response 1");
3570 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3571 language_model::TokenUsage {
3572 input_tokens: 100,
3573 output_tokens: 50,
3574 cache_creation_input_tokens: 0,
3575 cache_read_input_tokens: 0,
3576 },
3577 ));
3578 fake_model.end_last_completion_stream();
3579 cx.run_until_parked();
3580
3581 // First message still has no tokens before it
3582 thread.read_with(cx, |thread, _| {
3583 assert_eq!(
3584 thread.tokens_before_message(&message_1_id),
3585 None,
3586 "First message should still have no tokens before it after response"
3587 );
3588 });
3589
3590 // Second message
3591 let message_2_id = UserMessageId::new();
3592 thread
3593 .update(cx, |thread, cx| {
3594 thread.send(message_2_id.clone(), ["Second message"], cx)
3595 })
3596 .unwrap();
3597 cx.run_until_parked();
3598
3599 // Second message should have first message's input tokens before it
3600 thread.read_with(cx, |thread, _| {
3601 assert_eq!(
3602 thread.tokens_before_message(&message_2_id),
3603 Some(100),
3604 "Second message should have 100 tokens before it (from first request)"
3605 );
3606 });
3607
3608 // Complete second message
3609 fake_model.send_last_completion_stream_text_chunk("Response 2");
3610 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3611 language_model::TokenUsage {
3612 input_tokens: 250, // Total for this request (includes previous context)
3613 output_tokens: 75,
3614 cache_creation_input_tokens: 0,
3615 cache_read_input_tokens: 0,
3616 },
3617 ));
3618 fake_model.end_last_completion_stream();
3619 cx.run_until_parked();
3620
3621 // Third message
3622 let message_3_id = UserMessageId::new();
3623 thread
3624 .update(cx, |thread, cx| {
3625 thread.send(message_3_id.clone(), ["Third message"], cx)
3626 })
3627 .unwrap();
3628 cx.run_until_parked();
3629
3630 // Third message should have second message's input tokens (250) before it
3631 thread.read_with(cx, |thread, _| {
3632 assert_eq!(
3633 thread.tokens_before_message(&message_3_id),
3634 Some(250),
3635 "Third message should have 250 tokens before it (from second request)"
3636 );
3637 // Second message should still have 100
3638 assert_eq!(
3639 thread.tokens_before_message(&message_2_id),
3640 Some(100),
3641 "Second message should still have 100 tokens before it"
3642 );
3643 // First message still has none
3644 assert_eq!(
3645 thread.tokens_before_message(&message_1_id),
3646 None,
3647 "First message should still have no tokens before it"
3648 );
3649 });
3650}
3651
3652#[gpui::test]
3653async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3654 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3655 let fake_model = model.as_fake();
3656
3657 // Set up three messages with responses
3658 let message_1_id = UserMessageId::new();
3659 thread
3660 .update(cx, |thread, cx| {
3661 thread.send(message_1_id.clone(), ["Message 1"], cx)
3662 })
3663 .unwrap();
3664 cx.run_until_parked();
3665 fake_model.send_last_completion_stream_text_chunk("Response 1");
3666 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3667 language_model::TokenUsage {
3668 input_tokens: 100,
3669 output_tokens: 50,
3670 cache_creation_input_tokens: 0,
3671 cache_read_input_tokens: 0,
3672 },
3673 ));
3674 fake_model.end_last_completion_stream();
3675 cx.run_until_parked();
3676
3677 let message_2_id = UserMessageId::new();
3678 thread
3679 .update(cx, |thread, cx| {
3680 thread.send(message_2_id.clone(), ["Message 2"], cx)
3681 })
3682 .unwrap();
3683 cx.run_until_parked();
3684 fake_model.send_last_completion_stream_text_chunk("Response 2");
3685 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3686 language_model::TokenUsage {
3687 input_tokens: 250,
3688 output_tokens: 75,
3689 cache_creation_input_tokens: 0,
3690 cache_read_input_tokens: 0,
3691 },
3692 ));
3693 fake_model.end_last_completion_stream();
3694 cx.run_until_parked();
3695
3696 // Verify initial state
3697 thread.read_with(cx, |thread, _| {
3698 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3699 });
3700
3701 // Truncate at message 2 (removes message 2 and everything after)
3702 thread
3703 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3704 .unwrap();
3705 cx.run_until_parked();
3706
3707 // After truncation, message_2_id no longer exists, so lookup should return None
3708 thread.read_with(cx, |thread, _| {
3709 assert_eq!(
3710 thread.tokens_before_message(&message_2_id),
3711 None,
3712 "After truncation, message 2 no longer exists"
3713 );
3714 // Message 1 still exists but has no tokens before it
3715 assert_eq!(
3716 thread.tokens_before_message(&message_1_id),
3717 None,
3718 "First message still has no tokens before it"
3719 );
3720 });
3721}
3722
3723#[gpui::test]
3724async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3725 init_test(cx);
3726
3727 let fs = FakeFs::new(cx.executor());
3728 fs.insert_tree("/root", json!({})).await;
3729 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3730
3731 // Test 1: Deny rule blocks command
3732 {
3733 let environment = Rc::new(cx.update(|cx| {
3734 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3735 }));
3736
3737 cx.update(|cx| {
3738 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3739 settings.tool_permissions.tools.insert(
3740 TerminalTool::NAME.into(),
3741 agent_settings::ToolRules {
3742 default: Some(settings::ToolPermissionMode::Confirm),
3743 always_allow: vec![],
3744 always_deny: vec![
3745 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3746 ],
3747 always_confirm: vec![],
3748 invalid_patterns: vec![],
3749 },
3750 );
3751 agent_settings::AgentSettings::override_global(settings, cx);
3752 });
3753
3754 #[allow(clippy::arc_with_non_send_sync)]
3755 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3756 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3757
3758 let task = cx.update(|cx| {
3759 tool.run(
3760 crate::TerminalToolInput {
3761 command: "rm -rf /".to_string(),
3762 cd: ".".to_string(),
3763 timeout_ms: None,
3764 },
3765 event_stream,
3766 cx,
3767 )
3768 });
3769
3770 let result = task.await;
3771 assert!(
3772 result.is_err(),
3773 "expected command to be blocked by deny rule"
3774 );
3775 let err_msg = result.unwrap_err().to_string().to_lowercase();
3776 assert!(
3777 err_msg.contains("blocked"),
3778 "error should mention the command was blocked"
3779 );
3780 }
3781
3782 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
3783 {
3784 let environment = Rc::new(cx.update(|cx| {
3785 FakeThreadEnvironment::default()
3786 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3787 }));
3788
3789 cx.update(|cx| {
3790 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3791 settings.tool_permissions.tools.insert(
3792 TerminalTool::NAME.into(),
3793 agent_settings::ToolRules {
3794 default: Some(settings::ToolPermissionMode::Deny),
3795 always_allow: vec![
3796 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3797 ],
3798 always_deny: vec![],
3799 always_confirm: vec![],
3800 invalid_patterns: vec![],
3801 },
3802 );
3803 agent_settings::AgentSettings::override_global(settings, cx);
3804 });
3805
3806 #[allow(clippy::arc_with_non_send_sync)]
3807 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3808 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3809
3810 let task = cx.update(|cx| {
3811 tool.run(
3812 crate::TerminalToolInput {
3813 command: "echo hello".to_string(),
3814 cd: ".".to_string(),
3815 timeout_ms: None,
3816 },
3817 event_stream,
3818 cx,
3819 )
3820 });
3821
3822 let update = rx.expect_update_fields().await;
3823 assert!(
3824 update.content.iter().any(|blocks| {
3825 blocks
3826 .iter()
3827 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3828 }),
3829 "expected terminal content (allow rule should skip confirmation and override default deny)"
3830 );
3831
3832 let result = task.await;
3833 assert!(
3834 result.is_ok(),
3835 "expected command to succeed without confirmation"
3836 );
3837 }
3838
3839 // Test 3: global default: allow does NOT override always_confirm patterns
3840 {
3841 let environment = Rc::new(cx.update(|cx| {
3842 FakeThreadEnvironment::default()
3843 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3844 }));
3845
3846 cx.update(|cx| {
3847 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3848 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
3849 settings.tool_permissions.tools.insert(
3850 TerminalTool::NAME.into(),
3851 agent_settings::ToolRules {
3852 default: Some(settings::ToolPermissionMode::Allow),
3853 always_allow: vec![],
3854 always_deny: vec![],
3855 always_confirm: vec![
3856 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3857 ],
3858 invalid_patterns: vec![],
3859 },
3860 );
3861 agent_settings::AgentSettings::override_global(settings, cx);
3862 });
3863
3864 #[allow(clippy::arc_with_non_send_sync)]
3865 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3866 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3867
3868 let _task = cx.update(|cx| {
3869 tool.run(
3870 crate::TerminalToolInput {
3871 command: "sudo rm file".to_string(),
3872 cd: ".".to_string(),
3873 timeout_ms: None,
3874 },
3875 event_stream,
3876 cx,
3877 )
3878 });
3879
3880 // With global default: allow, confirm patterns are still respected
3881 // The expect_authorization() call will panic if no authorization is requested,
3882 // which validates that the confirm pattern still triggers confirmation
3883 let _auth = rx.expect_authorization().await;
3884
3885 drop(_task);
3886 }
3887
3888 // Test 4: tool-specific default: deny is respected even with global default: allow
3889 {
3890 let environment = Rc::new(cx.update(|cx| {
3891 FakeThreadEnvironment::default()
3892 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3893 }));
3894
3895 cx.update(|cx| {
3896 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3897 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
3898 settings.tool_permissions.tools.insert(
3899 TerminalTool::NAME.into(),
3900 agent_settings::ToolRules {
3901 default: Some(settings::ToolPermissionMode::Deny),
3902 always_allow: vec![],
3903 always_deny: vec![],
3904 always_confirm: vec![],
3905 invalid_patterns: vec![],
3906 },
3907 );
3908 agent_settings::AgentSettings::override_global(settings, cx);
3909 });
3910
3911 #[allow(clippy::arc_with_non_send_sync)]
3912 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3913 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3914
3915 let task = cx.update(|cx| {
3916 tool.run(
3917 crate::TerminalToolInput {
3918 command: "echo hello".to_string(),
3919 cd: ".".to_string(),
3920 timeout_ms: None,
3921 },
3922 event_stream,
3923 cx,
3924 )
3925 });
3926
3927 // tool-specific default: deny is respected even with global default: allow
3928 let result = task.await;
3929 assert!(
3930 result.is_err(),
3931 "expected command to be blocked by tool-specific deny default"
3932 );
3933 let err_msg = result.unwrap_err().to_string().to_lowercase();
3934 assert!(
3935 err_msg.contains("disabled"),
3936 "error should mention the tool is disabled, got: {err_msg}"
3937 );
3938 }
3939}
3940
3941#[gpui::test]
3942async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3943 init_test(cx);
3944
3945 cx.update(|cx| {
3946 cx.update_flags(true, vec!["subagents".to_string()]);
3947 });
3948
3949 let fs = FakeFs::new(cx.executor());
3950 fs.insert_tree(path!("/test"), json!({})).await;
3951 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3952 let project_context = cx.new(|_cx| ProjectContext::default());
3953 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3954 let context_server_registry =
3955 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3956 let model = Arc::new(FakeLanguageModel::default());
3957
3958 let environment = Rc::new(cx.update(|cx| {
3959 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3960 }));
3961
3962 let thread = cx.new(|cx| {
3963 let mut thread = Thread::new(
3964 project.clone(),
3965 project_context,
3966 context_server_registry,
3967 Templates::new(),
3968 Some(model),
3969 cx,
3970 );
3971 thread.add_default_tools(None, environment, cx);
3972 thread
3973 });
3974
3975 thread.read_with(cx, |thread, _| {
3976 assert!(
3977 thread.has_registered_tool(SubagentTool::NAME),
3978 "subagent tool should be present when feature flag is enabled"
3979 );
3980 });
3981}
3982
3983#[gpui::test]
3984async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
3985 init_test(cx);
3986
3987 cx.update(|cx| {
3988 cx.update_flags(true, vec!["subagents".to_string()]);
3989 });
3990
3991 let fs = FakeFs::new(cx.executor());
3992 fs.insert_tree(path!("/test"), json!({})).await;
3993 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3994 let project_context = cx.new(|_cx| ProjectContext::default());
3995 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3996 let context_server_registry =
3997 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3998 let model = Arc::new(FakeLanguageModel::default());
3999
4000 let parent_thread = cx.new(|cx| {
4001 Thread::new(
4002 project.clone(),
4003 project_context,
4004 context_server_registry,
4005 Templates::new(),
4006 Some(model.clone()),
4007 cx,
4008 )
4009 });
4010
4011 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4012 subagent_thread.read_with(cx, |subagent_thread, cx| {
4013 assert!(subagent_thread.is_subagent());
4014 assert_eq!(subagent_thread.depth(), 1);
4015 assert_eq!(
4016 subagent_thread.model().map(|model| model.id()),
4017 Some(model.id())
4018 );
4019 assert_eq!(
4020 subagent_thread.parent_thread_id(),
4021 Some(parent_thread.read(cx).id().clone())
4022 );
4023 });
4024}
4025
4026#[gpui::test]
4027async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4028 init_test(cx);
4029
4030 cx.update(|cx| {
4031 cx.update_flags(true, vec!["subagents".to_string()]);
4032 });
4033
4034 let fs = FakeFs::new(cx.executor());
4035 fs.insert_tree(path!("/test"), json!({})).await;
4036 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4037 let project_context = cx.new(|_cx| ProjectContext::default());
4038 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4039 let context_server_registry =
4040 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4041 let model = Arc::new(FakeLanguageModel::default());
4042 let environment = Rc::new(cx.update(|cx| {
4043 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4044 }));
4045
4046 let deep_parent_thread = cx.new(|cx| {
4047 let mut thread = Thread::new(
4048 project.clone(),
4049 project_context,
4050 context_server_registry,
4051 Templates::new(),
4052 Some(model.clone()),
4053 cx,
4054 );
4055 thread.set_subagent_context(SubagentContext {
4056 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4057 depth: MAX_SUBAGENT_DEPTH - 1,
4058 });
4059 thread
4060 });
4061 let deep_subagent_thread = cx.new(|cx| {
4062 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4063 thread.add_default_tools(None, environment, cx);
4064 thread
4065 });
4066
4067 deep_subagent_thread.read_with(cx, |thread, _| {
4068 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4069 assert!(
4070 !thread.has_registered_tool(SubagentTool::NAME),
4071 "subagent tool should not be present at max depth"
4072 );
4073 });
4074}
4075
4076#[gpui::test]
4077async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4078 init_test(cx);
4079
4080 cx.update(|cx| {
4081 cx.update_flags(true, vec!["subagents".to_string()]);
4082 });
4083
4084 let fs = FakeFs::new(cx.executor());
4085 fs.insert_tree(path!("/test"), json!({})).await;
4086 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4087 let project_context = cx.new(|_cx| ProjectContext::default());
4088 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4089 let context_server_registry =
4090 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4091 let model = Arc::new(FakeLanguageModel::default());
4092
4093 let parent = cx.new(|cx| {
4094 Thread::new(
4095 project.clone(),
4096 project_context.clone(),
4097 context_server_registry.clone(),
4098 Templates::new(),
4099 Some(model.clone()),
4100 cx,
4101 )
4102 });
4103
4104 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4105
4106 parent.update(cx, |thread, _cx| {
4107 thread.register_running_subagent(subagent.downgrade());
4108 });
4109
4110 subagent
4111 .update(cx, |thread, cx| {
4112 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4113 })
4114 .unwrap();
4115 cx.run_until_parked();
4116
4117 subagent.read_with(cx, |thread, _| {
4118 assert!(!thread.is_turn_complete(), "subagent should be running");
4119 });
4120
4121 parent.update(cx, |thread, cx| {
4122 thread.cancel(cx).detach();
4123 });
4124
4125 subagent.read_with(cx, |thread, _| {
4126 assert!(
4127 thread.is_turn_complete(),
4128 "subagent should be cancelled when parent cancels"
4129 );
4130 });
4131}
4132
4133#[gpui::test]
4134async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4135 // This test verifies that the subagent tool properly handles user cancellation
4136 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4137 init_test(cx);
4138 always_allow_tools(cx);
4139
4140 cx.update(|cx| {
4141 cx.update_flags(true, vec!["subagents".to_string()]);
4142 });
4143
4144 let fs = FakeFs::new(cx.executor());
4145 fs.insert_tree(path!("/test"), json!({})).await;
4146 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4147 let project_context = cx.new(|_cx| ProjectContext::default());
4148 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4149 let context_server_registry =
4150 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4151 let model = Arc::new(FakeLanguageModel::default());
4152 let environment = Rc::new(cx.update(|cx| {
4153 FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
4154 }));
4155
4156 let parent = cx.new(|cx| {
4157 Thread::new(
4158 project.clone(),
4159 project_context.clone(),
4160 context_server_registry.clone(),
4161 Templates::new(),
4162 Some(model.clone()),
4163 cx,
4164 )
4165 });
4166
4167 #[allow(clippy::arc_with_non_send_sync)]
4168 let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
4169
4170 let (event_stream, _rx, mut cancellation_tx) =
4171 crate::ToolCallEventStream::test_with_cancellation();
4172
4173 // Start the subagent tool
4174 let task = cx.update(|cx| {
4175 tool.run(
4176 SubagentToolInput {
4177 label: "Long running task".to_string(),
4178 task_prompt: "Do a very long task that takes forever".to_string(),
4179 summary_prompt: "Summarize".to_string(),
4180 timeout_ms: None,
4181 allowed_tools: None,
4182 },
4183 event_stream.clone(),
4184 cx,
4185 )
4186 });
4187
4188 cx.run_until_parked();
4189
4190 // Signal cancellation via the event stream
4191 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4192
4193 // The task should complete promptly with a cancellation error
4194 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4195 let result = futures::select! {
4196 result = task.fuse() => result,
4197 _ = timeout.fuse() => {
4198 panic!("subagent tool did not respond to cancellation within timeout");
4199 }
4200 };
4201
4202 // Verify we got a cancellation error
4203 let err = result.unwrap_err();
4204 assert!(
4205 err.to_string().contains("cancelled by user"),
4206 "expected cancellation error, got: {}",
4207 err
4208 );
4209}
4210
4211#[gpui::test]
4212async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4213 init_test(cx);
4214 always_allow_tools(cx);
4215
4216 cx.update(|cx| {
4217 cx.update_flags(true, vec!["subagents".to_string()]);
4218 });
4219
4220 let fs = FakeFs::new(cx.executor());
4221 fs.insert_tree(path!("/test"), json!({})).await;
4222 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4223 let project_context = cx.new(|_cx| ProjectContext::default());
4224 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4225 let context_server_registry =
4226 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4227 cx.update(LanguageModelRegistry::test);
4228 let model = Arc::new(FakeLanguageModel::default());
4229 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4230 let native_agent = NativeAgent::new(
4231 project.clone(),
4232 thread_store,
4233 Templates::new(),
4234 None,
4235 fs,
4236 &mut cx.to_async(),
4237 )
4238 .await
4239 .unwrap();
4240 let parent_thread = cx.new(|cx| {
4241 Thread::new(
4242 project.clone(),
4243 project_context,
4244 context_server_registry,
4245 Templates::new(),
4246 Some(model.clone()),
4247 cx,
4248 )
4249 });
4250
4251 let mut handles = Vec::new();
4252 for _ in 0..MAX_PARALLEL_SUBAGENTS {
4253 let handle = cx
4254 .update(|cx| {
4255 NativeThreadEnvironment::create_subagent_thread(
4256 native_agent.downgrade(),
4257 parent_thread.clone(),
4258 "some title".to_string(),
4259 "some task".to_string(),
4260 None,
4261 None,
4262 cx,
4263 )
4264 })
4265 .expect("Expected to be able to create subagent thread");
4266 handles.push(handle);
4267 }
4268
4269 let result = cx.update(|cx| {
4270 NativeThreadEnvironment::create_subagent_thread(
4271 native_agent.downgrade(),
4272 parent_thread.clone(),
4273 "some title".to_string(),
4274 "some task".to_string(),
4275 None,
4276 None,
4277 cx,
4278 )
4279 });
4280 assert!(result.is_err());
4281 assert_eq!(
4282 result.err().unwrap().to_string(),
4283 format!(
4284 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
4285 MAX_PARALLEL_SUBAGENTS
4286 )
4287 );
4288}
4289
4290#[gpui::test]
4291async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4292 init_test(cx);
4293
4294 always_allow_tools(cx);
4295
4296 cx.update(|cx| {
4297 cx.update_flags(true, vec!["subagents".to_string()]);
4298 });
4299
4300 let fs = FakeFs::new(cx.executor());
4301 fs.insert_tree(path!("/test"), json!({})).await;
4302 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4303 let project_context = cx.new(|_cx| ProjectContext::default());
4304 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4305 let context_server_registry =
4306 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4307 cx.update(LanguageModelRegistry::test);
4308 let model = Arc::new(FakeLanguageModel::default());
4309 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4310 let native_agent = NativeAgent::new(
4311 project.clone(),
4312 thread_store,
4313 Templates::new(),
4314 None,
4315 fs,
4316 &mut cx.to_async(),
4317 )
4318 .await
4319 .unwrap();
4320 let parent_thread = cx.new(|cx| {
4321 Thread::new(
4322 project.clone(),
4323 project_context,
4324 context_server_registry,
4325 Templates::new(),
4326 Some(model.clone()),
4327 cx,
4328 )
4329 });
4330
4331 let subagent_handle = cx
4332 .update(|cx| {
4333 NativeThreadEnvironment::create_subagent_thread(
4334 native_agent.downgrade(),
4335 parent_thread.clone(),
4336 "some title".to_string(),
4337 "task prompt".to_string(),
4338 Some(Duration::from_millis(10)),
4339 None,
4340 cx,
4341 )
4342 })
4343 .expect("Failed to create subagent");
4344
4345 let summary_task =
4346 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4347
4348 cx.run_until_parked();
4349
4350 {
4351 let messages = model.pending_completions().last().unwrap().messages.clone();
4352 // Ensure that model received a system prompt
4353 assert_eq!(messages[0].role, Role::System);
4354 // Ensure that model received a task prompt
4355 assert_eq!(messages[1].role, Role::User);
4356 assert_eq!(
4357 messages[1].content,
4358 vec![MessageContent::Text("task prompt".to_string())]
4359 );
4360 }
4361
4362 model.send_last_completion_stream_text_chunk("Some task response...");
4363 model.end_last_completion_stream();
4364
4365 cx.run_until_parked();
4366
4367 {
4368 let messages = model.pending_completions().last().unwrap().messages.clone();
4369 assert_eq!(messages[2].role, Role::Assistant);
4370 assert_eq!(
4371 messages[2].content,
4372 vec![MessageContent::Text("Some task response...".to_string())]
4373 );
4374 // Ensure that model received a summary prompt
4375 assert_eq!(messages[3].role, Role::User);
4376 assert_eq!(
4377 messages[3].content,
4378 vec![MessageContent::Text("summary prompt".to_string())]
4379 );
4380 }
4381
4382 model.send_last_completion_stream_text_chunk("Some summary...");
4383 model.end_last_completion_stream();
4384
4385 let result = summary_task.await;
4386 assert_eq!(result.unwrap(), "Some summary...\n");
4387}
4388
4389#[gpui::test]
4390async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4391 cx: &mut TestAppContext,
4392) {
4393 init_test(cx);
4394
4395 always_allow_tools(cx);
4396
4397 cx.update(|cx| {
4398 cx.update_flags(true, vec!["subagents".to_string()]);
4399 });
4400
4401 let fs = FakeFs::new(cx.executor());
4402 fs.insert_tree(path!("/test"), json!({})).await;
4403 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4404 let project_context = cx.new(|_cx| ProjectContext::default());
4405 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4406 let context_server_registry =
4407 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4408 cx.update(LanguageModelRegistry::test);
4409 let model = Arc::new(FakeLanguageModel::default());
4410 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4411 let native_agent = NativeAgent::new(
4412 project.clone(),
4413 thread_store,
4414 Templates::new(),
4415 None,
4416 fs,
4417 &mut cx.to_async(),
4418 )
4419 .await
4420 .unwrap();
4421 let parent_thread = cx.new(|cx| {
4422 Thread::new(
4423 project.clone(),
4424 project_context,
4425 context_server_registry,
4426 Templates::new(),
4427 Some(model.clone()),
4428 cx,
4429 )
4430 });
4431
4432 let subagent_handle = cx
4433 .update(|cx| {
4434 NativeThreadEnvironment::create_subagent_thread(
4435 native_agent.downgrade(),
4436 parent_thread.clone(),
4437 "some title".to_string(),
4438 "task prompt".to_string(),
4439 Some(Duration::from_millis(100)),
4440 None,
4441 cx,
4442 )
4443 })
4444 .expect("Failed to create subagent");
4445
4446 let summary_task =
4447 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4448
4449 cx.run_until_parked();
4450
4451 {
4452 let messages = model.pending_completions().last().unwrap().messages.clone();
4453 // Ensure that model received a system prompt
4454 assert_eq!(messages[0].role, Role::System);
4455 // Ensure that model received a task prompt
4456 assert_eq!(
4457 messages[1].content,
4458 vec![MessageContent::Text("task prompt".to_string())]
4459 );
4460 }
4461
4462 // Don't complete the initial model stream — let the timeout expire instead.
4463 cx.executor().advance_clock(Duration::from_millis(200));
4464 cx.run_until_parked();
4465
4466 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
4467 // instead of the summary_prompt.
4468 {
4469 let messages = model.pending_completions().last().unwrap().messages.clone();
4470 let last_user_message = messages
4471 .iter()
4472 .rev()
4473 .find(|m| m.role == Role::User)
4474 .unwrap();
4475 assert_eq!(
4476 last_user_message.content,
4477 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
4478 );
4479 }
4480
4481 model.send_last_completion_stream_text_chunk("Some context low response...");
4482 model.end_last_completion_stream();
4483
4484 let result = summary_task.await;
4485 assert_eq!(result.unwrap(), "Some context low response...\n");
4486}
4487
4488#[gpui::test]
4489async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
4490 init_test(cx);
4491
4492 always_allow_tools(cx);
4493
4494 cx.update(|cx| {
4495 cx.update_flags(true, vec!["subagents".to_string()]);
4496 });
4497
4498 let fs = FakeFs::new(cx.executor());
4499 fs.insert_tree(path!("/test"), json!({})).await;
4500 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4501 let project_context = cx.new(|_cx| ProjectContext::default());
4502 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4503 let context_server_registry =
4504 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4505 cx.update(LanguageModelRegistry::test);
4506 let model = Arc::new(FakeLanguageModel::default());
4507 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4508 let native_agent = NativeAgent::new(
4509 project.clone(),
4510 thread_store,
4511 Templates::new(),
4512 None,
4513 fs,
4514 &mut cx.to_async(),
4515 )
4516 .await
4517 .unwrap();
4518 let parent_thread = cx.new(|cx| {
4519 let mut thread = Thread::new(
4520 project.clone(),
4521 project_context,
4522 context_server_registry,
4523 Templates::new(),
4524 Some(model.clone()),
4525 cx,
4526 );
4527 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4528 thread.add_tool(GrepTool::new(project.clone()), None);
4529 thread
4530 });
4531
4532 let _subagent_handle = cx
4533 .update(|cx| {
4534 NativeThreadEnvironment::create_subagent_thread(
4535 native_agent.downgrade(),
4536 parent_thread.clone(),
4537 "some title".to_string(),
4538 "task prompt".to_string(),
4539 Some(Duration::from_millis(10)),
4540 None,
4541 cx,
4542 )
4543 })
4544 .expect("Failed to create subagent");
4545
4546 cx.run_until_parked();
4547
4548 let tools = model
4549 .pending_completions()
4550 .last()
4551 .unwrap()
4552 .tools
4553 .iter()
4554 .map(|tool| tool.name.clone())
4555 .collect::<Vec<_>>();
4556 assert_eq!(tools.len(), 2);
4557 assert!(tools.contains(&"grep".to_string()));
4558 assert!(tools.contains(&"list_directory".to_string()));
4559}
4560
4561#[gpui::test]
4562async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
4563 init_test(cx);
4564
4565 always_allow_tools(cx);
4566
4567 cx.update(|cx| {
4568 cx.update_flags(true, vec!["subagents".to_string()]);
4569 });
4570
4571 let fs = FakeFs::new(cx.executor());
4572 fs.insert_tree(path!("/test"), json!({})).await;
4573 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4574 let project_context = cx.new(|_cx| ProjectContext::default());
4575 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4576 let context_server_registry =
4577 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4578 cx.update(LanguageModelRegistry::test);
4579 let model = Arc::new(FakeLanguageModel::default());
4580 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4581 let native_agent = NativeAgent::new(
4582 project.clone(),
4583 thread_store,
4584 Templates::new(),
4585 None,
4586 fs,
4587 &mut cx.to_async(),
4588 )
4589 .await
4590 .unwrap();
4591 let parent_thread = cx.new(|cx| {
4592 let mut thread = Thread::new(
4593 project.clone(),
4594 project_context,
4595 context_server_registry,
4596 Templates::new(),
4597 Some(model.clone()),
4598 cx,
4599 );
4600 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4601 thread.add_tool(GrepTool::new(project.clone()), None);
4602 thread
4603 });
4604
4605 let _subagent_handle = cx
4606 .update(|cx| {
4607 NativeThreadEnvironment::create_subagent_thread(
4608 native_agent.downgrade(),
4609 parent_thread.clone(),
4610 "some title".to_string(),
4611 "task prompt".to_string(),
4612 Some(Duration::from_millis(10)),
4613 Some(vec!["grep".to_string()]),
4614 cx,
4615 )
4616 })
4617 .expect("Failed to create subagent");
4618
4619 cx.run_until_parked();
4620
4621 let tools = model
4622 .pending_completions()
4623 .last()
4624 .unwrap()
4625 .tools
4626 .iter()
4627 .map(|tool| tool.name.clone())
4628 .collect::<Vec<_>>();
4629 assert_eq!(tools, vec!["grep"]);
4630}
4631
4632#[gpui::test]
4633async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4634 init_test(cx);
4635
4636 let fs = FakeFs::new(cx.executor());
4637 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4638 .await;
4639 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4640
4641 cx.update(|cx| {
4642 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4643 settings.tool_permissions.tools.insert(
4644 EditFileTool::NAME.into(),
4645 agent_settings::ToolRules {
4646 default: Some(settings::ToolPermissionMode::Allow),
4647 always_allow: vec![],
4648 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
4649 always_confirm: vec![],
4650 invalid_patterns: vec![],
4651 },
4652 );
4653 agent_settings::AgentSettings::override_global(settings, cx);
4654 });
4655
4656 let context_server_registry =
4657 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4658 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4659 let templates = crate::Templates::new();
4660 let thread = cx.new(|cx| {
4661 crate::Thread::new(
4662 project.clone(),
4663 cx.new(|_cx| prompt_store::ProjectContext::default()),
4664 context_server_registry,
4665 templates.clone(),
4666 None,
4667 cx,
4668 )
4669 });
4670
4671 #[allow(clippy::arc_with_non_send_sync)]
4672 let tool = Arc::new(crate::EditFileTool::new(
4673 project.clone(),
4674 thread.downgrade(),
4675 language_registry,
4676 templates,
4677 ));
4678 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4679
4680 let task = cx.update(|cx| {
4681 tool.run(
4682 crate::EditFileToolInput {
4683 display_description: "Edit sensitive file".to_string(),
4684 path: "root/sensitive_config.txt".into(),
4685 mode: crate::EditFileMode::Edit,
4686 },
4687 event_stream,
4688 cx,
4689 )
4690 });
4691
4692 let result = task.await;
4693 assert!(result.is_err(), "expected edit to be blocked");
4694 assert!(
4695 result.unwrap_err().to_string().contains("blocked"),
4696 "error should mention the edit was blocked"
4697 );
4698}
4699
4700#[gpui::test]
4701async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4702 init_test(cx);
4703
4704 let fs = FakeFs::new(cx.executor());
4705 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4706 .await;
4707 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4708
4709 cx.update(|cx| {
4710 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4711 settings.tool_permissions.tools.insert(
4712 DeletePathTool::NAME.into(),
4713 agent_settings::ToolRules {
4714 default: Some(settings::ToolPermissionMode::Allow),
4715 always_allow: vec![],
4716 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4717 always_confirm: vec![],
4718 invalid_patterns: vec![],
4719 },
4720 );
4721 agent_settings::AgentSettings::override_global(settings, cx);
4722 });
4723
4724 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4725
4726 #[allow(clippy::arc_with_non_send_sync)]
4727 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4728 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4729
4730 let task = cx.update(|cx| {
4731 tool.run(
4732 crate::DeletePathToolInput {
4733 path: "root/important_data.txt".to_string(),
4734 },
4735 event_stream,
4736 cx,
4737 )
4738 });
4739
4740 let result = task.await;
4741 assert!(result.is_err(), "expected deletion to be blocked");
4742 assert!(
4743 result.unwrap_err().to_string().contains("blocked"),
4744 "error should mention the deletion was blocked"
4745 );
4746}
4747
4748#[gpui::test]
4749async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
4750 init_test(cx);
4751
4752 let fs = FakeFs::new(cx.executor());
4753 fs.insert_tree(
4754 "/root",
4755 json!({
4756 "safe.txt": "content",
4757 "protected": {}
4758 }),
4759 )
4760 .await;
4761 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4762
4763 cx.update(|cx| {
4764 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4765 settings.tool_permissions.tools.insert(
4766 MovePathTool::NAME.into(),
4767 agent_settings::ToolRules {
4768 default: Some(settings::ToolPermissionMode::Allow),
4769 always_allow: vec![],
4770 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
4771 always_confirm: vec![],
4772 invalid_patterns: vec![],
4773 },
4774 );
4775 agent_settings::AgentSettings::override_global(settings, cx);
4776 });
4777
4778 #[allow(clippy::arc_with_non_send_sync)]
4779 let tool = Arc::new(crate::MovePathTool::new(project));
4780 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4781
4782 let task = cx.update(|cx| {
4783 tool.run(
4784 crate::MovePathToolInput {
4785 source_path: "root/safe.txt".to_string(),
4786 destination_path: "root/protected/safe.txt".to_string(),
4787 },
4788 event_stream,
4789 cx,
4790 )
4791 });
4792
4793 let result = task.await;
4794 assert!(
4795 result.is_err(),
4796 "expected move to be blocked due to destination path"
4797 );
4798 assert!(
4799 result.unwrap_err().to_string().contains("blocked"),
4800 "error should mention the move was blocked"
4801 );
4802}
4803
4804#[gpui::test]
4805async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
4806 init_test(cx);
4807
4808 let fs = FakeFs::new(cx.executor());
4809 fs.insert_tree(
4810 "/root",
4811 json!({
4812 "secret.txt": "secret content",
4813 "public": {}
4814 }),
4815 )
4816 .await;
4817 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4818
4819 cx.update(|cx| {
4820 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4821 settings.tool_permissions.tools.insert(
4822 MovePathTool::NAME.into(),
4823 agent_settings::ToolRules {
4824 default: Some(settings::ToolPermissionMode::Allow),
4825 always_allow: vec![],
4826 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
4827 always_confirm: vec![],
4828 invalid_patterns: vec![],
4829 },
4830 );
4831 agent_settings::AgentSettings::override_global(settings, cx);
4832 });
4833
4834 #[allow(clippy::arc_with_non_send_sync)]
4835 let tool = Arc::new(crate::MovePathTool::new(project));
4836 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4837
4838 let task = cx.update(|cx| {
4839 tool.run(
4840 crate::MovePathToolInput {
4841 source_path: "root/secret.txt".to_string(),
4842 destination_path: "root/public/not_secret.txt".to_string(),
4843 },
4844 event_stream,
4845 cx,
4846 )
4847 });
4848
4849 let result = task.await;
4850 assert!(
4851 result.is_err(),
4852 "expected move to be blocked due to source path"
4853 );
4854 assert!(
4855 result.unwrap_err().to_string().contains("blocked"),
4856 "error should mention the move was blocked"
4857 );
4858}
4859
4860#[gpui::test]
4861async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
4862 init_test(cx);
4863
4864 let fs = FakeFs::new(cx.executor());
4865 fs.insert_tree(
4866 "/root",
4867 json!({
4868 "confidential.txt": "confidential data",
4869 "dest": {}
4870 }),
4871 )
4872 .await;
4873 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4874
4875 cx.update(|cx| {
4876 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4877 settings.tool_permissions.tools.insert(
4878 CopyPathTool::NAME.into(),
4879 agent_settings::ToolRules {
4880 default: Some(settings::ToolPermissionMode::Allow),
4881 always_allow: vec![],
4882 always_deny: vec![
4883 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
4884 ],
4885 always_confirm: vec![],
4886 invalid_patterns: vec![],
4887 },
4888 );
4889 agent_settings::AgentSettings::override_global(settings, cx);
4890 });
4891
4892 #[allow(clippy::arc_with_non_send_sync)]
4893 let tool = Arc::new(crate::CopyPathTool::new(project));
4894 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4895
4896 let task = cx.update(|cx| {
4897 tool.run(
4898 crate::CopyPathToolInput {
4899 source_path: "root/confidential.txt".to_string(),
4900 destination_path: "root/dest/copy.txt".to_string(),
4901 },
4902 event_stream,
4903 cx,
4904 )
4905 });
4906
4907 let result = task.await;
4908 assert!(result.is_err(), "expected copy to be blocked");
4909 assert!(
4910 result.unwrap_err().to_string().contains("blocked"),
4911 "error should mention the copy was blocked"
4912 );
4913}
4914
4915#[gpui::test]
4916async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
4917 init_test(cx);
4918
4919 let fs = FakeFs::new(cx.executor());
4920 fs.insert_tree(
4921 "/root",
4922 json!({
4923 "normal.txt": "normal content",
4924 "readonly": {
4925 "config.txt": "readonly content"
4926 }
4927 }),
4928 )
4929 .await;
4930 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4931
4932 cx.update(|cx| {
4933 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4934 settings.tool_permissions.tools.insert(
4935 SaveFileTool::NAME.into(),
4936 agent_settings::ToolRules {
4937 default: Some(settings::ToolPermissionMode::Allow),
4938 always_allow: vec![],
4939 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
4940 always_confirm: vec![],
4941 invalid_patterns: vec![],
4942 },
4943 );
4944 agent_settings::AgentSettings::override_global(settings, cx);
4945 });
4946
4947 #[allow(clippy::arc_with_non_send_sync)]
4948 let tool = Arc::new(crate::SaveFileTool::new(project));
4949 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4950
4951 let task = cx.update(|cx| {
4952 tool.run(
4953 crate::SaveFileToolInput {
4954 paths: vec![
4955 std::path::PathBuf::from("root/normal.txt"),
4956 std::path::PathBuf::from("root/readonly/config.txt"),
4957 ],
4958 },
4959 event_stream,
4960 cx,
4961 )
4962 });
4963
4964 let result = task.await;
4965 assert!(
4966 result.is_err(),
4967 "expected save to be blocked due to denied path"
4968 );
4969 assert!(
4970 result.unwrap_err().to_string().contains("blocked"),
4971 "error should mention the save was blocked"
4972 );
4973}
4974
4975#[gpui::test]
4976async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
4977 init_test(cx);
4978
4979 let fs = FakeFs::new(cx.executor());
4980 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
4981 .await;
4982 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4983
4984 cx.update(|cx| {
4985 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4986 settings.tool_permissions.tools.insert(
4987 SaveFileTool::NAME.into(),
4988 agent_settings::ToolRules {
4989 default: Some(settings::ToolPermissionMode::Allow),
4990 always_allow: vec![],
4991 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
4992 always_confirm: vec![],
4993 invalid_patterns: vec![],
4994 },
4995 );
4996 agent_settings::AgentSettings::override_global(settings, cx);
4997 });
4998
4999 #[allow(clippy::arc_with_non_send_sync)]
5000 let tool = Arc::new(crate::SaveFileTool::new(project));
5001 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5002
5003 let task = cx.update(|cx| {
5004 tool.run(
5005 crate::SaveFileToolInput {
5006 paths: vec![std::path::PathBuf::from("root/config.secret")],
5007 },
5008 event_stream,
5009 cx,
5010 )
5011 });
5012
5013 let result = task.await;
5014 assert!(result.is_err(), "expected save to be blocked");
5015 assert!(
5016 result.unwrap_err().to_string().contains("blocked"),
5017 "error should mention the save was blocked"
5018 );
5019}
5020
5021#[gpui::test]
5022async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5023 init_test(cx);
5024
5025 cx.update(|cx| {
5026 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5027 settings.tool_permissions.tools.insert(
5028 WebSearchTool::NAME.into(),
5029 agent_settings::ToolRules {
5030 default: Some(settings::ToolPermissionMode::Allow),
5031 always_allow: vec![],
5032 always_deny: vec![
5033 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5034 ],
5035 always_confirm: vec![],
5036 invalid_patterns: vec![],
5037 },
5038 );
5039 agent_settings::AgentSettings::override_global(settings, cx);
5040 });
5041
5042 #[allow(clippy::arc_with_non_send_sync)]
5043 let tool = Arc::new(crate::WebSearchTool);
5044 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5045
5046 let input: crate::WebSearchToolInput =
5047 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5048
5049 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5050
5051 let result = task.await;
5052 assert!(result.is_err(), "expected search to be blocked");
5053 assert!(
5054 result.unwrap_err().to_string().contains("blocked"),
5055 "error should mention the search was blocked"
5056 );
5057}
5058
5059#[gpui::test]
5060async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5061 init_test(cx);
5062
5063 let fs = FakeFs::new(cx.executor());
5064 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5065 .await;
5066 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5067
5068 cx.update(|cx| {
5069 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5070 settings.tool_permissions.tools.insert(
5071 EditFileTool::NAME.into(),
5072 agent_settings::ToolRules {
5073 default: Some(settings::ToolPermissionMode::Confirm),
5074 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5075 always_deny: vec![],
5076 always_confirm: vec![],
5077 invalid_patterns: vec![],
5078 },
5079 );
5080 agent_settings::AgentSettings::override_global(settings, cx);
5081 });
5082
5083 let context_server_registry =
5084 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5085 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5086 let templates = crate::Templates::new();
5087 let thread = cx.new(|cx| {
5088 crate::Thread::new(
5089 project.clone(),
5090 cx.new(|_cx| prompt_store::ProjectContext::default()),
5091 context_server_registry,
5092 templates.clone(),
5093 None,
5094 cx,
5095 )
5096 });
5097
5098 #[allow(clippy::arc_with_non_send_sync)]
5099 let tool = Arc::new(crate::EditFileTool::new(
5100 project,
5101 thread.downgrade(),
5102 language_registry,
5103 templates,
5104 ));
5105 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5106
5107 let _task = cx.update(|cx| {
5108 tool.run(
5109 crate::EditFileToolInput {
5110 display_description: "Edit README".to_string(),
5111 path: "root/README.md".into(),
5112 mode: crate::EditFileMode::Edit,
5113 },
5114 event_stream,
5115 cx,
5116 )
5117 });
5118
5119 cx.run_until_parked();
5120
5121 let event = rx.try_next();
5122 assert!(
5123 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5124 "expected no authorization request for allowed .md file"
5125 );
5126}
5127
5128#[gpui::test]
5129async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5130 init_test(cx);
5131
5132 let fs = FakeFs::new(cx.executor());
5133 fs.insert_tree(
5134 "/root",
5135 json!({
5136 ".zed": {
5137 "settings.json": "{}"
5138 },
5139 "README.md": "# Hello"
5140 }),
5141 )
5142 .await;
5143 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5144
5145 cx.update(|cx| {
5146 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5147 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5148 agent_settings::AgentSettings::override_global(settings, cx);
5149 });
5150
5151 let context_server_registry =
5152 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5153 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5154 let templates = crate::Templates::new();
5155 let thread = cx.new(|cx| {
5156 crate::Thread::new(
5157 project.clone(),
5158 cx.new(|_cx| prompt_store::ProjectContext::default()),
5159 context_server_registry,
5160 templates.clone(),
5161 None,
5162 cx,
5163 )
5164 });
5165
5166 #[allow(clippy::arc_with_non_send_sync)]
5167 let tool = Arc::new(crate::EditFileTool::new(
5168 project,
5169 thread.downgrade(),
5170 language_registry,
5171 templates,
5172 ));
5173
5174 // Editing a file inside .zed/ should still prompt even with global default: allow,
5175 // because local settings paths are sensitive and require confirmation regardless.
5176 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5177 let _task = cx.update(|cx| {
5178 tool.run(
5179 crate::EditFileToolInput {
5180 display_description: "Edit local settings".to_string(),
5181 path: "root/.zed/settings.json".into(),
5182 mode: crate::EditFileMode::Edit,
5183 },
5184 event_stream,
5185 cx,
5186 )
5187 });
5188
5189 let _update = rx.expect_update_fields().await;
5190 let _auth = rx.expect_authorization().await;
5191}
5192
5193#[gpui::test]
5194async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5195 init_test(cx);
5196
5197 cx.update(|cx| {
5198 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5199 settings.tool_permissions.tools.insert(
5200 FetchTool::NAME.into(),
5201 agent_settings::ToolRules {
5202 default: Some(settings::ToolPermissionMode::Allow),
5203 always_allow: vec![],
5204 always_deny: vec![
5205 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5206 ],
5207 always_confirm: vec![],
5208 invalid_patterns: vec![],
5209 },
5210 );
5211 agent_settings::AgentSettings::override_global(settings, cx);
5212 });
5213
5214 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5215
5216 #[allow(clippy::arc_with_non_send_sync)]
5217 let tool = Arc::new(crate::FetchTool::new(http_client));
5218 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5219
5220 let input: crate::FetchToolInput =
5221 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5222
5223 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5224
5225 let result = task.await;
5226 assert!(result.is_err(), "expected fetch to be blocked");
5227 assert!(
5228 result.unwrap_err().to_string().contains("blocked"),
5229 "error should mention the fetch was blocked"
5230 );
5231}
5232
5233#[gpui::test]
5234async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5235 init_test(cx);
5236
5237 cx.update(|cx| {
5238 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5239 settings.tool_permissions.tools.insert(
5240 FetchTool::NAME.into(),
5241 agent_settings::ToolRules {
5242 default: Some(settings::ToolPermissionMode::Confirm),
5243 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5244 always_deny: vec![],
5245 always_confirm: vec![],
5246 invalid_patterns: vec![],
5247 },
5248 );
5249 agent_settings::AgentSettings::override_global(settings, cx);
5250 });
5251
5252 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5253
5254 #[allow(clippy::arc_with_non_send_sync)]
5255 let tool = Arc::new(crate::FetchTool::new(http_client));
5256 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5257
5258 let input: crate::FetchToolInput =
5259 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5260
5261 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5262
5263 cx.run_until_parked();
5264
5265 let event = rx.try_next();
5266 assert!(
5267 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5268 "expected no authorization request for allowed docs.rs URL"
5269 );
5270}
5271
5272#[gpui::test]
5273async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5274 init_test(cx);
5275 always_allow_tools(cx);
5276
5277 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5278 let fake_model = model.as_fake();
5279
5280 // Add a tool so we can simulate tool calls
5281 thread.update(cx, |thread, _cx| {
5282 thread.add_tool(EchoTool, None);
5283 });
5284
5285 // Start a turn by sending a message
5286 let mut events = thread
5287 .update(cx, |thread, cx| {
5288 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5289 })
5290 .unwrap();
5291 cx.run_until_parked();
5292
5293 // Simulate the model making a tool call
5294 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5295 LanguageModelToolUse {
5296 id: "tool_1".into(),
5297 name: "echo".into(),
5298 raw_input: r#"{"text": "hello"}"#.into(),
5299 input: json!({"text": "hello"}),
5300 is_input_complete: true,
5301 thought_signature: None,
5302 },
5303 ));
5304 fake_model
5305 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5306
5307 // Signal that a message is queued before ending the stream
5308 thread.update(cx, |thread, _cx| {
5309 thread.set_has_queued_message(true);
5310 });
5311
5312 // Now end the stream - tool will run, and the boundary check should see the queue
5313 fake_model.end_last_completion_stream();
5314
5315 // Collect all events until the turn stops
5316 let all_events = collect_events_until_stop(&mut events, cx).await;
5317
5318 // Verify we received the tool call event
5319 let tool_call_ids: Vec<_> = all_events
5320 .iter()
5321 .filter_map(|e| match e {
5322 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5323 _ => None,
5324 })
5325 .collect();
5326 assert_eq!(
5327 tool_call_ids,
5328 vec!["tool_1"],
5329 "Should have received a tool call event for our echo tool"
5330 );
5331
5332 // The turn should have stopped with EndTurn
5333 let stop_reasons = stop_events(all_events);
5334 assert_eq!(
5335 stop_reasons,
5336 vec![acp::StopReason::EndTurn],
5337 "Turn should have ended after tool completion due to queued message"
5338 );
5339
5340 // Verify the queued message flag is still set
5341 thread.update(cx, |thread, _cx| {
5342 assert!(
5343 thread.has_queued_message(),
5344 "Should still have queued message flag set"
5345 );
5346 });
5347
5348 // Thread should be idle now
5349 thread.update(cx, |thread, _cx| {
5350 assert!(
5351 thread.is_turn_complete(),
5352 "Thread should not be running after turn ends"
5353 );
5354 });
5355}