1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeSubagentHandle {
159 session_id: acp::SessionId,
160 wait_for_summary_task: Shared<Task<String>>,
161}
162
163impl FakeSubagentHandle {
164 fn new_never_completes(cx: &App) -> Self {
165 Self {
166 session_id: acp::SessionId::new("subagent-id"),
167 wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
168 }
169 }
170}
171
172impl SubagentHandle for FakeSubagentHandle {
173 fn id(&self) -> acp::SessionId {
174 self.session_id.clone()
175 }
176
177 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
178 let task = self.wait_for_summary_task.clone();
179 cx.background_spawn(async move { Ok(task.await) })
180 }
181}
182
183#[derive(Default)]
184struct FakeThreadEnvironment {
185 terminal_handle: Option<Rc<FakeTerminalHandle>>,
186 subagent_handle: Option<Rc<FakeSubagentHandle>>,
187}
188
189impl FakeThreadEnvironment {
190 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
191 Self {
192 terminal_handle: Some(terminal_handle.into()),
193 ..self
194 }
195 }
196
197 pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
198 Self {
199 subagent_handle: Some(subagent_handle.into()),
200 ..self
201 }
202 }
203}
204
205impl crate::ThreadEnvironment for FakeThreadEnvironment {
206 fn create_terminal(
207 &self,
208 _command: String,
209 _cwd: Option<std::path::PathBuf>,
210 _output_byte_limit: Option<u64>,
211 _cx: &mut AsyncApp,
212 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
213 let handle = self
214 .terminal_handle
215 .clone()
216 .expect("Terminal handle not available on FakeThreadEnvironment");
217 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
218 }
219
220 fn create_subagent(
221 &self,
222 _parent_thread: Entity<Thread>,
223 _label: String,
224 _initial_prompt: String,
225 _timeout_ms: Option<Duration>,
226 _allowed_tools: Option<Vec<String>>,
227 _cx: &mut App,
228 ) -> Result<Rc<dyn SubagentHandle>> {
229 Ok(self
230 .subagent_handle
231 .clone()
232 .expect("Subagent handle not available on FakeThreadEnvironment")
233 as Rc<dyn SubagentHandle>)
234 }
235}
236
237/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
238struct MultiTerminalEnvironment {
239 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
240}
241
242impl MultiTerminalEnvironment {
243 fn new() -> Self {
244 Self {
245 handles: std::cell::RefCell::new(Vec::new()),
246 }
247 }
248
249 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
250 self.handles.borrow().clone()
251 }
252}
253
254impl crate::ThreadEnvironment for MultiTerminalEnvironment {
255 fn create_terminal(
256 &self,
257 _command: String,
258 _cwd: Option<std::path::PathBuf>,
259 _output_byte_limit: Option<u64>,
260 cx: &mut AsyncApp,
261 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
263 self.handles.borrow_mut().push(handle.clone());
264 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
265 }
266
267 fn create_subagent(
268 &self,
269 _parent_thread: Entity<Thread>,
270 _label: String,
271 _initial_prompt: String,
272 _timeout: Option<Duration>,
273 _allowed_tools: Option<Vec<String>>,
274 _cx: &mut App,
275 ) -> Result<Rc<dyn SubagentHandle>> {
276 unimplemented!()
277 }
278}
279
280fn always_allow_tools(cx: &mut TestAppContext) {
281 cx.update(|cx| {
282 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
283 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
284 agent_settings::AgentSettings::override_global(settings, cx);
285 });
286}
287
288#[gpui::test]
289async fn test_echo(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
296 })
297 .unwrap();
298 cx.run_until_parked();
299 fake_model.send_last_completion_stream_text_chunk("Hello");
300 fake_model
301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
302 fake_model.end_last_completion_stream();
303
304 let events = events.collect().await;
305 thread.update(cx, |thread, _cx| {
306 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
307 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
308 });
309 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
310}
311
312#[gpui::test]
313async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
314 init_test(cx);
315 always_allow_tools(cx);
316
317 let fs = FakeFs::new(cx.executor());
318 let project = Project::test(fs, [], cx).await;
319
320 let environment = Rc::new(cx.update(|cx| {
321 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
322 }));
323 let handle = environment.terminal_handle.clone().unwrap();
324
325 #[allow(clippy::arc_with_non_send_sync)]
326 let tool = Arc::new(crate::TerminalTool::new(project, environment));
327 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
328
329 let task = cx.update(|cx| {
330 tool.run(
331 crate::TerminalToolInput {
332 command: "sleep 1000".to_string(),
333 cd: ".".to_string(),
334 timeout_ms: Some(5),
335 },
336 event_stream,
337 cx,
338 )
339 });
340
341 let update = rx.expect_update_fields().await;
342 assert!(
343 update.content.iter().any(|blocks| {
344 blocks
345 .iter()
346 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
347 }),
348 "expected tool call update to include terminal content"
349 );
350
351 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
352
353 let deadline = std::time::Instant::now() + Duration::from_millis(500);
354 loop {
355 if let Some(result) = task_future.as_mut().now_or_never() {
356 let result = result.expect("terminal tool task should complete");
357
358 assert!(
359 handle.was_killed(),
360 "expected terminal handle to be killed on timeout"
361 );
362 assert!(
363 result.contains("partial output"),
364 "expected result to include terminal output, got: {result}"
365 );
366 return;
367 }
368
369 if std::time::Instant::now() >= deadline {
370 panic!("timed out waiting for terminal tool task to complete");
371 }
372
373 cx.run_until_parked();
374 cx.background_executor.timer(Duration::from_millis(1)).await;
375 }
376}
377
378#[gpui::test]
379#[ignore]
380async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
381 init_test(cx);
382 always_allow_tools(cx);
383
384 let fs = FakeFs::new(cx.executor());
385 let project = Project::test(fs, [], cx).await;
386
387 let environment = Rc::new(cx.update(|cx| {
388 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
389 }));
390 let handle = environment.terminal_handle.clone().unwrap();
391
392 #[allow(clippy::arc_with_non_send_sync)]
393 let tool = Arc::new(crate::TerminalTool::new(project, environment));
394 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
395
396 let _task = cx.update(|cx| {
397 tool.run(
398 crate::TerminalToolInput {
399 command: "sleep 1000".to_string(),
400 cd: ".".to_string(),
401 timeout_ms: None,
402 },
403 event_stream,
404 cx,
405 )
406 });
407
408 let update = rx.expect_update_fields().await;
409 assert!(
410 update.content.iter().any(|blocks| {
411 blocks
412 .iter()
413 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
414 }),
415 "expected tool call update to include terminal content"
416 );
417
418 cx.background_executor
419 .timer(Duration::from_millis(25))
420 .await;
421
422 assert!(
423 !handle.was_killed(),
424 "did not expect terminal handle to be killed without a timeout"
425 );
426}
427
428#[gpui::test]
429async fn test_thinking(cx: &mut TestAppContext) {
430 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
431 let fake_model = model.as_fake();
432
433 let events = thread
434 .update(cx, |thread, cx| {
435 thread.send(
436 UserMessageId::new(),
437 [indoc! {"
438 Testing:
439
440 Generate a thinking step where you just think the word 'Think',
441 and have your final answer be 'Hello'
442 "}],
443 cx,
444 )
445 })
446 .unwrap();
447 cx.run_until_parked();
448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
449 text: "Think".to_string(),
450 signature: None,
451 });
452 fake_model.send_last_completion_stream_text_chunk("Hello");
453 fake_model
454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
455 fake_model.end_last_completion_stream();
456
457 let events = events.collect().await;
458 thread.update(cx, |thread, _cx| {
459 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
460 assert_eq!(
461 thread.last_message().unwrap().to_markdown(),
462 indoc! {"
463 <think>Think</think>
464 Hello
465 "}
466 )
467 });
468 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
469}
470
471#[gpui::test]
472async fn test_system_prompt(cx: &mut TestAppContext) {
473 let ThreadTest {
474 model,
475 thread,
476 project_context,
477 ..
478 } = setup(cx, TestModel::Fake).await;
479 let fake_model = model.as_fake();
480
481 project_context.update(cx, |project_context, _cx| {
482 project_context.shell = "test-shell".into()
483 });
484 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["abc"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491 let mut pending_completions = fake_model.pending_completions();
492 assert_eq!(
493 pending_completions.len(),
494 1,
495 "unexpected pending completions: {:?}",
496 pending_completions
497 );
498
499 let pending_completion = pending_completions.pop().unwrap();
500 assert_eq!(pending_completion.messages[0].role, Role::System);
501
502 let system_message = &pending_completion.messages[0];
503 let system_prompt = system_message.content[0].to_str().unwrap();
504 assert!(
505 system_prompt.contains("test-shell"),
506 "unexpected system message: {:?}",
507 system_message
508 );
509 assert!(
510 system_prompt.contains("## Fixing Diagnostics"),
511 "unexpected system message: {:?}",
512 system_message
513 );
514}
515
516#[gpui::test]
517async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
518 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
519 let fake_model = model.as_fake();
520
521 thread
522 .update(cx, |thread, cx| {
523 thread.send(UserMessageId::new(), ["abc"], cx)
524 })
525 .unwrap();
526 cx.run_until_parked();
527 let mut pending_completions = fake_model.pending_completions();
528 assert_eq!(
529 pending_completions.len(),
530 1,
531 "unexpected pending completions: {:?}",
532 pending_completions
533 );
534
535 let pending_completion = pending_completions.pop().unwrap();
536 assert_eq!(pending_completion.messages[0].role, Role::System);
537
538 let system_message = &pending_completion.messages[0];
539 let system_prompt = system_message.content[0].to_str().unwrap();
540 assert!(
541 !system_prompt.contains("## Tool Use"),
542 "unexpected system message: {:?}",
543 system_message
544 );
545 assert!(
546 !system_prompt.contains("## Fixing Diagnostics"),
547 "unexpected system message: {:?}",
548 system_message
549 );
550}
551
552#[gpui::test]
553async fn test_prompt_caching(cx: &mut TestAppContext) {
554 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
555 let fake_model = model.as_fake();
556
557 // Send initial user message and verify it's cached
558 thread
559 .update(cx, |thread, cx| {
560 thread.send(UserMessageId::new(), ["Message 1"], cx)
561 })
562 .unwrap();
563 cx.run_until_parked();
564
565 let completion = fake_model.pending_completions().pop().unwrap();
566 assert_eq!(
567 completion.messages[1..],
568 vec![LanguageModelRequestMessage {
569 role: Role::User,
570 content: vec!["Message 1".into()],
571 cache: true,
572 reasoning_details: None,
573 }]
574 );
575 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
576 "Response to Message 1".into(),
577 ));
578 fake_model.end_last_completion_stream();
579 cx.run_until_parked();
580
581 // Send another user message and verify only the latest is cached
582 thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["Message 2"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588
589 let completion = fake_model.pending_completions().pop().unwrap();
590 assert_eq!(
591 completion.messages[1..],
592 vec![
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 1".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 1".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Message 2".into()],
608 cache: true,
609 reasoning_details: None,
610 }
611 ]
612 );
613 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
614 "Response to Message 2".into(),
615 ));
616 fake_model.end_last_completion_stream();
617 cx.run_until_parked();
618
619 // Simulate a tool call and verify that the latest tool result is cached
620 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
621 thread
622 .update(cx, |thread, cx| {
623 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
624 })
625 .unwrap();
626 cx.run_until_parked();
627
628 let tool_use = LanguageModelToolUse {
629 id: "tool_1".into(),
630 name: EchoTool::NAME.into(),
631 raw_input: json!({"text": "test"}).to_string(),
632 input: json!({"text": "test"}),
633 is_input_complete: true,
634 thought_signature: None,
635 };
636 fake_model
637 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
638 fake_model.end_last_completion_stream();
639 cx.run_until_parked();
640
641 let completion = fake_model.pending_completions().pop().unwrap();
642 let tool_result = LanguageModelToolResult {
643 tool_use_id: "tool_1".into(),
644 tool_name: EchoTool::NAME.into(),
645 is_error: false,
646 content: "test".into(),
647 output: Some("test".into()),
648 };
649 assert_eq!(
650 completion.messages[1..],
651 vec![
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec!["Message 1".into()],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::Assistant,
660 content: vec!["Response to Message 1".into()],
661 cache: false,
662 reasoning_details: None,
663 },
664 LanguageModelRequestMessage {
665 role: Role::User,
666 content: vec!["Message 2".into()],
667 cache: false,
668 reasoning_details: None,
669 },
670 LanguageModelRequestMessage {
671 role: Role::Assistant,
672 content: vec!["Response to Message 2".into()],
673 cache: false,
674 reasoning_details: None,
675 },
676 LanguageModelRequestMessage {
677 role: Role::User,
678 content: vec!["Use the echo tool".into()],
679 cache: false,
680 reasoning_details: None,
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false,
686 reasoning_details: None,
687 },
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec![MessageContent::ToolResult(tool_result)],
691 cache: true,
692 reasoning_details: None,
693 }
694 ]
695 );
696}
697
698#[gpui::test]
699#[cfg_attr(not(feature = "e2e"), ignore)]
700async fn test_basic_tool_calls(cx: &mut TestAppContext) {
701 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
702
703 // Test a tool call that's likely to complete *before* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.add_tool(EchoTool, None);
707 thread.send(
708 UserMessageId::new(),
709 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
710 cx,
711 )
712 })
713 .unwrap()
714 .collect()
715 .await;
716 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
717
718 // Test a tool calls that's likely to complete *after* streaming stops.
719 let events = thread
720 .update(cx, |thread, cx| {
721 thread.remove_tool(&EchoTool::NAME);
722 thread.add_tool(DelayTool, None);
723 thread.send(
724 UserMessageId::new(),
725 [
726 "Now call the delay tool with 200ms.",
727 "When the timer goes off, then you echo the output of the tool.",
728 ],
729 cx,
730 )
731 })
732 .unwrap()
733 .collect()
734 .await;
735 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
736 thread.update(cx, |thread, _cx| {
737 assert!(
738 thread
739 .last_message()
740 .unwrap()
741 .as_agent_message()
742 .unwrap()
743 .content
744 .iter()
745 .any(|content| {
746 if let AgentMessageContent::Text(text) = content {
747 text.contains("Ding")
748 } else {
749 false
750 }
751 }),
752 "{}",
753 thread.to_markdown()
754 );
755 });
756}
757
758#[gpui::test]
759#[cfg_attr(not(feature = "e2e"), ignore)]
760async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
761 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
762
763 // Test a tool call that's likely to complete *before* streaming stops.
764 let mut events = thread
765 .update(cx, |thread, cx| {
766 thread.add_tool(WordListTool, None);
767 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
768 })
769 .unwrap();
770
771 let mut saw_partial_tool_use = false;
772 while let Some(event) = events.next().await {
773 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
774 thread.update(cx, |thread, _cx| {
775 // Look for a tool use in the thread's last message
776 let message = thread.last_message().unwrap();
777 let agent_message = message.as_agent_message().unwrap();
778 let last_content = agent_message.content.last().unwrap();
779 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
780 assert_eq!(last_tool_use.name.as_ref(), "word_list");
781 if tool_call.status == acp::ToolCallStatus::Pending {
782 if !last_tool_use.is_input_complete
783 && last_tool_use.input.get("g").is_none()
784 {
785 saw_partial_tool_use = true;
786 }
787 } else {
788 last_tool_use
789 .input
790 .get("a")
791 .expect("'a' has streamed because input is now complete");
792 last_tool_use
793 .input
794 .get("g")
795 .expect("'g' has streamed because input is now complete");
796 }
797 } else {
798 panic!("last content should be a tool use");
799 }
800 });
801 }
802 }
803
804 assert!(
805 saw_partial_tool_use,
806 "should see at least one partially streamed tool use in the history"
807 );
808}
809
810#[gpui::test]
811async fn test_tool_authorization(cx: &mut TestAppContext) {
812 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
813 let fake_model = model.as_fake();
814
815 let mut events = thread
816 .update(cx, |thread, cx| {
817 thread.add_tool(ToolRequiringPermission, None);
818 thread.send(UserMessageId::new(), ["abc"], cx)
819 })
820 .unwrap();
821 cx.run_until_parked();
822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
823 LanguageModelToolUse {
824 id: "tool_id_1".into(),
825 name: ToolRequiringPermission::NAME.into(),
826 raw_input: "{}".into(),
827 input: json!({}),
828 is_input_complete: true,
829 thought_signature: None,
830 },
831 ));
832 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
833 LanguageModelToolUse {
834 id: "tool_id_2".into(),
835 name: ToolRequiringPermission::NAME.into(),
836 raw_input: "{}".into(),
837 input: json!({}),
838 is_input_complete: true,
839 thought_signature: None,
840 },
841 ));
842 fake_model.end_last_completion_stream();
843 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
844 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
845
846 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
847 tool_call_auth_1
848 .response
849 .send(acp::PermissionOptionId::new("allow"))
850 .unwrap();
851 cx.run_until_parked();
852
853 // Reject the second - send "deny" option_id directly since Deny is now a button
854 tool_call_auth_2
855 .response
856 .send(acp::PermissionOptionId::new("deny"))
857 .unwrap();
858 cx.run_until_parked();
859
860 let completion = fake_model.pending_completions().pop().unwrap();
861 let message = completion.messages.last().unwrap();
862 assert_eq!(
863 message.content,
864 vec![
865 language_model::MessageContent::ToolResult(LanguageModelToolResult {
866 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
867 tool_name: ToolRequiringPermission::NAME.into(),
868 is_error: false,
869 content: "Allowed".into(),
870 output: Some("Allowed".into())
871 }),
872 language_model::MessageContent::ToolResult(LanguageModelToolResult {
873 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
874 tool_name: ToolRequiringPermission::NAME.into(),
875 is_error: true,
876 content: "Permission to run tool denied by user".into(),
877 output: Some("Permission to run tool denied by user".into())
878 })
879 ]
880 );
881
882 // Simulate yet another tool call.
883 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
884 LanguageModelToolUse {
885 id: "tool_id_3".into(),
886 name: ToolRequiringPermission::NAME.into(),
887 raw_input: "{}".into(),
888 input: json!({}),
889 is_input_complete: true,
890 thought_signature: None,
891 },
892 ));
893 fake_model.end_last_completion_stream();
894
895 // Respond by always allowing tools - send transformed option_id
896 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
897 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
898 tool_call_auth_3
899 .response
900 .send(acp::PermissionOptionId::new(
901 "always_allow:tool_requiring_permission",
902 ))
903 .unwrap();
904 cx.run_until_parked();
905 let completion = fake_model.pending_completions().pop().unwrap();
906 let message = completion.messages.last().unwrap();
907 assert_eq!(
908 message.content,
909 vec![language_model::MessageContent::ToolResult(
910 LanguageModelToolResult {
911 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
912 tool_name: ToolRequiringPermission::NAME.into(),
913 is_error: false,
914 content: "Allowed".into(),
915 output: Some("Allowed".into())
916 }
917 )]
918 );
919
920 // Simulate a final tool call, ensuring we don't trigger authorization.
921 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
922 LanguageModelToolUse {
923 id: "tool_id_4".into(),
924 name: ToolRequiringPermission::NAME.into(),
925 raw_input: "{}".into(),
926 input: json!({}),
927 is_input_complete: true,
928 thought_signature: None,
929 },
930 ));
931 fake_model.end_last_completion_stream();
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let message = completion.messages.last().unwrap();
935 assert_eq!(
936 message.content,
937 vec![language_model::MessageContent::ToolResult(
938 LanguageModelToolResult {
939 tool_use_id: "tool_id_4".into(),
940 tool_name: ToolRequiringPermission::NAME.into(),
941 is_error: false,
942 content: "Allowed".into(),
943 output: Some("Allowed".into())
944 }
945 )]
946 );
947}
948
949#[gpui::test]
950async fn test_tool_hallucination(cx: &mut TestAppContext) {
951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
952 let fake_model = model.as_fake();
953
954 let mut events = thread
955 .update(cx, |thread, cx| {
956 thread.send(UserMessageId::new(), ["abc"], cx)
957 })
958 .unwrap();
959 cx.run_until_parked();
960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
961 LanguageModelToolUse {
962 id: "tool_id_1".into(),
963 name: "nonexistent_tool".into(),
964 raw_input: "{}".into(),
965 input: json!({}),
966 is_input_complete: true,
967 thought_signature: None,
968 },
969 ));
970 fake_model.end_last_completion_stream();
971
972 let tool_call = expect_tool_call(&mut events).await;
973 assert_eq!(tool_call.title, "nonexistent_tool");
974 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
975 let update = expect_tool_call_update_fields(&mut events).await;
976 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
977}
978
979async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
980 let event = events
981 .next()
982 .await
983 .expect("no tool call authorization event received")
984 .unwrap();
985 match event {
986 ThreadEvent::ToolCall(tool_call) => tool_call,
987 event => {
988 panic!("Unexpected event {event:?}");
989 }
990 }
991}
992
993async fn expect_tool_call_update_fields(
994 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
995) -> acp::ToolCallUpdate {
996 let event = events
997 .next()
998 .await
999 .expect("no tool call authorization event received")
1000 .unwrap();
1001 match event {
1002 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1003 event => {
1004 panic!("Unexpected event {event:?}");
1005 }
1006 }
1007}
1008
1009async fn next_tool_call_authorization(
1010 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1011) -> ToolCallAuthorization {
1012 loop {
1013 let event = events
1014 .next()
1015 .await
1016 .expect("no tool call authorization event received")
1017 .unwrap();
1018 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1019 let permission_kinds = tool_call_authorization
1020 .options
1021 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1022 .map(|option| option.kind);
1023 let allow_once = tool_call_authorization
1024 .options
1025 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1026 .map(|option| option.kind);
1027
1028 assert_eq!(
1029 permission_kinds,
1030 Some(acp::PermissionOptionKind::AllowAlways)
1031 );
1032 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1033 return tool_call_authorization;
1034 }
1035 }
1036}
1037
1038#[test]
1039fn test_permission_options_terminal_with_pattern() {
1040 let permission_options = ToolPermissionContext::new(
1041 TerminalTool::NAME,
1042 vec!["cargo build --release".to_string()],
1043 )
1044 .build_permission_options();
1045
1046 let PermissionOptions::Dropdown(choices) = permission_options else {
1047 panic!("Expected dropdown permission options");
1048 };
1049
1050 assert_eq!(choices.len(), 3);
1051 let labels: Vec<&str> = choices
1052 .iter()
1053 .map(|choice| choice.allow.name.as_ref())
1054 .collect();
1055 assert!(labels.contains(&"Always for terminal"));
1056 assert!(labels.contains(&"Always for `cargo` commands"));
1057 assert!(labels.contains(&"Only this time"));
1058}
1059
1060#[test]
1061fn test_permission_options_edit_file_with_path_pattern() {
1062 let permission_options =
1063 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1064 .build_permission_options();
1065
1066 let PermissionOptions::Dropdown(choices) = permission_options else {
1067 panic!("Expected dropdown permission options");
1068 };
1069
1070 let labels: Vec<&str> = choices
1071 .iter()
1072 .map(|choice| choice.allow.name.as_ref())
1073 .collect();
1074 assert!(labels.contains(&"Always for edit file"));
1075 assert!(labels.contains(&"Always for `src/`"));
1076}
1077
1078#[test]
1079fn test_permission_options_fetch_with_domain_pattern() {
1080 let permission_options =
1081 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1082 .build_permission_options();
1083
1084 let PermissionOptions::Dropdown(choices) = permission_options else {
1085 panic!("Expected dropdown permission options");
1086 };
1087
1088 let labels: Vec<&str> = choices
1089 .iter()
1090 .map(|choice| choice.allow.name.as_ref())
1091 .collect();
1092 assert!(labels.contains(&"Always for fetch"));
1093 assert!(labels.contains(&"Always for `docs.rs`"));
1094}
1095
1096#[test]
1097fn test_permission_options_without_pattern() {
1098 let permission_options = ToolPermissionContext::new(
1099 TerminalTool::NAME,
1100 vec!["./deploy.sh --production".to_string()],
1101 )
1102 .build_permission_options();
1103
1104 let PermissionOptions::Dropdown(choices) = permission_options else {
1105 panic!("Expected dropdown permission options");
1106 };
1107
1108 assert_eq!(choices.len(), 2);
1109 let labels: Vec<&str> = choices
1110 .iter()
1111 .map(|choice| choice.allow.name.as_ref())
1112 .collect();
1113 assert!(labels.contains(&"Always for terminal"));
1114 assert!(labels.contains(&"Only this time"));
1115 assert!(!labels.iter().any(|label| label.contains("commands")));
1116}
1117
1118#[test]
1119fn test_permission_option_ids_for_terminal() {
1120 let permission_options = ToolPermissionContext::new(
1121 TerminalTool::NAME,
1122 vec!["cargo build --release".to_string()],
1123 )
1124 .build_permission_options();
1125
1126 let PermissionOptions::Dropdown(choices) = permission_options else {
1127 panic!("Expected dropdown permission options");
1128 };
1129
1130 let allow_ids: Vec<String> = choices
1131 .iter()
1132 .map(|choice| choice.allow.option_id.0.to_string())
1133 .collect();
1134 let deny_ids: Vec<String> = choices
1135 .iter()
1136 .map(|choice| choice.deny.option_id.0.to_string())
1137 .collect();
1138
1139 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1140 assert!(allow_ids.contains(&"allow".to_string()));
1141 assert!(
1142 allow_ids
1143 .iter()
1144 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1145 "Missing allow pattern option"
1146 );
1147
1148 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1149 assert!(deny_ids.contains(&"deny".to_string()));
1150 assert!(
1151 deny_ids
1152 .iter()
1153 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1154 "Missing deny pattern option"
1155 );
1156}
1157
1158#[gpui::test]
1159#[cfg_attr(not(feature = "e2e"), ignore)]
1160async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1161 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1162
1163 // Test concurrent tool calls with different delay times
1164 let events = thread
1165 .update(cx, |thread, cx| {
1166 thread.add_tool(DelayTool, None);
1167 thread.send(
1168 UserMessageId::new(),
1169 [
1170 "Call the delay tool twice in the same message.",
1171 "Once with 100ms. Once with 300ms.",
1172 "When both timers are complete, describe the outputs.",
1173 ],
1174 cx,
1175 )
1176 })
1177 .unwrap()
1178 .collect()
1179 .await;
1180
1181 let stop_reasons = stop_events(events);
1182 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1183
1184 thread.update(cx, |thread, _cx| {
1185 let last_message = thread.last_message().unwrap();
1186 let agent_message = last_message.as_agent_message().unwrap();
1187 let text = agent_message
1188 .content
1189 .iter()
1190 .filter_map(|content| {
1191 if let AgentMessageContent::Text(text) = content {
1192 Some(text.as_str())
1193 } else {
1194 None
1195 }
1196 })
1197 .collect::<String>();
1198
1199 assert!(text.contains("Ding"));
1200 });
1201}
1202
1203#[gpui::test]
1204async fn test_profiles(cx: &mut TestAppContext) {
1205 let ThreadTest {
1206 model, thread, fs, ..
1207 } = setup(cx, TestModel::Fake).await;
1208 let fake_model = model.as_fake();
1209
1210 thread.update(cx, |thread, _cx| {
1211 thread.add_tool(DelayTool, None);
1212 thread.add_tool(EchoTool, None);
1213 thread.add_tool(InfiniteTool, None);
1214 });
1215
1216 // Override profiles and wait for settings to be loaded.
1217 fs.insert_file(
1218 paths::settings_file(),
1219 json!({
1220 "agent": {
1221 "profiles": {
1222 "test-1": {
1223 "name": "Test Profile 1",
1224 "tools": {
1225 EchoTool::NAME: true,
1226 DelayTool::NAME: true,
1227 }
1228 },
1229 "test-2": {
1230 "name": "Test Profile 2",
1231 "tools": {
1232 InfiniteTool::NAME: true,
1233 }
1234 }
1235 }
1236 }
1237 })
1238 .to_string()
1239 .into_bytes(),
1240 )
1241 .await;
1242 cx.run_until_parked();
1243
1244 // Test that test-1 profile (default) has echo and delay tools
1245 thread
1246 .update(cx, |thread, cx| {
1247 thread.set_profile(AgentProfileId("test-1".into()), cx);
1248 thread.send(UserMessageId::new(), ["test"], cx)
1249 })
1250 .unwrap();
1251 cx.run_until_parked();
1252
1253 let mut pending_completions = fake_model.pending_completions();
1254 assert_eq!(pending_completions.len(), 1);
1255 let completion = pending_completions.pop().unwrap();
1256 let tool_names: Vec<String> = completion
1257 .tools
1258 .iter()
1259 .map(|tool| tool.name.clone())
1260 .collect();
1261 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1262 fake_model.end_last_completion_stream();
1263
1264 // Switch to test-2 profile, and verify that it has only the infinite tool.
1265 thread
1266 .update(cx, |thread, cx| {
1267 thread.set_profile(AgentProfileId("test-2".into()), cx);
1268 thread.send(UserMessageId::new(), ["test2"], cx)
1269 })
1270 .unwrap();
1271 cx.run_until_parked();
1272 let mut pending_completions = fake_model.pending_completions();
1273 assert_eq!(pending_completions.len(), 1);
1274 let completion = pending_completions.pop().unwrap();
1275 let tool_names: Vec<String> = completion
1276 .tools
1277 .iter()
1278 .map(|tool| tool.name.clone())
1279 .collect();
1280 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1281}
1282
1283#[gpui::test]
1284async fn test_mcp_tools(cx: &mut TestAppContext) {
1285 let ThreadTest {
1286 model,
1287 thread,
1288 context_server_store,
1289 fs,
1290 ..
1291 } = setup(cx, TestModel::Fake).await;
1292 let fake_model = model.as_fake();
1293
1294 // Override profiles and wait for settings to be loaded.
1295 fs.insert_file(
1296 paths::settings_file(),
1297 json!({
1298 "agent": {
1299 "tool_permissions": { "default": "allow" },
1300 "profiles": {
1301 "test": {
1302 "name": "Test Profile",
1303 "enable_all_context_servers": true,
1304 "tools": {
1305 EchoTool::NAME: true,
1306 }
1307 },
1308 }
1309 }
1310 })
1311 .to_string()
1312 .into_bytes(),
1313 )
1314 .await;
1315 cx.run_until_parked();
1316 thread.update(cx, |thread, cx| {
1317 thread.set_profile(AgentProfileId("test".into()), cx)
1318 });
1319
1320 let mut mcp_tool_calls = setup_context_server(
1321 "test_server",
1322 vec![context_server::types::Tool {
1323 name: "echo".into(),
1324 description: None,
1325 input_schema: serde_json::to_value(EchoTool::input_schema(
1326 LanguageModelToolSchemaFormat::JsonSchema,
1327 ))
1328 .unwrap(),
1329 output_schema: None,
1330 annotations: None,
1331 }],
1332 &context_server_store,
1333 cx,
1334 );
1335
1336 let events = thread.update(cx, |thread, cx| {
1337 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1338 });
1339 cx.run_until_parked();
1340
1341 // Simulate the model calling the MCP tool.
1342 let completion = fake_model.pending_completions().pop().unwrap();
1343 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1344 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1345 LanguageModelToolUse {
1346 id: "tool_1".into(),
1347 name: "echo".into(),
1348 raw_input: json!({"text": "test"}).to_string(),
1349 input: json!({"text": "test"}),
1350 is_input_complete: true,
1351 thought_signature: None,
1352 },
1353 ));
1354 fake_model.end_last_completion_stream();
1355 cx.run_until_parked();
1356
1357 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1358 assert_eq!(tool_call_params.name, "echo");
1359 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1360 tool_call_response
1361 .send(context_server::types::CallToolResponse {
1362 content: vec![context_server::types::ToolResponseContent::Text {
1363 text: "test".into(),
1364 }],
1365 is_error: None,
1366 meta: None,
1367 structured_content: None,
1368 })
1369 .unwrap();
1370 cx.run_until_parked();
1371
1372 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1373 fake_model.send_last_completion_stream_text_chunk("Done!");
1374 fake_model.end_last_completion_stream();
1375 events.collect::<Vec<_>>().await;
1376
1377 // Send again after adding the echo tool, ensuring the name collision is resolved.
1378 let events = thread.update(cx, |thread, cx| {
1379 thread.add_tool(EchoTool, None);
1380 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1381 });
1382 cx.run_until_parked();
1383 let completion = fake_model.pending_completions().pop().unwrap();
1384 assert_eq!(
1385 tool_names_for_completion(&completion),
1386 vec!["echo", "test_server_echo"]
1387 );
1388 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1389 LanguageModelToolUse {
1390 id: "tool_2".into(),
1391 name: "test_server_echo".into(),
1392 raw_input: json!({"text": "mcp"}).to_string(),
1393 input: json!({"text": "mcp"}),
1394 is_input_complete: true,
1395 thought_signature: None,
1396 },
1397 ));
1398 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1399 LanguageModelToolUse {
1400 id: "tool_3".into(),
1401 name: "echo".into(),
1402 raw_input: json!({"text": "native"}).to_string(),
1403 input: json!({"text": "native"}),
1404 is_input_complete: true,
1405 thought_signature: None,
1406 },
1407 ));
1408 fake_model.end_last_completion_stream();
1409 cx.run_until_parked();
1410
1411 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1412 assert_eq!(tool_call_params.name, "echo");
1413 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1414 tool_call_response
1415 .send(context_server::types::CallToolResponse {
1416 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1417 is_error: None,
1418 meta: None,
1419 structured_content: None,
1420 })
1421 .unwrap();
1422 cx.run_until_parked();
1423
1424 // Ensure the tool results were inserted with the correct names.
1425 let completion = fake_model.pending_completions().pop().unwrap();
1426 assert_eq!(
1427 completion.messages.last().unwrap().content,
1428 vec![
1429 MessageContent::ToolResult(LanguageModelToolResult {
1430 tool_use_id: "tool_3".into(),
1431 tool_name: "echo".into(),
1432 is_error: false,
1433 content: "native".into(),
1434 output: Some("native".into()),
1435 },),
1436 MessageContent::ToolResult(LanguageModelToolResult {
1437 tool_use_id: "tool_2".into(),
1438 tool_name: "test_server_echo".into(),
1439 is_error: false,
1440 content: "mcp".into(),
1441 output: Some("mcp".into()),
1442 },),
1443 ]
1444 );
1445 fake_model.end_last_completion_stream();
1446 events.collect::<Vec<_>>().await;
1447}
1448
1449#[gpui::test]
1450async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1451 let ThreadTest {
1452 model,
1453 thread,
1454 context_server_store,
1455 fs,
1456 ..
1457 } = setup(cx, TestModel::Fake).await;
1458 let fake_model = model.as_fake();
1459
1460 // Set up a profile with all tools enabled
1461 fs.insert_file(
1462 paths::settings_file(),
1463 json!({
1464 "agent": {
1465 "profiles": {
1466 "test": {
1467 "name": "Test Profile",
1468 "enable_all_context_servers": true,
1469 "tools": {
1470 EchoTool::NAME: true,
1471 DelayTool::NAME: true,
1472 WordListTool::NAME: true,
1473 ToolRequiringPermission::NAME: true,
1474 InfiniteTool::NAME: true,
1475 }
1476 },
1477 }
1478 }
1479 })
1480 .to_string()
1481 .into_bytes(),
1482 )
1483 .await;
1484 cx.run_until_parked();
1485
1486 thread.update(cx, |thread, cx| {
1487 thread.set_profile(AgentProfileId("test".into()), cx);
1488 thread.add_tool(EchoTool, None);
1489 thread.add_tool(DelayTool, None);
1490 thread.add_tool(WordListTool, None);
1491 thread.add_tool(ToolRequiringPermission, None);
1492 thread.add_tool(InfiniteTool, None);
1493 });
1494
1495 // Set up multiple context servers with some overlapping tool names
1496 let _server1_calls = setup_context_server(
1497 "xxx",
1498 vec![
1499 context_server::types::Tool {
1500 name: "echo".into(), // Conflicts with native EchoTool
1501 description: None,
1502 input_schema: serde_json::to_value(EchoTool::input_schema(
1503 LanguageModelToolSchemaFormat::JsonSchema,
1504 ))
1505 .unwrap(),
1506 output_schema: None,
1507 annotations: None,
1508 },
1509 context_server::types::Tool {
1510 name: "unique_tool_1".into(),
1511 description: None,
1512 input_schema: json!({"type": "object", "properties": {}}),
1513 output_schema: None,
1514 annotations: None,
1515 },
1516 ],
1517 &context_server_store,
1518 cx,
1519 );
1520
1521 let _server2_calls = setup_context_server(
1522 "yyy",
1523 vec![
1524 context_server::types::Tool {
1525 name: "echo".into(), // Also conflicts with native EchoTool
1526 description: None,
1527 input_schema: serde_json::to_value(EchoTool::input_schema(
1528 LanguageModelToolSchemaFormat::JsonSchema,
1529 ))
1530 .unwrap(),
1531 output_schema: None,
1532 annotations: None,
1533 },
1534 context_server::types::Tool {
1535 name: "unique_tool_2".into(),
1536 description: None,
1537 input_schema: json!({"type": "object", "properties": {}}),
1538 output_schema: None,
1539 annotations: None,
1540 },
1541 context_server::types::Tool {
1542 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1543 description: None,
1544 input_schema: json!({"type": "object", "properties": {}}),
1545 output_schema: None,
1546 annotations: None,
1547 },
1548 context_server::types::Tool {
1549 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1550 description: None,
1551 input_schema: json!({"type": "object", "properties": {}}),
1552 output_schema: None,
1553 annotations: None,
1554 },
1555 ],
1556 &context_server_store,
1557 cx,
1558 );
1559 let _server3_calls = setup_context_server(
1560 "zzz",
1561 vec![
1562 context_server::types::Tool {
1563 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1564 description: None,
1565 input_schema: json!({"type": "object", "properties": {}}),
1566 output_schema: None,
1567 annotations: None,
1568 },
1569 context_server::types::Tool {
1570 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1571 description: None,
1572 input_schema: json!({"type": "object", "properties": {}}),
1573 output_schema: None,
1574 annotations: None,
1575 },
1576 context_server::types::Tool {
1577 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1578 description: None,
1579 input_schema: json!({"type": "object", "properties": {}}),
1580 output_schema: None,
1581 annotations: None,
1582 },
1583 ],
1584 &context_server_store,
1585 cx,
1586 );
1587
1588 thread
1589 .update(cx, |thread, cx| {
1590 thread.send(UserMessageId::new(), ["Go"], cx)
1591 })
1592 .unwrap();
1593 cx.run_until_parked();
1594 let completion = fake_model.pending_completions().pop().unwrap();
1595 assert_eq!(
1596 tool_names_for_completion(&completion),
1597 vec![
1598 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1599 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1600 "delay",
1601 "echo",
1602 "infinite",
1603 "tool_requiring_permission",
1604 "unique_tool_1",
1605 "unique_tool_2",
1606 "word_list",
1607 "xxx_echo",
1608 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1609 "yyy_echo",
1610 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1611 ]
1612 );
1613}
1614
1615#[gpui::test]
1616#[cfg_attr(not(feature = "e2e"), ignore)]
1617async fn test_cancellation(cx: &mut TestAppContext) {
1618 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1619
1620 let mut events = thread
1621 .update(cx, |thread, cx| {
1622 thread.add_tool(InfiniteTool, None);
1623 thread.add_tool(EchoTool, None);
1624 thread.send(
1625 UserMessageId::new(),
1626 ["Call the echo tool, then call the infinite tool, then explain their output"],
1627 cx,
1628 )
1629 })
1630 .unwrap();
1631
1632 // Wait until both tools are called.
1633 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1634 let mut echo_id = None;
1635 let mut echo_completed = false;
1636 while let Some(event) = events.next().await {
1637 match event.unwrap() {
1638 ThreadEvent::ToolCall(tool_call) => {
1639 assert_eq!(tool_call.title, expected_tools.remove(0));
1640 if tool_call.title == "Echo" {
1641 echo_id = Some(tool_call.tool_call_id);
1642 }
1643 }
1644 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1645 acp::ToolCallUpdate {
1646 tool_call_id,
1647 fields:
1648 acp::ToolCallUpdateFields {
1649 status: Some(acp::ToolCallStatus::Completed),
1650 ..
1651 },
1652 ..
1653 },
1654 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1655 echo_completed = true;
1656 }
1657 _ => {}
1658 }
1659
1660 if expected_tools.is_empty() && echo_completed {
1661 break;
1662 }
1663 }
1664
1665 // Cancel the current send and ensure that the event stream is closed, even
1666 // if one of the tools is still running.
1667 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1668 let events = events.collect::<Vec<_>>().await;
1669 let last_event = events.last();
1670 assert!(
1671 matches!(
1672 last_event,
1673 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1674 ),
1675 "unexpected event {last_event:?}"
1676 );
1677
1678 // Ensure we can still send a new message after cancellation.
1679 let events = thread
1680 .update(cx, |thread, cx| {
1681 thread.send(
1682 UserMessageId::new(),
1683 ["Testing: reply with 'Hello' then stop."],
1684 cx,
1685 )
1686 })
1687 .unwrap()
1688 .collect::<Vec<_>>()
1689 .await;
1690 thread.update(cx, |thread, _cx| {
1691 let message = thread.last_message().unwrap();
1692 let agent_message = message.as_agent_message().unwrap();
1693 assert_eq!(
1694 agent_message.content,
1695 vec![AgentMessageContent::Text("Hello".to_string())]
1696 );
1697 });
1698 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1699}
1700
1701#[gpui::test]
1702async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1703 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1704 always_allow_tools(cx);
1705 let fake_model = model.as_fake();
1706
1707 let environment = Rc::new(cx.update(|cx| {
1708 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1709 }));
1710 let handle = environment.terminal_handle.clone().unwrap();
1711
1712 let mut events = thread
1713 .update(cx, |thread, cx| {
1714 thread.add_tool(
1715 crate::TerminalTool::new(thread.project().clone(), environment),
1716 None,
1717 );
1718 thread.send(UserMessageId::new(), ["run a command"], cx)
1719 })
1720 .unwrap();
1721
1722 cx.run_until_parked();
1723
1724 // Simulate the model calling the terminal tool
1725 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1726 LanguageModelToolUse {
1727 id: "terminal_tool_1".into(),
1728 name: TerminalTool::NAME.into(),
1729 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1730 input: json!({"command": "sleep 1000", "cd": "."}),
1731 is_input_complete: true,
1732 thought_signature: None,
1733 },
1734 ));
1735 fake_model.end_last_completion_stream();
1736
1737 // Wait for the terminal tool to start running
1738 wait_for_terminal_tool_started(&mut events, cx).await;
1739
1740 // Cancel the thread while the terminal is running
1741 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1742
1743 // Collect remaining events, driving the executor to let cancellation complete
1744 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1745
1746 // Verify the terminal was killed
1747 assert!(
1748 handle.was_killed(),
1749 "expected terminal handle to be killed on cancellation"
1750 );
1751
1752 // Verify we got a cancellation stop event
1753 assert_eq!(
1754 stop_events(remaining_events),
1755 vec![acp::StopReason::Cancelled],
1756 );
1757
1758 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1759 thread.update(cx, |thread, _cx| {
1760 let message = thread.last_message().unwrap();
1761 let agent_message = message.as_agent_message().unwrap();
1762
1763 let tool_use = agent_message
1764 .content
1765 .iter()
1766 .find_map(|content| match content {
1767 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1768 _ => None,
1769 })
1770 .expect("expected tool use in agent message");
1771
1772 let tool_result = agent_message
1773 .tool_results
1774 .get(&tool_use.id)
1775 .expect("expected tool result");
1776
1777 let result_text = match &tool_result.content {
1778 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1779 _ => panic!("expected text content in tool result"),
1780 };
1781
1782 // "partial output" comes from FakeTerminalHandle's output field
1783 assert!(
1784 result_text.contains("partial output"),
1785 "expected tool result to contain terminal output, got: {result_text}"
1786 );
1787 // Match the actual format from process_content in terminal_tool.rs
1788 assert!(
1789 result_text.contains("The user stopped this command"),
1790 "expected tool result to indicate user stopped, got: {result_text}"
1791 );
1792 });
1793
1794 // Verify we can send a new message after cancellation
1795 verify_thread_recovery(&thread, &fake_model, cx).await;
1796}
1797
1798#[gpui::test]
1799async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1800 // This test verifies that tools which properly handle cancellation via
1801 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1802 // to cancellation and report that they were cancelled.
1803 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1804 always_allow_tools(cx);
1805 let fake_model = model.as_fake();
1806
1807 let (tool, was_cancelled) = CancellationAwareTool::new();
1808
1809 let mut events = thread
1810 .update(cx, |thread, cx| {
1811 thread.add_tool(tool, None);
1812 thread.send(
1813 UserMessageId::new(),
1814 ["call the cancellation aware tool"],
1815 cx,
1816 )
1817 })
1818 .unwrap();
1819
1820 cx.run_until_parked();
1821
1822 // Simulate the model calling the cancellation-aware tool
1823 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1824 LanguageModelToolUse {
1825 id: "cancellation_aware_1".into(),
1826 name: "cancellation_aware".into(),
1827 raw_input: r#"{}"#.into(),
1828 input: json!({}),
1829 is_input_complete: true,
1830 thought_signature: None,
1831 },
1832 ));
1833 fake_model.end_last_completion_stream();
1834
1835 cx.run_until_parked();
1836
1837 // Wait for the tool call to be reported
1838 let mut tool_started = false;
1839 let deadline = cx.executor().num_cpus() * 100;
1840 for _ in 0..deadline {
1841 cx.run_until_parked();
1842
1843 while let Some(Some(event)) = events.next().now_or_never() {
1844 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1845 if tool_call.title == "Cancellation Aware Tool" {
1846 tool_started = true;
1847 break;
1848 }
1849 }
1850 }
1851
1852 if tool_started {
1853 break;
1854 }
1855
1856 cx.background_executor
1857 .timer(Duration::from_millis(10))
1858 .await;
1859 }
1860 assert!(tool_started, "expected cancellation aware tool to start");
1861
1862 // Cancel the thread and wait for it to complete
1863 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1864
1865 // The cancel task should complete promptly because the tool handles cancellation
1866 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1867 futures::select! {
1868 _ = cancel_task.fuse() => {}
1869 _ = timeout.fuse() => {
1870 panic!("cancel task timed out - tool did not respond to cancellation");
1871 }
1872 }
1873
1874 // Verify the tool detected cancellation via its flag
1875 assert!(
1876 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1877 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1878 );
1879
1880 // Collect remaining events
1881 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1882
1883 // Verify we got a cancellation stop event
1884 assert_eq!(
1885 stop_events(remaining_events),
1886 vec![acp::StopReason::Cancelled],
1887 );
1888
1889 // Verify we can send a new message after cancellation
1890 verify_thread_recovery(&thread, &fake_model, cx).await;
1891}
1892
1893/// Helper to verify thread can recover after cancellation by sending a simple message.
1894async fn verify_thread_recovery(
1895 thread: &Entity<Thread>,
1896 fake_model: &FakeLanguageModel,
1897 cx: &mut TestAppContext,
1898) {
1899 let events = thread
1900 .update(cx, |thread, cx| {
1901 thread.send(
1902 UserMessageId::new(),
1903 ["Testing: reply with 'Hello' then stop."],
1904 cx,
1905 )
1906 })
1907 .unwrap();
1908 cx.run_until_parked();
1909 fake_model.send_last_completion_stream_text_chunk("Hello");
1910 fake_model
1911 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1912 fake_model.end_last_completion_stream();
1913
1914 let events = events.collect::<Vec<_>>().await;
1915 thread.update(cx, |thread, _cx| {
1916 let message = thread.last_message().unwrap();
1917 let agent_message = message.as_agent_message().unwrap();
1918 assert_eq!(
1919 agent_message.content,
1920 vec![AgentMessageContent::Text("Hello".to_string())]
1921 );
1922 });
1923 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1924}
1925
1926/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1927async fn wait_for_terminal_tool_started(
1928 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1929 cx: &mut TestAppContext,
1930) {
1931 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1932 for _ in 0..deadline {
1933 cx.run_until_parked();
1934
1935 while let Some(Some(event)) = events.next().now_or_never() {
1936 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1937 update,
1938 ))) = &event
1939 {
1940 if update.fields.content.as_ref().is_some_and(|content| {
1941 content
1942 .iter()
1943 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1944 }) {
1945 return;
1946 }
1947 }
1948 }
1949
1950 cx.background_executor
1951 .timer(Duration::from_millis(10))
1952 .await;
1953 }
1954 panic!("terminal tool did not start within the expected time");
1955}
1956
1957/// Collects events until a Stop event is received, driving the executor to completion.
1958async fn collect_events_until_stop(
1959 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1960 cx: &mut TestAppContext,
1961) -> Vec<Result<ThreadEvent>> {
1962 let mut collected = Vec::new();
1963 let deadline = cx.executor().num_cpus() * 200;
1964
1965 for _ in 0..deadline {
1966 cx.executor().advance_clock(Duration::from_millis(10));
1967 cx.run_until_parked();
1968
1969 while let Some(Some(event)) = events.next().now_or_never() {
1970 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1971 collected.push(event);
1972 if is_stop {
1973 return collected;
1974 }
1975 }
1976 }
1977 panic!(
1978 "did not receive Stop event within the expected time; collected {} events",
1979 collected.len()
1980 );
1981}
1982
1983#[gpui::test]
1984async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1985 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1986 always_allow_tools(cx);
1987 let fake_model = model.as_fake();
1988
1989 let environment = Rc::new(cx.update(|cx| {
1990 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1991 }));
1992 let handle = environment.terminal_handle.clone().unwrap();
1993
1994 let message_id = UserMessageId::new();
1995 let mut events = thread
1996 .update(cx, |thread, cx| {
1997 thread.add_tool(
1998 crate::TerminalTool::new(thread.project().clone(), environment),
1999 None,
2000 );
2001 thread.send(message_id.clone(), ["run a command"], cx)
2002 })
2003 .unwrap();
2004
2005 cx.run_until_parked();
2006
2007 // Simulate the model calling the terminal tool
2008 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2009 LanguageModelToolUse {
2010 id: "terminal_tool_1".into(),
2011 name: TerminalTool::NAME.into(),
2012 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2013 input: json!({"command": "sleep 1000", "cd": "."}),
2014 is_input_complete: true,
2015 thought_signature: None,
2016 },
2017 ));
2018 fake_model.end_last_completion_stream();
2019
2020 // Wait for the terminal tool to start running
2021 wait_for_terminal_tool_started(&mut events, cx).await;
2022
2023 // Truncate the thread while the terminal is running
2024 thread
2025 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2026 .unwrap();
2027
2028 // Drive the executor to let cancellation complete
2029 let _ = collect_events_until_stop(&mut events, cx).await;
2030
2031 // Verify the terminal was killed
2032 assert!(
2033 handle.was_killed(),
2034 "expected terminal handle to be killed on truncate"
2035 );
2036
2037 // Verify the thread is empty after truncation
2038 thread.update(cx, |thread, _cx| {
2039 assert_eq!(
2040 thread.to_markdown(),
2041 "",
2042 "expected thread to be empty after truncating the only message"
2043 );
2044 });
2045
2046 // Verify we can send a new message after truncation
2047 verify_thread_recovery(&thread, &fake_model, cx).await;
2048}
2049
2050#[gpui::test]
2051async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2052 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2054 always_allow_tools(cx);
2055 let fake_model = model.as_fake();
2056
2057 let environment = Rc::new(MultiTerminalEnvironment::new());
2058
2059 let mut events = thread
2060 .update(cx, |thread, cx| {
2061 thread.add_tool(
2062 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2063 None,
2064 );
2065 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2066 })
2067 .unwrap();
2068
2069 cx.run_until_parked();
2070
2071 // Simulate the model calling two terminal tools
2072 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2073 LanguageModelToolUse {
2074 id: "terminal_tool_1".into(),
2075 name: TerminalTool::NAME.into(),
2076 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2077 input: json!({"command": "sleep 1000", "cd": "."}),
2078 is_input_complete: true,
2079 thought_signature: None,
2080 },
2081 ));
2082 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2083 LanguageModelToolUse {
2084 id: "terminal_tool_2".into(),
2085 name: TerminalTool::NAME.into(),
2086 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2087 input: json!({"command": "sleep 2000", "cd": "."}),
2088 is_input_complete: true,
2089 thought_signature: None,
2090 },
2091 ));
2092 fake_model.end_last_completion_stream();
2093
2094 // Wait for both terminal tools to start by counting terminal content updates
2095 let mut terminals_started = 0;
2096 let deadline = cx.executor().num_cpus() * 100;
2097 for _ in 0..deadline {
2098 cx.run_until_parked();
2099
2100 while let Some(Some(event)) = events.next().now_or_never() {
2101 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2102 update,
2103 ))) = &event
2104 {
2105 if update.fields.content.as_ref().is_some_and(|content| {
2106 content
2107 .iter()
2108 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2109 }) {
2110 terminals_started += 1;
2111 if terminals_started >= 2 {
2112 break;
2113 }
2114 }
2115 }
2116 }
2117 if terminals_started >= 2 {
2118 break;
2119 }
2120
2121 cx.background_executor
2122 .timer(Duration::from_millis(10))
2123 .await;
2124 }
2125 assert!(
2126 terminals_started >= 2,
2127 "expected 2 terminal tools to start, got {terminals_started}"
2128 );
2129
2130 // Cancel the thread while both terminals are running
2131 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2132
2133 // Collect remaining events
2134 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2135
2136 // Verify both terminal handles were killed
2137 let handles = environment.handles();
2138 assert_eq!(
2139 handles.len(),
2140 2,
2141 "expected 2 terminal handles to be created"
2142 );
2143 assert!(
2144 handles[0].was_killed(),
2145 "expected first terminal handle to be killed on cancellation"
2146 );
2147 assert!(
2148 handles[1].was_killed(),
2149 "expected second terminal handle to be killed on cancellation"
2150 );
2151
2152 // Verify we got a cancellation stop event
2153 assert_eq!(
2154 stop_events(remaining_events),
2155 vec![acp::StopReason::Cancelled],
2156 );
2157}
2158
2159#[gpui::test]
2160async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2161 // Tests that clicking the stop button on the terminal card (as opposed to the main
2162 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2163 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2164 always_allow_tools(cx);
2165 let fake_model = model.as_fake();
2166
2167 let environment = Rc::new(cx.update(|cx| {
2168 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2169 }));
2170 let handle = environment.terminal_handle.clone().unwrap();
2171
2172 let mut events = thread
2173 .update(cx, |thread, cx| {
2174 thread.add_tool(
2175 crate::TerminalTool::new(thread.project().clone(), environment),
2176 None,
2177 );
2178 thread.send(UserMessageId::new(), ["run a command"], cx)
2179 })
2180 .unwrap();
2181
2182 cx.run_until_parked();
2183
2184 // Simulate the model calling the terminal tool
2185 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2186 LanguageModelToolUse {
2187 id: "terminal_tool_1".into(),
2188 name: TerminalTool::NAME.into(),
2189 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2190 input: json!({"command": "sleep 1000", "cd": "."}),
2191 is_input_complete: true,
2192 thought_signature: None,
2193 },
2194 ));
2195 fake_model.end_last_completion_stream();
2196
2197 // Wait for the terminal tool to start running
2198 wait_for_terminal_tool_started(&mut events, cx).await;
2199
2200 // Simulate user clicking stop on the terminal card itself.
2201 // This sets the flag and signals exit (simulating what the real UI would do).
2202 handle.set_stopped_by_user(true);
2203 handle.killed.store(true, Ordering::SeqCst);
2204 handle.signal_exit();
2205
2206 // Wait for the tool to complete
2207 cx.run_until_parked();
2208
2209 // The thread continues after tool completion - simulate the model ending its turn
2210 fake_model
2211 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2212 fake_model.end_last_completion_stream();
2213
2214 // Collect remaining events
2215 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2216
2217 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2218 assert_eq!(
2219 stop_events(remaining_events),
2220 vec![acp::StopReason::EndTurn],
2221 );
2222
2223 // Verify the tool result indicates user stopped
2224 thread.update(cx, |thread, _cx| {
2225 let message = thread.last_message().unwrap();
2226 let agent_message = message.as_agent_message().unwrap();
2227
2228 let tool_use = agent_message
2229 .content
2230 .iter()
2231 .find_map(|content| match content {
2232 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2233 _ => None,
2234 })
2235 .expect("expected tool use in agent message");
2236
2237 let tool_result = agent_message
2238 .tool_results
2239 .get(&tool_use.id)
2240 .expect("expected tool result");
2241
2242 let result_text = match &tool_result.content {
2243 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2244 _ => panic!("expected text content in tool result"),
2245 };
2246
2247 assert!(
2248 result_text.contains("The user stopped this command"),
2249 "expected tool result to indicate user stopped, got: {result_text}"
2250 );
2251 });
2252}
2253
2254#[gpui::test]
2255async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2256 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2257 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2258 always_allow_tools(cx);
2259 let fake_model = model.as_fake();
2260
2261 let environment = Rc::new(cx.update(|cx| {
2262 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2263 }));
2264 let handle = environment.terminal_handle.clone().unwrap();
2265
2266 let mut events = thread
2267 .update(cx, |thread, cx| {
2268 thread.add_tool(
2269 crate::TerminalTool::new(thread.project().clone(), environment),
2270 None,
2271 );
2272 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2273 })
2274 .unwrap();
2275
2276 cx.run_until_parked();
2277
2278 // Simulate the model calling the terminal tool with a short timeout
2279 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2280 LanguageModelToolUse {
2281 id: "terminal_tool_1".into(),
2282 name: TerminalTool::NAME.into(),
2283 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2284 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2285 is_input_complete: true,
2286 thought_signature: None,
2287 },
2288 ));
2289 fake_model.end_last_completion_stream();
2290
2291 // Wait for the terminal tool to start running
2292 wait_for_terminal_tool_started(&mut events, cx).await;
2293
2294 // Advance clock past the timeout
2295 cx.executor().advance_clock(Duration::from_millis(200));
2296 cx.run_until_parked();
2297
2298 // The thread continues after tool completion - simulate the model ending its turn
2299 fake_model
2300 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2301 fake_model.end_last_completion_stream();
2302
2303 // Collect remaining events
2304 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2305
2306 // Verify the terminal was killed due to timeout
2307 assert!(
2308 handle.was_killed(),
2309 "expected terminal handle to be killed on timeout"
2310 );
2311
2312 // Verify we got an EndTurn (the tool completed, just with timeout)
2313 assert_eq!(
2314 stop_events(remaining_events),
2315 vec![acp::StopReason::EndTurn],
2316 );
2317
2318 // Verify the tool result indicates timeout, not user stopped
2319 thread.update(cx, |thread, _cx| {
2320 let message = thread.last_message().unwrap();
2321 let agent_message = message.as_agent_message().unwrap();
2322
2323 let tool_use = agent_message
2324 .content
2325 .iter()
2326 .find_map(|content| match content {
2327 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2328 _ => None,
2329 })
2330 .expect("expected tool use in agent message");
2331
2332 let tool_result = agent_message
2333 .tool_results
2334 .get(&tool_use.id)
2335 .expect("expected tool result");
2336
2337 let result_text = match &tool_result.content {
2338 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2339 _ => panic!("expected text content in tool result"),
2340 };
2341
2342 assert!(
2343 result_text.contains("timed out"),
2344 "expected tool result to indicate timeout, got: {result_text}"
2345 );
2346 assert!(
2347 !result_text.contains("The user stopped"),
2348 "tool result should not mention user stopped when it timed out, got: {result_text}"
2349 );
2350 });
2351}
2352
2353#[gpui::test]
2354async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2355 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2356 let fake_model = model.as_fake();
2357
2358 let events_1 = thread
2359 .update(cx, |thread, cx| {
2360 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2361 })
2362 .unwrap();
2363 cx.run_until_parked();
2364 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2365 cx.run_until_parked();
2366
2367 let events_2 = thread
2368 .update(cx, |thread, cx| {
2369 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2370 })
2371 .unwrap();
2372 cx.run_until_parked();
2373 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2374 fake_model
2375 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2376 fake_model.end_last_completion_stream();
2377
2378 let events_1 = events_1.collect::<Vec<_>>().await;
2379 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2380 let events_2 = events_2.collect::<Vec<_>>().await;
2381 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2382}
2383
2384#[gpui::test]
2385async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2386 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2387 let fake_model = model.as_fake();
2388
2389 let events_1 = thread
2390 .update(cx, |thread, cx| {
2391 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2392 })
2393 .unwrap();
2394 cx.run_until_parked();
2395 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2396 fake_model
2397 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2398 fake_model.end_last_completion_stream();
2399 let events_1 = events_1.collect::<Vec<_>>().await;
2400
2401 let events_2 = thread
2402 .update(cx, |thread, cx| {
2403 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2404 })
2405 .unwrap();
2406 cx.run_until_parked();
2407 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2408 fake_model
2409 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2410 fake_model.end_last_completion_stream();
2411 let events_2 = events_2.collect::<Vec<_>>().await;
2412
2413 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2414 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2415}
2416
2417#[gpui::test]
2418async fn test_refusal(cx: &mut TestAppContext) {
2419 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2420 let fake_model = model.as_fake();
2421
2422 let events = thread
2423 .update(cx, |thread, cx| {
2424 thread.send(UserMessageId::new(), ["Hello"], cx)
2425 })
2426 .unwrap();
2427 cx.run_until_parked();
2428 thread.read_with(cx, |thread, _| {
2429 assert_eq!(
2430 thread.to_markdown(),
2431 indoc! {"
2432 ## User
2433
2434 Hello
2435 "}
2436 );
2437 });
2438
2439 fake_model.send_last_completion_stream_text_chunk("Hey!");
2440 cx.run_until_parked();
2441 thread.read_with(cx, |thread, _| {
2442 assert_eq!(
2443 thread.to_markdown(),
2444 indoc! {"
2445 ## User
2446
2447 Hello
2448
2449 ## Assistant
2450
2451 Hey!
2452 "}
2453 );
2454 });
2455
2456 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2457 fake_model
2458 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2459 let events = events.collect::<Vec<_>>().await;
2460 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2461 thread.read_with(cx, |thread, _| {
2462 assert_eq!(thread.to_markdown(), "");
2463 });
2464}
2465
2466#[gpui::test]
2467async fn test_truncate_first_message(cx: &mut TestAppContext) {
2468 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2469 let fake_model = model.as_fake();
2470
2471 let message_id = UserMessageId::new();
2472 thread
2473 .update(cx, |thread, cx| {
2474 thread.send(message_id.clone(), ["Hello"], cx)
2475 })
2476 .unwrap();
2477 cx.run_until_parked();
2478 thread.read_with(cx, |thread, _| {
2479 assert_eq!(
2480 thread.to_markdown(),
2481 indoc! {"
2482 ## User
2483
2484 Hello
2485 "}
2486 );
2487 assert_eq!(thread.latest_token_usage(), None);
2488 });
2489
2490 fake_model.send_last_completion_stream_text_chunk("Hey!");
2491 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2492 language_model::TokenUsage {
2493 input_tokens: 32_000,
2494 output_tokens: 16_000,
2495 cache_creation_input_tokens: 0,
2496 cache_read_input_tokens: 0,
2497 },
2498 ));
2499 cx.run_until_parked();
2500 thread.read_with(cx, |thread, _| {
2501 assert_eq!(
2502 thread.to_markdown(),
2503 indoc! {"
2504 ## User
2505
2506 Hello
2507
2508 ## Assistant
2509
2510 Hey!
2511 "}
2512 );
2513 assert_eq!(
2514 thread.latest_token_usage(),
2515 Some(acp_thread::TokenUsage {
2516 used_tokens: 32_000 + 16_000,
2517 max_tokens: 1_000_000,
2518 input_tokens: 32_000,
2519 output_tokens: 16_000,
2520 })
2521 );
2522 });
2523
2524 thread
2525 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2526 .unwrap();
2527 cx.run_until_parked();
2528 thread.read_with(cx, |thread, _| {
2529 assert_eq!(thread.to_markdown(), "");
2530 assert_eq!(thread.latest_token_usage(), None);
2531 });
2532
2533 // Ensure we can still send a new message after truncation.
2534 thread
2535 .update(cx, |thread, cx| {
2536 thread.send(UserMessageId::new(), ["Hi"], cx)
2537 })
2538 .unwrap();
2539 thread.update(cx, |thread, _cx| {
2540 assert_eq!(
2541 thread.to_markdown(),
2542 indoc! {"
2543 ## User
2544
2545 Hi
2546 "}
2547 );
2548 });
2549 cx.run_until_parked();
2550 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2551 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2552 language_model::TokenUsage {
2553 input_tokens: 40_000,
2554 output_tokens: 20_000,
2555 cache_creation_input_tokens: 0,
2556 cache_read_input_tokens: 0,
2557 },
2558 ));
2559 cx.run_until_parked();
2560 thread.read_with(cx, |thread, _| {
2561 assert_eq!(
2562 thread.to_markdown(),
2563 indoc! {"
2564 ## User
2565
2566 Hi
2567
2568 ## Assistant
2569
2570 Ahoy!
2571 "}
2572 );
2573
2574 assert_eq!(
2575 thread.latest_token_usage(),
2576 Some(acp_thread::TokenUsage {
2577 used_tokens: 40_000 + 20_000,
2578 max_tokens: 1_000_000,
2579 input_tokens: 40_000,
2580 output_tokens: 20_000,
2581 })
2582 );
2583 });
2584}
2585
2586#[gpui::test]
2587async fn test_truncate_second_message(cx: &mut TestAppContext) {
2588 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2589 let fake_model = model.as_fake();
2590
2591 thread
2592 .update(cx, |thread, cx| {
2593 thread.send(UserMessageId::new(), ["Message 1"], cx)
2594 })
2595 .unwrap();
2596 cx.run_until_parked();
2597 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2598 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2599 language_model::TokenUsage {
2600 input_tokens: 32_000,
2601 output_tokens: 16_000,
2602 cache_creation_input_tokens: 0,
2603 cache_read_input_tokens: 0,
2604 },
2605 ));
2606 fake_model.end_last_completion_stream();
2607 cx.run_until_parked();
2608
2609 let assert_first_message_state = |cx: &mut TestAppContext| {
2610 thread.clone().read_with(cx, |thread, _| {
2611 assert_eq!(
2612 thread.to_markdown(),
2613 indoc! {"
2614 ## User
2615
2616 Message 1
2617
2618 ## Assistant
2619
2620 Message 1 response
2621 "}
2622 );
2623
2624 assert_eq!(
2625 thread.latest_token_usage(),
2626 Some(acp_thread::TokenUsage {
2627 used_tokens: 32_000 + 16_000,
2628 max_tokens: 1_000_000,
2629 input_tokens: 32_000,
2630 output_tokens: 16_000,
2631 })
2632 );
2633 });
2634 };
2635
2636 assert_first_message_state(cx);
2637
2638 let second_message_id = UserMessageId::new();
2639 thread
2640 .update(cx, |thread, cx| {
2641 thread.send(second_message_id.clone(), ["Message 2"], cx)
2642 })
2643 .unwrap();
2644 cx.run_until_parked();
2645
2646 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2647 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2648 language_model::TokenUsage {
2649 input_tokens: 40_000,
2650 output_tokens: 20_000,
2651 cache_creation_input_tokens: 0,
2652 cache_read_input_tokens: 0,
2653 },
2654 ));
2655 fake_model.end_last_completion_stream();
2656 cx.run_until_parked();
2657
2658 thread.read_with(cx, |thread, _| {
2659 assert_eq!(
2660 thread.to_markdown(),
2661 indoc! {"
2662 ## User
2663
2664 Message 1
2665
2666 ## Assistant
2667
2668 Message 1 response
2669
2670 ## User
2671
2672 Message 2
2673
2674 ## Assistant
2675
2676 Message 2 response
2677 "}
2678 );
2679
2680 assert_eq!(
2681 thread.latest_token_usage(),
2682 Some(acp_thread::TokenUsage {
2683 used_tokens: 40_000 + 20_000,
2684 max_tokens: 1_000_000,
2685 input_tokens: 40_000,
2686 output_tokens: 20_000,
2687 })
2688 );
2689 });
2690
2691 thread
2692 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2693 .unwrap();
2694 cx.run_until_parked();
2695
2696 assert_first_message_state(cx);
2697}
2698
2699#[gpui::test]
2700async fn test_title_generation(cx: &mut TestAppContext) {
2701 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2702 let fake_model = model.as_fake();
2703
2704 let summary_model = Arc::new(FakeLanguageModel::default());
2705 thread.update(cx, |thread, cx| {
2706 thread.set_summarization_model(Some(summary_model.clone()), cx)
2707 });
2708
2709 let send = thread
2710 .update(cx, |thread, cx| {
2711 thread.send(UserMessageId::new(), ["Hello"], cx)
2712 })
2713 .unwrap();
2714 cx.run_until_parked();
2715
2716 fake_model.send_last_completion_stream_text_chunk("Hey!");
2717 fake_model.end_last_completion_stream();
2718 cx.run_until_parked();
2719 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2720
2721 // Ensure the summary model has been invoked to generate a title.
2722 summary_model.send_last_completion_stream_text_chunk("Hello ");
2723 summary_model.send_last_completion_stream_text_chunk("world\nG");
2724 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2725 summary_model.end_last_completion_stream();
2726 send.collect::<Vec<_>>().await;
2727 cx.run_until_parked();
2728 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2729
2730 // Send another message, ensuring no title is generated this time.
2731 let send = thread
2732 .update(cx, |thread, cx| {
2733 thread.send(UserMessageId::new(), ["Hello again"], cx)
2734 })
2735 .unwrap();
2736 cx.run_until_parked();
2737 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2738 fake_model.end_last_completion_stream();
2739 cx.run_until_parked();
2740 assert_eq!(summary_model.pending_completions(), Vec::new());
2741 send.collect::<Vec<_>>().await;
2742 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2743}
2744
2745#[gpui::test]
2746async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2747 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2748 let fake_model = model.as_fake();
2749
2750 let _events = thread
2751 .update(cx, |thread, cx| {
2752 thread.add_tool(ToolRequiringPermission, None);
2753 thread.add_tool(EchoTool, None);
2754 thread.send(UserMessageId::new(), ["Hey!"], cx)
2755 })
2756 .unwrap();
2757 cx.run_until_parked();
2758
2759 let permission_tool_use = LanguageModelToolUse {
2760 id: "tool_id_1".into(),
2761 name: ToolRequiringPermission::NAME.into(),
2762 raw_input: "{}".into(),
2763 input: json!({}),
2764 is_input_complete: true,
2765 thought_signature: None,
2766 };
2767 let echo_tool_use = LanguageModelToolUse {
2768 id: "tool_id_2".into(),
2769 name: EchoTool::NAME.into(),
2770 raw_input: json!({"text": "test"}).to_string(),
2771 input: json!({"text": "test"}),
2772 is_input_complete: true,
2773 thought_signature: None,
2774 };
2775 fake_model.send_last_completion_stream_text_chunk("Hi!");
2776 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2777 permission_tool_use,
2778 ));
2779 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2780 echo_tool_use.clone(),
2781 ));
2782 fake_model.end_last_completion_stream();
2783 cx.run_until_parked();
2784
2785 // Ensure pending tools are skipped when building a request.
2786 let request = thread
2787 .read_with(cx, |thread, cx| {
2788 thread.build_completion_request(CompletionIntent::EditFile, cx)
2789 })
2790 .unwrap();
2791 assert_eq!(
2792 request.messages[1..],
2793 vec![
2794 LanguageModelRequestMessage {
2795 role: Role::User,
2796 content: vec!["Hey!".into()],
2797 cache: true,
2798 reasoning_details: None,
2799 },
2800 LanguageModelRequestMessage {
2801 role: Role::Assistant,
2802 content: vec![
2803 MessageContent::Text("Hi!".into()),
2804 MessageContent::ToolUse(echo_tool_use.clone())
2805 ],
2806 cache: false,
2807 reasoning_details: None,
2808 },
2809 LanguageModelRequestMessage {
2810 role: Role::User,
2811 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2812 tool_use_id: echo_tool_use.id.clone(),
2813 tool_name: echo_tool_use.name,
2814 is_error: false,
2815 content: "test".into(),
2816 output: Some("test".into())
2817 })],
2818 cache: false,
2819 reasoning_details: None,
2820 },
2821 ],
2822 );
2823}
2824
2825#[gpui::test]
2826async fn test_agent_connection(cx: &mut TestAppContext) {
2827 cx.update(settings::init);
2828 let templates = Templates::new();
2829
2830 // Initialize language model system with test provider
2831 cx.update(|cx| {
2832 gpui_tokio::init(cx);
2833
2834 let http_client = FakeHttpClient::with_404_response();
2835 let clock = Arc::new(clock::FakeSystemClock::new());
2836 let client = Client::new(clock, http_client, cx);
2837 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2838 language_model::init(client.clone(), cx);
2839 language_models::init(user_store, client.clone(), cx);
2840 LanguageModelRegistry::test(cx);
2841 });
2842 cx.executor().forbid_parking();
2843
2844 // Create a project for new_thread
2845 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2846 fake_fs.insert_tree(path!("/test"), json!({})).await;
2847 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2848 let cwd = Path::new("/test");
2849 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2850
2851 // Create agent and connection
2852 let agent = NativeAgent::new(
2853 project.clone(),
2854 thread_store,
2855 templates.clone(),
2856 None,
2857 fake_fs.clone(),
2858 &mut cx.to_async(),
2859 )
2860 .await
2861 .unwrap();
2862 let connection = NativeAgentConnection(agent.clone());
2863
2864 // Create a thread using new_thread
2865 let connection_rc = Rc::new(connection.clone());
2866 let acp_thread = cx
2867 .update(|cx| connection_rc.new_session(project, cwd, cx))
2868 .await
2869 .expect("new_thread should succeed");
2870
2871 // Get the session_id from the AcpThread
2872 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2873
2874 // Test model_selector returns Some
2875 let selector_opt = connection.model_selector(&session_id);
2876 assert!(
2877 selector_opt.is_some(),
2878 "agent should always support ModelSelector"
2879 );
2880 let selector = selector_opt.unwrap();
2881
2882 // Test list_models
2883 let listed_models = cx
2884 .update(|cx| selector.list_models(cx))
2885 .await
2886 .expect("list_models should succeed");
2887 let AgentModelList::Grouped(listed_models) = listed_models else {
2888 panic!("Unexpected model list type");
2889 };
2890 assert!(!listed_models.is_empty(), "should have at least one model");
2891 assert_eq!(
2892 listed_models[&AgentModelGroupName("Fake".into())][0]
2893 .id
2894 .0
2895 .as_ref(),
2896 "fake/fake"
2897 );
2898
2899 // Test selected_model returns the default
2900 let model = cx
2901 .update(|cx| selector.selected_model(cx))
2902 .await
2903 .expect("selected_model should succeed");
2904 let model = cx
2905 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2906 .unwrap();
2907 let model = model.as_fake();
2908 assert_eq!(model.id().0, "fake", "should return default model");
2909
2910 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2911 cx.run_until_parked();
2912 model.send_last_completion_stream_text_chunk("def");
2913 cx.run_until_parked();
2914 acp_thread.read_with(cx, |thread, cx| {
2915 assert_eq!(
2916 thread.to_markdown(cx),
2917 indoc! {"
2918 ## User
2919
2920 abc
2921
2922 ## Assistant
2923
2924 def
2925
2926 "}
2927 )
2928 });
2929
2930 // Test cancel
2931 cx.update(|cx| connection.cancel(&session_id, cx));
2932 request.await.expect("prompt should fail gracefully");
2933
2934 // Explicitly close the session and drop the ACP thread.
2935 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
2936 .await
2937 .unwrap();
2938 drop(acp_thread);
2939 let result = cx
2940 .update(|cx| {
2941 connection.prompt(
2942 Some(acp_thread::UserMessageId::new()),
2943 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2944 cx,
2945 )
2946 })
2947 .await;
2948 assert_eq!(
2949 result.as_ref().unwrap_err().to_string(),
2950 "Session not found",
2951 "unexpected result: {:?}",
2952 result
2953 );
2954}
2955
2956#[gpui::test]
2957async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2958 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2959 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
2960 let fake_model = model.as_fake();
2961
2962 let mut events = thread
2963 .update(cx, |thread, cx| {
2964 thread.send(UserMessageId::new(), ["Echo something"], cx)
2965 })
2966 .unwrap();
2967 cx.run_until_parked();
2968
2969 // Simulate streaming partial input.
2970 let input = json!({});
2971 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2972 LanguageModelToolUse {
2973 id: "1".into(),
2974 name: EchoTool::NAME.into(),
2975 raw_input: input.to_string(),
2976 input,
2977 is_input_complete: false,
2978 thought_signature: None,
2979 },
2980 ));
2981
2982 // Input streaming completed
2983 let input = json!({ "text": "Hello!" });
2984 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2985 LanguageModelToolUse {
2986 id: "1".into(),
2987 name: "echo".into(),
2988 raw_input: input.to_string(),
2989 input,
2990 is_input_complete: true,
2991 thought_signature: None,
2992 },
2993 ));
2994 fake_model.end_last_completion_stream();
2995 cx.run_until_parked();
2996
2997 let tool_call = expect_tool_call(&mut events).await;
2998 assert_eq!(
2999 tool_call,
3000 acp::ToolCall::new("1", "Echo")
3001 .raw_input(json!({}))
3002 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3003 );
3004 let update = expect_tool_call_update_fields(&mut events).await;
3005 assert_eq!(
3006 update,
3007 acp::ToolCallUpdate::new(
3008 "1",
3009 acp::ToolCallUpdateFields::new()
3010 .title("Echo")
3011 .kind(acp::ToolKind::Other)
3012 .raw_input(json!({ "text": "Hello!"}))
3013 )
3014 );
3015 let update = expect_tool_call_update_fields(&mut events).await;
3016 assert_eq!(
3017 update,
3018 acp::ToolCallUpdate::new(
3019 "1",
3020 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3021 )
3022 );
3023 let update = expect_tool_call_update_fields(&mut events).await;
3024 assert_eq!(
3025 update,
3026 acp::ToolCallUpdate::new(
3027 "1",
3028 acp::ToolCallUpdateFields::new()
3029 .status(acp::ToolCallStatus::Completed)
3030 .raw_output("Hello!")
3031 )
3032 );
3033}
3034
3035#[gpui::test]
3036async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3037 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3038 let fake_model = model.as_fake();
3039
3040 let mut events = thread
3041 .update(cx, |thread, cx| {
3042 thread.send(UserMessageId::new(), ["Hello!"], cx)
3043 })
3044 .unwrap();
3045 cx.run_until_parked();
3046
3047 fake_model.send_last_completion_stream_text_chunk("Hey!");
3048 fake_model.end_last_completion_stream();
3049
3050 let mut retry_events = Vec::new();
3051 while let Some(Ok(event)) = events.next().await {
3052 match event {
3053 ThreadEvent::Retry(retry_status) => {
3054 retry_events.push(retry_status);
3055 }
3056 ThreadEvent::Stop(..) => break,
3057 _ => {}
3058 }
3059 }
3060
3061 assert_eq!(retry_events.len(), 0);
3062 thread.read_with(cx, |thread, _cx| {
3063 assert_eq!(
3064 thread.to_markdown(),
3065 indoc! {"
3066 ## User
3067
3068 Hello!
3069
3070 ## Assistant
3071
3072 Hey!
3073 "}
3074 )
3075 });
3076}
3077
3078#[gpui::test]
3079async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3080 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3081 let fake_model = model.as_fake();
3082
3083 let mut events = thread
3084 .update(cx, |thread, cx| {
3085 thread.send(UserMessageId::new(), ["Hello!"], cx)
3086 })
3087 .unwrap();
3088 cx.run_until_parked();
3089
3090 fake_model.send_last_completion_stream_text_chunk("Hey,");
3091 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3092 provider: LanguageModelProviderName::new("Anthropic"),
3093 retry_after: Some(Duration::from_secs(3)),
3094 });
3095 fake_model.end_last_completion_stream();
3096
3097 cx.executor().advance_clock(Duration::from_secs(3));
3098 cx.run_until_parked();
3099
3100 fake_model.send_last_completion_stream_text_chunk("there!");
3101 fake_model.end_last_completion_stream();
3102 cx.run_until_parked();
3103
3104 let mut retry_events = Vec::new();
3105 while let Some(Ok(event)) = events.next().await {
3106 match event {
3107 ThreadEvent::Retry(retry_status) => {
3108 retry_events.push(retry_status);
3109 }
3110 ThreadEvent::Stop(..) => break,
3111 _ => {}
3112 }
3113 }
3114
3115 assert_eq!(retry_events.len(), 1);
3116 assert!(matches!(
3117 retry_events[0],
3118 acp_thread::RetryStatus { attempt: 1, .. }
3119 ));
3120 thread.read_with(cx, |thread, _cx| {
3121 assert_eq!(
3122 thread.to_markdown(),
3123 indoc! {"
3124 ## User
3125
3126 Hello!
3127
3128 ## Assistant
3129
3130 Hey,
3131
3132 [resume]
3133
3134 ## Assistant
3135
3136 there!
3137 "}
3138 )
3139 });
3140}
3141
3142#[gpui::test]
3143async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3144 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3145 let fake_model = model.as_fake();
3146
3147 let events = thread
3148 .update(cx, |thread, cx| {
3149 thread.add_tool(EchoTool, None);
3150 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3151 })
3152 .unwrap();
3153 cx.run_until_parked();
3154
3155 let tool_use_1 = LanguageModelToolUse {
3156 id: "tool_1".into(),
3157 name: EchoTool::NAME.into(),
3158 raw_input: json!({"text": "test"}).to_string(),
3159 input: json!({"text": "test"}),
3160 is_input_complete: true,
3161 thought_signature: None,
3162 };
3163 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3164 tool_use_1.clone(),
3165 ));
3166 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3167 provider: LanguageModelProviderName::new("Anthropic"),
3168 retry_after: Some(Duration::from_secs(3)),
3169 });
3170 fake_model.end_last_completion_stream();
3171
3172 cx.executor().advance_clock(Duration::from_secs(3));
3173 let completion = fake_model.pending_completions().pop().unwrap();
3174 assert_eq!(
3175 completion.messages[1..],
3176 vec![
3177 LanguageModelRequestMessage {
3178 role: Role::User,
3179 content: vec!["Call the echo tool!".into()],
3180 cache: false,
3181 reasoning_details: None,
3182 },
3183 LanguageModelRequestMessage {
3184 role: Role::Assistant,
3185 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3186 cache: false,
3187 reasoning_details: None,
3188 },
3189 LanguageModelRequestMessage {
3190 role: Role::User,
3191 content: vec![language_model::MessageContent::ToolResult(
3192 LanguageModelToolResult {
3193 tool_use_id: tool_use_1.id.clone(),
3194 tool_name: tool_use_1.name.clone(),
3195 is_error: false,
3196 content: "test".into(),
3197 output: Some("test".into())
3198 }
3199 )],
3200 cache: true,
3201 reasoning_details: None,
3202 },
3203 ]
3204 );
3205
3206 fake_model.send_last_completion_stream_text_chunk("Done");
3207 fake_model.end_last_completion_stream();
3208 cx.run_until_parked();
3209 events.collect::<Vec<_>>().await;
3210 thread.read_with(cx, |thread, _cx| {
3211 assert_eq!(
3212 thread.last_message(),
3213 Some(Message::Agent(AgentMessage {
3214 content: vec![AgentMessageContent::Text("Done".into())],
3215 tool_results: IndexMap::default(),
3216 reasoning_details: None,
3217 }))
3218 );
3219 })
3220}
3221
3222#[gpui::test]
3223async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3224 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3225 let fake_model = model.as_fake();
3226
3227 let mut events = thread
3228 .update(cx, |thread, cx| {
3229 thread.send(UserMessageId::new(), ["Hello!"], cx)
3230 })
3231 .unwrap();
3232 cx.run_until_parked();
3233
3234 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3235 fake_model.send_last_completion_stream_error(
3236 LanguageModelCompletionError::ServerOverloaded {
3237 provider: LanguageModelProviderName::new("Anthropic"),
3238 retry_after: Some(Duration::from_secs(3)),
3239 },
3240 );
3241 fake_model.end_last_completion_stream();
3242 cx.executor().advance_clock(Duration::from_secs(3));
3243 cx.run_until_parked();
3244 }
3245
3246 let mut errors = Vec::new();
3247 let mut retry_events = Vec::new();
3248 while let Some(event) = events.next().await {
3249 match event {
3250 Ok(ThreadEvent::Retry(retry_status)) => {
3251 retry_events.push(retry_status);
3252 }
3253 Ok(ThreadEvent::Stop(..)) => break,
3254 Err(error) => errors.push(error),
3255 _ => {}
3256 }
3257 }
3258
3259 assert_eq!(
3260 retry_events.len(),
3261 crate::thread::MAX_RETRY_ATTEMPTS as usize
3262 );
3263 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3264 assert_eq!(retry_events[i].attempt, i + 1);
3265 }
3266 assert_eq!(errors.len(), 1);
3267 let error = errors[0]
3268 .downcast_ref::<LanguageModelCompletionError>()
3269 .unwrap();
3270 assert!(matches!(
3271 error,
3272 LanguageModelCompletionError::ServerOverloaded { .. }
3273 ));
3274}
3275
3276/// Filters out the stop events for asserting against in tests
3277fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3278 result_events
3279 .into_iter()
3280 .filter_map(|event| match event.unwrap() {
3281 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3282 _ => None,
3283 })
3284 .collect()
3285}
3286
3287struct ThreadTest {
3288 model: Arc<dyn LanguageModel>,
3289 thread: Entity<Thread>,
3290 project_context: Entity<ProjectContext>,
3291 context_server_store: Entity<ContextServerStore>,
3292 fs: Arc<FakeFs>,
3293}
3294
3295enum TestModel {
3296 Sonnet4,
3297 Fake,
3298}
3299
3300impl TestModel {
3301 fn id(&self) -> LanguageModelId {
3302 match self {
3303 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3304 TestModel::Fake => unreachable!(),
3305 }
3306 }
3307}
3308
3309async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3310 cx.executor().allow_parking();
3311
3312 let fs = FakeFs::new(cx.background_executor.clone());
3313 fs.create_dir(paths::settings_file().parent().unwrap())
3314 .await
3315 .unwrap();
3316 fs.insert_file(
3317 paths::settings_file(),
3318 json!({
3319 "agent": {
3320 "default_profile": "test-profile",
3321 "profiles": {
3322 "test-profile": {
3323 "name": "Test Profile",
3324 "tools": {
3325 EchoTool::NAME: true,
3326 DelayTool::NAME: true,
3327 WordListTool::NAME: true,
3328 ToolRequiringPermission::NAME: true,
3329 InfiniteTool::NAME: true,
3330 CancellationAwareTool::NAME: true,
3331 (TerminalTool::NAME): true,
3332 }
3333 }
3334 }
3335 }
3336 })
3337 .to_string()
3338 .into_bytes(),
3339 )
3340 .await;
3341
3342 cx.update(|cx| {
3343 settings::init(cx);
3344
3345 match model {
3346 TestModel::Fake => {}
3347 TestModel::Sonnet4 => {
3348 gpui_tokio::init(cx);
3349 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3350 cx.set_http_client(Arc::new(http_client));
3351 let client = Client::production(cx);
3352 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3353 language_model::init(client.clone(), cx);
3354 language_models::init(user_store, client.clone(), cx);
3355 }
3356 };
3357
3358 watch_settings(fs.clone(), cx);
3359 });
3360
3361 let templates = Templates::new();
3362
3363 fs.insert_tree(path!("/test"), json!({})).await;
3364 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3365
3366 let model = cx
3367 .update(|cx| {
3368 if let TestModel::Fake = model {
3369 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3370 } else {
3371 let model_id = model.id();
3372 let models = LanguageModelRegistry::read_global(cx);
3373 let model = models
3374 .available_models(cx)
3375 .find(|model| model.id() == model_id)
3376 .unwrap();
3377
3378 let provider = models.provider(&model.provider_id()).unwrap();
3379 let authenticated = provider.authenticate(cx);
3380
3381 cx.spawn(async move |_cx| {
3382 authenticated.await.unwrap();
3383 model
3384 })
3385 }
3386 })
3387 .await;
3388
3389 let project_context = cx.new(|_cx| ProjectContext::default());
3390 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3391 let context_server_registry =
3392 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3393 let thread = cx.new(|cx| {
3394 Thread::new(
3395 project,
3396 project_context.clone(),
3397 context_server_registry,
3398 templates,
3399 Some(model.clone()),
3400 cx,
3401 )
3402 });
3403 ThreadTest {
3404 model,
3405 thread,
3406 project_context,
3407 context_server_store,
3408 fs,
3409 }
3410}
3411
3412#[cfg(test)]
3413#[ctor::ctor]
3414fn init_logger() {
3415 if std::env::var("RUST_LOG").is_ok() {
3416 env_logger::init();
3417 }
3418}
3419
3420fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3421 let fs = fs.clone();
3422 cx.spawn({
3423 async move |cx| {
3424 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3425 cx.background_executor(),
3426 fs,
3427 paths::settings_file().clone(),
3428 );
3429 let _watcher_task = watcher_task;
3430
3431 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3432 cx.update(|cx| {
3433 SettingsStore::update_global(cx, |settings, cx| {
3434 settings.set_user_settings(&new_settings_content, cx)
3435 })
3436 })
3437 .ok();
3438 }
3439 }
3440 })
3441 .detach();
3442}
3443
3444fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3445 completion
3446 .tools
3447 .iter()
3448 .map(|tool| tool.name.clone())
3449 .collect()
3450}
3451
3452fn setup_context_server(
3453 name: &'static str,
3454 tools: Vec<context_server::types::Tool>,
3455 context_server_store: &Entity<ContextServerStore>,
3456 cx: &mut TestAppContext,
3457) -> mpsc::UnboundedReceiver<(
3458 context_server::types::CallToolParams,
3459 oneshot::Sender<context_server::types::CallToolResponse>,
3460)> {
3461 cx.update(|cx| {
3462 let mut settings = ProjectSettings::get_global(cx).clone();
3463 settings.context_servers.insert(
3464 name.into(),
3465 project::project_settings::ContextServerSettings::Stdio {
3466 enabled: true,
3467 remote: false,
3468 command: ContextServerCommand {
3469 path: "somebinary".into(),
3470 args: Vec::new(),
3471 env: None,
3472 timeout: None,
3473 },
3474 },
3475 );
3476 ProjectSettings::override_global(settings, cx);
3477 });
3478
3479 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3480 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3481 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3482 context_server::types::InitializeResponse {
3483 protocol_version: context_server::types::ProtocolVersion(
3484 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3485 ),
3486 server_info: context_server::types::Implementation {
3487 name: name.into(),
3488 version: "1.0.0".to_string(),
3489 },
3490 capabilities: context_server::types::ServerCapabilities {
3491 tools: Some(context_server::types::ToolsCapabilities {
3492 list_changed: Some(true),
3493 }),
3494 ..Default::default()
3495 },
3496 meta: None,
3497 }
3498 })
3499 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3500 let tools = tools.clone();
3501 async move {
3502 context_server::types::ListToolsResponse {
3503 tools,
3504 next_cursor: None,
3505 meta: None,
3506 }
3507 }
3508 })
3509 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3510 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3511 async move {
3512 let (response_tx, response_rx) = oneshot::channel();
3513 mcp_tool_calls_tx
3514 .unbounded_send((params, response_tx))
3515 .unwrap();
3516 response_rx.await.unwrap()
3517 }
3518 });
3519 context_server_store.update(cx, |store, cx| {
3520 store.start_server(
3521 Arc::new(ContextServer::new(
3522 ContextServerId(name.into()),
3523 Arc::new(fake_transport),
3524 )),
3525 cx,
3526 );
3527 });
3528 cx.run_until_parked();
3529 mcp_tool_calls_rx
3530}
3531
3532#[gpui::test]
3533async fn test_tokens_before_message(cx: &mut TestAppContext) {
3534 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3535 let fake_model = model.as_fake();
3536
3537 // First message
3538 let message_1_id = UserMessageId::new();
3539 thread
3540 .update(cx, |thread, cx| {
3541 thread.send(message_1_id.clone(), ["First message"], cx)
3542 })
3543 .unwrap();
3544 cx.run_until_parked();
3545
3546 // Before any response, tokens_before_message should return None for first message
3547 thread.read_with(cx, |thread, _| {
3548 assert_eq!(
3549 thread.tokens_before_message(&message_1_id),
3550 None,
3551 "First message should have no tokens before it"
3552 );
3553 });
3554
3555 // Complete first message with usage
3556 fake_model.send_last_completion_stream_text_chunk("Response 1");
3557 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3558 language_model::TokenUsage {
3559 input_tokens: 100,
3560 output_tokens: 50,
3561 cache_creation_input_tokens: 0,
3562 cache_read_input_tokens: 0,
3563 },
3564 ));
3565 fake_model.end_last_completion_stream();
3566 cx.run_until_parked();
3567
3568 // First message still has no tokens before it
3569 thread.read_with(cx, |thread, _| {
3570 assert_eq!(
3571 thread.tokens_before_message(&message_1_id),
3572 None,
3573 "First message should still have no tokens before it after response"
3574 );
3575 });
3576
3577 // Second message
3578 let message_2_id = UserMessageId::new();
3579 thread
3580 .update(cx, |thread, cx| {
3581 thread.send(message_2_id.clone(), ["Second message"], cx)
3582 })
3583 .unwrap();
3584 cx.run_until_parked();
3585
3586 // Second message should have first message's input tokens before it
3587 thread.read_with(cx, |thread, _| {
3588 assert_eq!(
3589 thread.tokens_before_message(&message_2_id),
3590 Some(100),
3591 "Second message should have 100 tokens before it (from first request)"
3592 );
3593 });
3594
3595 // Complete second message
3596 fake_model.send_last_completion_stream_text_chunk("Response 2");
3597 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3598 language_model::TokenUsage {
3599 input_tokens: 250, // Total for this request (includes previous context)
3600 output_tokens: 75,
3601 cache_creation_input_tokens: 0,
3602 cache_read_input_tokens: 0,
3603 },
3604 ));
3605 fake_model.end_last_completion_stream();
3606 cx.run_until_parked();
3607
3608 // Third message
3609 let message_3_id = UserMessageId::new();
3610 thread
3611 .update(cx, |thread, cx| {
3612 thread.send(message_3_id.clone(), ["Third message"], cx)
3613 })
3614 .unwrap();
3615 cx.run_until_parked();
3616
3617 // Third message should have second message's input tokens (250) before it
3618 thread.read_with(cx, |thread, _| {
3619 assert_eq!(
3620 thread.tokens_before_message(&message_3_id),
3621 Some(250),
3622 "Third message should have 250 tokens before it (from second request)"
3623 );
3624 // Second message should still have 100
3625 assert_eq!(
3626 thread.tokens_before_message(&message_2_id),
3627 Some(100),
3628 "Second message should still have 100 tokens before it"
3629 );
3630 // First message still has none
3631 assert_eq!(
3632 thread.tokens_before_message(&message_1_id),
3633 None,
3634 "First message should still have no tokens before it"
3635 );
3636 });
3637}
3638
3639#[gpui::test]
3640async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3641 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3642 let fake_model = model.as_fake();
3643
3644 // Set up three messages with responses
3645 let message_1_id = UserMessageId::new();
3646 thread
3647 .update(cx, |thread, cx| {
3648 thread.send(message_1_id.clone(), ["Message 1"], cx)
3649 })
3650 .unwrap();
3651 cx.run_until_parked();
3652 fake_model.send_last_completion_stream_text_chunk("Response 1");
3653 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3654 language_model::TokenUsage {
3655 input_tokens: 100,
3656 output_tokens: 50,
3657 cache_creation_input_tokens: 0,
3658 cache_read_input_tokens: 0,
3659 },
3660 ));
3661 fake_model.end_last_completion_stream();
3662 cx.run_until_parked();
3663
3664 let message_2_id = UserMessageId::new();
3665 thread
3666 .update(cx, |thread, cx| {
3667 thread.send(message_2_id.clone(), ["Message 2"], cx)
3668 })
3669 .unwrap();
3670 cx.run_until_parked();
3671 fake_model.send_last_completion_stream_text_chunk("Response 2");
3672 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3673 language_model::TokenUsage {
3674 input_tokens: 250,
3675 output_tokens: 75,
3676 cache_creation_input_tokens: 0,
3677 cache_read_input_tokens: 0,
3678 },
3679 ));
3680 fake_model.end_last_completion_stream();
3681 cx.run_until_parked();
3682
3683 // Verify initial state
3684 thread.read_with(cx, |thread, _| {
3685 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3686 });
3687
3688 // Truncate at message 2 (removes message 2 and everything after)
3689 thread
3690 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3691 .unwrap();
3692 cx.run_until_parked();
3693
3694 // After truncation, message_2_id no longer exists, so lookup should return None
3695 thread.read_with(cx, |thread, _| {
3696 assert_eq!(
3697 thread.tokens_before_message(&message_2_id),
3698 None,
3699 "After truncation, message 2 no longer exists"
3700 );
3701 // Message 1 still exists but has no tokens before it
3702 assert_eq!(
3703 thread.tokens_before_message(&message_1_id),
3704 None,
3705 "First message still has no tokens before it"
3706 );
3707 });
3708}
3709
3710#[gpui::test]
3711async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3712 init_test(cx);
3713
3714 let fs = FakeFs::new(cx.executor());
3715 fs.insert_tree("/root", json!({})).await;
3716 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3717
3718 // Test 1: Deny rule blocks command
3719 {
3720 let environment = Rc::new(cx.update(|cx| {
3721 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3722 }));
3723
3724 cx.update(|cx| {
3725 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3726 settings.tool_permissions.tools.insert(
3727 TerminalTool::NAME.into(),
3728 agent_settings::ToolRules {
3729 default: Some(settings::ToolPermissionMode::Confirm),
3730 always_allow: vec![],
3731 always_deny: vec![
3732 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3733 ],
3734 always_confirm: vec![],
3735 invalid_patterns: vec![],
3736 },
3737 );
3738 agent_settings::AgentSettings::override_global(settings, cx);
3739 });
3740
3741 #[allow(clippy::arc_with_non_send_sync)]
3742 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3743 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3744
3745 let task = cx.update(|cx| {
3746 tool.run(
3747 crate::TerminalToolInput {
3748 command: "rm -rf /".to_string(),
3749 cd: ".".to_string(),
3750 timeout_ms: None,
3751 },
3752 event_stream,
3753 cx,
3754 )
3755 });
3756
3757 let result = task.await;
3758 assert!(
3759 result.is_err(),
3760 "expected command to be blocked by deny rule"
3761 );
3762 let err_msg = result.unwrap_err().to_string().to_lowercase();
3763 assert!(
3764 err_msg.contains("blocked"),
3765 "error should mention the command was blocked"
3766 );
3767 }
3768
3769 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
3770 {
3771 let environment = Rc::new(cx.update(|cx| {
3772 FakeThreadEnvironment::default()
3773 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3774 }));
3775
3776 cx.update(|cx| {
3777 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3778 settings.tool_permissions.tools.insert(
3779 TerminalTool::NAME.into(),
3780 agent_settings::ToolRules {
3781 default: Some(settings::ToolPermissionMode::Deny),
3782 always_allow: vec![
3783 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3784 ],
3785 always_deny: vec![],
3786 always_confirm: vec![],
3787 invalid_patterns: vec![],
3788 },
3789 );
3790 agent_settings::AgentSettings::override_global(settings, cx);
3791 });
3792
3793 #[allow(clippy::arc_with_non_send_sync)]
3794 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3795 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3796
3797 let task = cx.update(|cx| {
3798 tool.run(
3799 crate::TerminalToolInput {
3800 command: "echo hello".to_string(),
3801 cd: ".".to_string(),
3802 timeout_ms: None,
3803 },
3804 event_stream,
3805 cx,
3806 )
3807 });
3808
3809 let update = rx.expect_update_fields().await;
3810 assert!(
3811 update.content.iter().any(|blocks| {
3812 blocks
3813 .iter()
3814 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3815 }),
3816 "expected terminal content (allow rule should skip confirmation and override default deny)"
3817 );
3818
3819 let result = task.await;
3820 assert!(
3821 result.is_ok(),
3822 "expected command to succeed without confirmation"
3823 );
3824 }
3825
3826 // Test 3: global default: allow does NOT override always_confirm patterns
3827 {
3828 let environment = Rc::new(cx.update(|cx| {
3829 FakeThreadEnvironment::default()
3830 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3831 }));
3832
3833 cx.update(|cx| {
3834 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3835 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
3836 settings.tool_permissions.tools.insert(
3837 TerminalTool::NAME.into(),
3838 agent_settings::ToolRules {
3839 default: Some(settings::ToolPermissionMode::Allow),
3840 always_allow: vec![],
3841 always_deny: vec![],
3842 always_confirm: vec![
3843 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3844 ],
3845 invalid_patterns: vec![],
3846 },
3847 );
3848 agent_settings::AgentSettings::override_global(settings, cx);
3849 });
3850
3851 #[allow(clippy::arc_with_non_send_sync)]
3852 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3853 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3854
3855 let _task = cx.update(|cx| {
3856 tool.run(
3857 crate::TerminalToolInput {
3858 command: "sudo rm file".to_string(),
3859 cd: ".".to_string(),
3860 timeout_ms: None,
3861 },
3862 event_stream,
3863 cx,
3864 )
3865 });
3866
3867 // With global default: allow, confirm patterns are still respected
3868 // The expect_authorization() call will panic if no authorization is requested,
3869 // which validates that the confirm pattern still triggers confirmation
3870 let _auth = rx.expect_authorization().await;
3871
3872 drop(_task);
3873 }
3874
3875 // Test 4: tool-specific default: deny is respected even with global default: allow
3876 {
3877 let environment = Rc::new(cx.update(|cx| {
3878 FakeThreadEnvironment::default()
3879 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3880 }));
3881
3882 cx.update(|cx| {
3883 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3884 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
3885 settings.tool_permissions.tools.insert(
3886 TerminalTool::NAME.into(),
3887 agent_settings::ToolRules {
3888 default: Some(settings::ToolPermissionMode::Deny),
3889 always_allow: vec![],
3890 always_deny: vec![],
3891 always_confirm: vec![],
3892 invalid_patterns: vec![],
3893 },
3894 );
3895 agent_settings::AgentSettings::override_global(settings, cx);
3896 });
3897
3898 #[allow(clippy::arc_with_non_send_sync)]
3899 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3900 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3901
3902 let task = cx.update(|cx| {
3903 tool.run(
3904 crate::TerminalToolInput {
3905 command: "echo hello".to_string(),
3906 cd: ".".to_string(),
3907 timeout_ms: None,
3908 },
3909 event_stream,
3910 cx,
3911 )
3912 });
3913
3914 // tool-specific default: deny is respected even with global default: allow
3915 let result = task.await;
3916 assert!(
3917 result.is_err(),
3918 "expected command to be blocked by tool-specific deny default"
3919 );
3920 let err_msg = result.unwrap_err().to_string().to_lowercase();
3921 assert!(
3922 err_msg.contains("disabled"),
3923 "error should mention the tool is disabled, got: {err_msg}"
3924 );
3925 }
3926}
3927
3928#[gpui::test]
3929async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3930 init_test(cx);
3931
3932 cx.update(|cx| {
3933 cx.update_flags(true, vec!["subagents".to_string()]);
3934 });
3935
3936 let fs = FakeFs::new(cx.executor());
3937 fs.insert_tree(path!("/test"), json!({})).await;
3938 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3939 let project_context = cx.new(|_cx| ProjectContext::default());
3940 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3941 let context_server_registry =
3942 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3943 let model = Arc::new(FakeLanguageModel::default());
3944
3945 let environment = Rc::new(cx.update(|cx| {
3946 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3947 }));
3948
3949 let thread = cx.new(|cx| {
3950 let mut thread = Thread::new(
3951 project.clone(),
3952 project_context,
3953 context_server_registry,
3954 Templates::new(),
3955 Some(model),
3956 cx,
3957 );
3958 thread.add_default_tools(None, environment, cx);
3959 thread
3960 });
3961
3962 thread.read_with(cx, |thread, _| {
3963 assert!(
3964 thread.has_registered_tool(SubagentTool::NAME),
3965 "subagent tool should be present when feature flag is enabled"
3966 );
3967 });
3968}
3969
3970#[gpui::test]
3971async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
3972 init_test(cx);
3973
3974 cx.update(|cx| {
3975 cx.update_flags(true, vec!["subagents".to_string()]);
3976 });
3977
3978 let fs = FakeFs::new(cx.executor());
3979 fs.insert_tree(path!("/test"), json!({})).await;
3980 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3981 let project_context = cx.new(|_cx| ProjectContext::default());
3982 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3983 let context_server_registry =
3984 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3985 let model = Arc::new(FakeLanguageModel::default());
3986
3987 let parent_thread = cx.new(|cx| {
3988 Thread::new(
3989 project.clone(),
3990 project_context,
3991 context_server_registry,
3992 Templates::new(),
3993 Some(model.clone()),
3994 cx,
3995 )
3996 });
3997
3998 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
3999 subagent_thread.read_with(cx, |subagent_thread, cx| {
4000 assert!(subagent_thread.is_subagent());
4001 assert_eq!(subagent_thread.depth(), 1);
4002 assert_eq!(
4003 subagent_thread.model().map(|model| model.id()),
4004 Some(model.id())
4005 );
4006 assert_eq!(
4007 subagent_thread.parent_thread_id(),
4008 Some(parent_thread.read(cx).id().clone())
4009 );
4010 });
4011}
4012
4013#[gpui::test]
4014async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4015 init_test(cx);
4016
4017 cx.update(|cx| {
4018 cx.update_flags(true, vec!["subagents".to_string()]);
4019 });
4020
4021 let fs = FakeFs::new(cx.executor());
4022 fs.insert_tree(path!("/test"), json!({})).await;
4023 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4024 let project_context = cx.new(|_cx| ProjectContext::default());
4025 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4026 let context_server_registry =
4027 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4028 let model = Arc::new(FakeLanguageModel::default());
4029 let environment = Rc::new(cx.update(|cx| {
4030 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4031 }));
4032
4033 let deep_parent_thread = cx.new(|cx| {
4034 let mut thread = Thread::new(
4035 project.clone(),
4036 project_context,
4037 context_server_registry,
4038 Templates::new(),
4039 Some(model.clone()),
4040 cx,
4041 );
4042 thread.set_subagent_context(SubagentContext {
4043 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4044 depth: MAX_SUBAGENT_DEPTH - 1,
4045 });
4046 thread
4047 });
4048 let deep_subagent_thread = cx.new(|cx| {
4049 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4050 thread.add_default_tools(None, environment, cx);
4051 thread
4052 });
4053
4054 deep_subagent_thread.read_with(cx, |thread, _| {
4055 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4056 assert!(
4057 !thread.has_registered_tool(SubagentTool::NAME),
4058 "subagent tool should not be present at max depth"
4059 );
4060 });
4061}
4062
4063#[gpui::test]
4064async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4065 init_test(cx);
4066
4067 cx.update(|cx| {
4068 cx.update_flags(true, vec!["subagents".to_string()]);
4069 });
4070
4071 let fs = FakeFs::new(cx.executor());
4072 fs.insert_tree(path!("/test"), json!({})).await;
4073 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4074 let project_context = cx.new(|_cx| ProjectContext::default());
4075 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4076 let context_server_registry =
4077 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4078 let model = Arc::new(FakeLanguageModel::default());
4079
4080 let parent = cx.new(|cx| {
4081 Thread::new(
4082 project.clone(),
4083 project_context.clone(),
4084 context_server_registry.clone(),
4085 Templates::new(),
4086 Some(model.clone()),
4087 cx,
4088 )
4089 });
4090
4091 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4092
4093 parent.update(cx, |thread, _cx| {
4094 thread.register_running_subagent(subagent.downgrade());
4095 });
4096
4097 subagent
4098 .update(cx, |thread, cx| {
4099 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4100 })
4101 .unwrap();
4102 cx.run_until_parked();
4103
4104 subagent.read_with(cx, |thread, _| {
4105 assert!(!thread.is_turn_complete(), "subagent should be running");
4106 });
4107
4108 parent.update(cx, |thread, cx| {
4109 thread.cancel(cx).detach();
4110 });
4111
4112 subagent.read_with(cx, |thread, _| {
4113 assert!(
4114 thread.is_turn_complete(),
4115 "subagent should be cancelled when parent cancels"
4116 );
4117 });
4118}
4119
4120#[gpui::test]
4121async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4122 // This test verifies that the subagent tool properly handles user cancellation
4123 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4124 init_test(cx);
4125 always_allow_tools(cx);
4126
4127 cx.update(|cx| {
4128 cx.update_flags(true, vec!["subagents".to_string()]);
4129 });
4130
4131 let fs = FakeFs::new(cx.executor());
4132 fs.insert_tree(path!("/test"), json!({})).await;
4133 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4134 let project_context = cx.new(|_cx| ProjectContext::default());
4135 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4136 let context_server_registry =
4137 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4138 let model = Arc::new(FakeLanguageModel::default());
4139 let environment = Rc::new(cx.update(|cx| {
4140 FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
4141 }));
4142
4143 let parent = cx.new(|cx| {
4144 Thread::new(
4145 project.clone(),
4146 project_context.clone(),
4147 context_server_registry.clone(),
4148 Templates::new(),
4149 Some(model.clone()),
4150 cx,
4151 )
4152 });
4153
4154 #[allow(clippy::arc_with_non_send_sync)]
4155 let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
4156
4157 let (event_stream, _rx, mut cancellation_tx) =
4158 crate::ToolCallEventStream::test_with_cancellation();
4159
4160 // Start the subagent tool
4161 let task = cx.update(|cx| {
4162 tool.run(
4163 SubagentToolInput {
4164 label: "Long running task".to_string(),
4165 task_prompt: "Do a very long task that takes forever".to_string(),
4166 summary_prompt: "Summarize".to_string(),
4167 timeout_ms: None,
4168 allowed_tools: None,
4169 },
4170 event_stream.clone(),
4171 cx,
4172 )
4173 });
4174
4175 cx.run_until_parked();
4176
4177 // Signal cancellation via the event stream
4178 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4179
4180 // The task should complete promptly with a cancellation error
4181 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4182 let result = futures::select! {
4183 result = task.fuse() => result,
4184 _ = timeout.fuse() => {
4185 panic!("subagent tool did not respond to cancellation within timeout");
4186 }
4187 };
4188
4189 // Verify we got a cancellation error
4190 let err = result.unwrap_err();
4191 assert!(
4192 err.to_string().contains("cancelled by user"),
4193 "expected cancellation error, got: {}",
4194 err
4195 );
4196}
4197
4198#[gpui::test]
4199async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4200 init_test(cx);
4201 always_allow_tools(cx);
4202
4203 cx.update(|cx| {
4204 cx.update_flags(true, vec!["subagents".to_string()]);
4205 });
4206
4207 let fs = FakeFs::new(cx.executor());
4208 fs.insert_tree(path!("/test"), json!({})).await;
4209 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4210 let project_context = cx.new(|_cx| ProjectContext::default());
4211 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4212 let context_server_registry =
4213 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4214 cx.update(LanguageModelRegistry::test);
4215 let model = Arc::new(FakeLanguageModel::default());
4216 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4217 let native_agent = NativeAgent::new(
4218 project.clone(),
4219 thread_store,
4220 Templates::new(),
4221 None,
4222 fs,
4223 &mut cx.to_async(),
4224 )
4225 .await
4226 .unwrap();
4227 let parent_thread = cx.new(|cx| {
4228 Thread::new(
4229 project.clone(),
4230 project_context,
4231 context_server_registry,
4232 Templates::new(),
4233 Some(model.clone()),
4234 cx,
4235 )
4236 });
4237
4238 let mut handles = Vec::new();
4239 for _ in 0..MAX_PARALLEL_SUBAGENTS {
4240 let handle = cx
4241 .update(|cx| {
4242 NativeThreadEnvironment::create_subagent_thread(
4243 native_agent.downgrade(),
4244 parent_thread.clone(),
4245 "some title".to_string(),
4246 "some task".to_string(),
4247 None,
4248 None,
4249 cx,
4250 )
4251 })
4252 .expect("Expected to be able to create subagent thread");
4253 handles.push(handle);
4254 }
4255
4256 let result = cx.update(|cx| {
4257 NativeThreadEnvironment::create_subagent_thread(
4258 native_agent.downgrade(),
4259 parent_thread.clone(),
4260 "some title".to_string(),
4261 "some task".to_string(),
4262 None,
4263 None,
4264 cx,
4265 )
4266 });
4267 assert!(result.is_err());
4268 assert_eq!(
4269 result.err().unwrap().to_string(),
4270 format!(
4271 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
4272 MAX_PARALLEL_SUBAGENTS
4273 )
4274 );
4275}
4276
4277#[gpui::test]
4278async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4279 init_test(cx);
4280
4281 always_allow_tools(cx);
4282
4283 cx.update(|cx| {
4284 cx.update_flags(true, vec!["subagents".to_string()]);
4285 });
4286
4287 let fs = FakeFs::new(cx.executor());
4288 fs.insert_tree(path!("/test"), json!({})).await;
4289 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4290 let project_context = cx.new(|_cx| ProjectContext::default());
4291 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4292 let context_server_registry =
4293 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4294 cx.update(LanguageModelRegistry::test);
4295 let model = Arc::new(FakeLanguageModel::default());
4296 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4297 let native_agent = NativeAgent::new(
4298 project.clone(),
4299 thread_store,
4300 Templates::new(),
4301 None,
4302 fs,
4303 &mut cx.to_async(),
4304 )
4305 .await
4306 .unwrap();
4307 let parent_thread = cx.new(|cx| {
4308 Thread::new(
4309 project.clone(),
4310 project_context,
4311 context_server_registry,
4312 Templates::new(),
4313 Some(model.clone()),
4314 cx,
4315 )
4316 });
4317
4318 let subagent_handle = cx
4319 .update(|cx| {
4320 NativeThreadEnvironment::create_subagent_thread(
4321 native_agent.downgrade(),
4322 parent_thread.clone(),
4323 "some title".to_string(),
4324 "task prompt".to_string(),
4325 Some(Duration::from_millis(10)),
4326 None,
4327 cx,
4328 )
4329 })
4330 .expect("Failed to create subagent");
4331
4332 let summary_task =
4333 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4334
4335 cx.run_until_parked();
4336
4337 {
4338 let messages = model.pending_completions().last().unwrap().messages.clone();
4339 // Ensure that model received a system prompt
4340 assert_eq!(messages[0].role, Role::System);
4341 // Ensure that model received a task prompt
4342 assert_eq!(messages[1].role, Role::User);
4343 assert_eq!(
4344 messages[1].content,
4345 vec![MessageContent::Text("task prompt".to_string())]
4346 );
4347 }
4348
4349 model.send_last_completion_stream_text_chunk("Some task response...");
4350 model.end_last_completion_stream();
4351
4352 cx.run_until_parked();
4353
4354 {
4355 let messages = model.pending_completions().last().unwrap().messages.clone();
4356 assert_eq!(messages[2].role, Role::Assistant);
4357 assert_eq!(
4358 messages[2].content,
4359 vec![MessageContent::Text("Some task response...".to_string())]
4360 );
4361 // Ensure that model received a summary prompt
4362 assert_eq!(messages[3].role, Role::User);
4363 assert_eq!(
4364 messages[3].content,
4365 vec![MessageContent::Text("summary prompt".to_string())]
4366 );
4367 }
4368
4369 model.send_last_completion_stream_text_chunk("Some summary...");
4370 model.end_last_completion_stream();
4371
4372 let result = summary_task.await;
4373 assert_eq!(result.unwrap(), "Some summary...\n");
4374}
4375
4376#[gpui::test]
4377async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4378 cx: &mut TestAppContext,
4379) {
4380 init_test(cx);
4381
4382 always_allow_tools(cx);
4383
4384 cx.update(|cx| {
4385 cx.update_flags(true, vec!["subagents".to_string()]);
4386 });
4387
4388 let fs = FakeFs::new(cx.executor());
4389 fs.insert_tree(path!("/test"), json!({})).await;
4390 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4391 let project_context = cx.new(|_cx| ProjectContext::default());
4392 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4393 let context_server_registry =
4394 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4395 cx.update(LanguageModelRegistry::test);
4396 let model = Arc::new(FakeLanguageModel::default());
4397 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4398 let native_agent = NativeAgent::new(
4399 project.clone(),
4400 thread_store,
4401 Templates::new(),
4402 None,
4403 fs,
4404 &mut cx.to_async(),
4405 )
4406 .await
4407 .unwrap();
4408 let parent_thread = cx.new(|cx| {
4409 Thread::new(
4410 project.clone(),
4411 project_context,
4412 context_server_registry,
4413 Templates::new(),
4414 Some(model.clone()),
4415 cx,
4416 )
4417 });
4418
4419 let subagent_handle = cx
4420 .update(|cx| {
4421 NativeThreadEnvironment::create_subagent_thread(
4422 native_agent.downgrade(),
4423 parent_thread.clone(),
4424 "some title".to_string(),
4425 "task prompt".to_string(),
4426 Some(Duration::from_millis(100)),
4427 None,
4428 cx,
4429 )
4430 })
4431 .expect("Failed to create subagent");
4432
4433 let summary_task =
4434 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4435
4436 cx.run_until_parked();
4437
4438 {
4439 let messages = model.pending_completions().last().unwrap().messages.clone();
4440 // Ensure that model received a system prompt
4441 assert_eq!(messages[0].role, Role::System);
4442 // Ensure that model received a task prompt
4443 assert_eq!(
4444 messages[1].content,
4445 vec![MessageContent::Text("task prompt".to_string())]
4446 );
4447 }
4448
4449 // Don't complete the initial model stream — let the timeout expire instead.
4450 cx.executor().advance_clock(Duration::from_millis(200));
4451 cx.run_until_parked();
4452
4453 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
4454 // instead of the summary_prompt.
4455 {
4456 let messages = model.pending_completions().last().unwrap().messages.clone();
4457 let last_user_message = messages
4458 .iter()
4459 .rev()
4460 .find(|m| m.role == Role::User)
4461 .unwrap();
4462 assert_eq!(
4463 last_user_message.content,
4464 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
4465 );
4466 }
4467
4468 model.send_last_completion_stream_text_chunk("Some context low response...");
4469 model.end_last_completion_stream();
4470
4471 let result = summary_task.await;
4472 assert_eq!(result.unwrap(), "Some context low response...\n");
4473}
4474
4475#[gpui::test]
4476async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
4477 init_test(cx);
4478
4479 always_allow_tools(cx);
4480
4481 cx.update(|cx| {
4482 cx.update_flags(true, vec!["subagents".to_string()]);
4483 });
4484
4485 let fs = FakeFs::new(cx.executor());
4486 fs.insert_tree(path!("/test"), json!({})).await;
4487 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4488 let project_context = cx.new(|_cx| ProjectContext::default());
4489 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4490 let context_server_registry =
4491 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4492 cx.update(LanguageModelRegistry::test);
4493 let model = Arc::new(FakeLanguageModel::default());
4494 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4495 let native_agent = NativeAgent::new(
4496 project.clone(),
4497 thread_store,
4498 Templates::new(),
4499 None,
4500 fs,
4501 &mut cx.to_async(),
4502 )
4503 .await
4504 .unwrap();
4505 let parent_thread = cx.new(|cx| {
4506 let mut thread = Thread::new(
4507 project.clone(),
4508 project_context,
4509 context_server_registry,
4510 Templates::new(),
4511 Some(model.clone()),
4512 cx,
4513 );
4514 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4515 thread.add_tool(GrepTool::new(project.clone()), None);
4516 thread
4517 });
4518
4519 let _subagent_handle = cx
4520 .update(|cx| {
4521 NativeThreadEnvironment::create_subagent_thread(
4522 native_agent.downgrade(),
4523 parent_thread.clone(),
4524 "some title".to_string(),
4525 "task prompt".to_string(),
4526 Some(Duration::from_millis(10)),
4527 None,
4528 cx,
4529 )
4530 })
4531 .expect("Failed to create subagent");
4532
4533 cx.run_until_parked();
4534
4535 let tools = model
4536 .pending_completions()
4537 .last()
4538 .unwrap()
4539 .tools
4540 .iter()
4541 .map(|tool| tool.name.clone())
4542 .collect::<Vec<_>>();
4543 assert_eq!(tools.len(), 2);
4544 assert!(tools.contains(&"grep".to_string()));
4545 assert!(tools.contains(&"list_directory".to_string()));
4546}
4547
4548#[gpui::test]
4549async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
4550 init_test(cx);
4551
4552 always_allow_tools(cx);
4553
4554 cx.update(|cx| {
4555 cx.update_flags(true, vec!["subagents".to_string()]);
4556 });
4557
4558 let fs = FakeFs::new(cx.executor());
4559 fs.insert_tree(path!("/test"), json!({})).await;
4560 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4561 let project_context = cx.new(|_cx| ProjectContext::default());
4562 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4563 let context_server_registry =
4564 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4565 cx.update(LanguageModelRegistry::test);
4566 let model = Arc::new(FakeLanguageModel::default());
4567 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4568 let native_agent = NativeAgent::new(
4569 project.clone(),
4570 thread_store,
4571 Templates::new(),
4572 None,
4573 fs,
4574 &mut cx.to_async(),
4575 )
4576 .await
4577 .unwrap();
4578 let parent_thread = cx.new(|cx| {
4579 let mut thread = Thread::new(
4580 project.clone(),
4581 project_context,
4582 context_server_registry,
4583 Templates::new(),
4584 Some(model.clone()),
4585 cx,
4586 );
4587 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4588 thread.add_tool(GrepTool::new(project.clone()), None);
4589 thread
4590 });
4591
4592 let _subagent_handle = cx
4593 .update(|cx| {
4594 NativeThreadEnvironment::create_subagent_thread(
4595 native_agent.downgrade(),
4596 parent_thread.clone(),
4597 "some title".to_string(),
4598 "task prompt".to_string(),
4599 Some(Duration::from_millis(10)),
4600 Some(vec!["grep".to_string()]),
4601 cx,
4602 )
4603 })
4604 .expect("Failed to create subagent");
4605
4606 cx.run_until_parked();
4607
4608 let tools = model
4609 .pending_completions()
4610 .last()
4611 .unwrap()
4612 .tools
4613 .iter()
4614 .map(|tool| tool.name.clone())
4615 .collect::<Vec<_>>();
4616 assert_eq!(tools, vec!["grep"]);
4617}
4618
4619#[gpui::test]
4620async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4621 init_test(cx);
4622
4623 let fs = FakeFs::new(cx.executor());
4624 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4625 .await;
4626 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4627
4628 cx.update(|cx| {
4629 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4630 settings.tool_permissions.tools.insert(
4631 EditFileTool::NAME.into(),
4632 agent_settings::ToolRules {
4633 default: Some(settings::ToolPermissionMode::Allow),
4634 always_allow: vec![],
4635 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
4636 always_confirm: vec![],
4637 invalid_patterns: vec![],
4638 },
4639 );
4640 agent_settings::AgentSettings::override_global(settings, cx);
4641 });
4642
4643 let context_server_registry =
4644 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4645 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4646 let templates = crate::Templates::new();
4647 let thread = cx.new(|cx| {
4648 crate::Thread::new(
4649 project.clone(),
4650 cx.new(|_cx| prompt_store::ProjectContext::default()),
4651 context_server_registry,
4652 templates.clone(),
4653 None,
4654 cx,
4655 )
4656 });
4657
4658 #[allow(clippy::arc_with_non_send_sync)]
4659 let tool = Arc::new(crate::EditFileTool::new(
4660 project.clone(),
4661 thread.downgrade(),
4662 language_registry,
4663 templates,
4664 ));
4665 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4666
4667 let task = cx.update(|cx| {
4668 tool.run(
4669 crate::EditFileToolInput {
4670 display_description: "Edit sensitive file".to_string(),
4671 path: "root/sensitive_config.txt".into(),
4672 mode: crate::EditFileMode::Edit,
4673 },
4674 event_stream,
4675 cx,
4676 )
4677 });
4678
4679 let result = task.await;
4680 assert!(result.is_err(), "expected edit to be blocked");
4681 assert!(
4682 result.unwrap_err().to_string().contains("blocked"),
4683 "error should mention the edit was blocked"
4684 );
4685}
4686
4687#[gpui::test]
4688async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4689 init_test(cx);
4690
4691 let fs = FakeFs::new(cx.executor());
4692 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4693 .await;
4694 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4695
4696 cx.update(|cx| {
4697 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4698 settings.tool_permissions.tools.insert(
4699 DeletePathTool::NAME.into(),
4700 agent_settings::ToolRules {
4701 default: Some(settings::ToolPermissionMode::Allow),
4702 always_allow: vec![],
4703 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4704 always_confirm: vec![],
4705 invalid_patterns: vec![],
4706 },
4707 );
4708 agent_settings::AgentSettings::override_global(settings, cx);
4709 });
4710
4711 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4712
4713 #[allow(clippy::arc_with_non_send_sync)]
4714 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4715 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4716
4717 let task = cx.update(|cx| {
4718 tool.run(
4719 crate::DeletePathToolInput {
4720 path: "root/important_data.txt".to_string(),
4721 },
4722 event_stream,
4723 cx,
4724 )
4725 });
4726
4727 let result = task.await;
4728 assert!(result.is_err(), "expected deletion to be blocked");
4729 assert!(
4730 result.unwrap_err().to_string().contains("blocked"),
4731 "error should mention the deletion was blocked"
4732 );
4733}
4734
4735#[gpui::test]
4736async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
4737 init_test(cx);
4738
4739 let fs = FakeFs::new(cx.executor());
4740 fs.insert_tree(
4741 "/root",
4742 json!({
4743 "safe.txt": "content",
4744 "protected": {}
4745 }),
4746 )
4747 .await;
4748 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4749
4750 cx.update(|cx| {
4751 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4752 settings.tool_permissions.tools.insert(
4753 MovePathTool::NAME.into(),
4754 agent_settings::ToolRules {
4755 default: Some(settings::ToolPermissionMode::Allow),
4756 always_allow: vec![],
4757 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
4758 always_confirm: vec![],
4759 invalid_patterns: vec![],
4760 },
4761 );
4762 agent_settings::AgentSettings::override_global(settings, cx);
4763 });
4764
4765 #[allow(clippy::arc_with_non_send_sync)]
4766 let tool = Arc::new(crate::MovePathTool::new(project));
4767 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4768
4769 let task = cx.update(|cx| {
4770 tool.run(
4771 crate::MovePathToolInput {
4772 source_path: "root/safe.txt".to_string(),
4773 destination_path: "root/protected/safe.txt".to_string(),
4774 },
4775 event_stream,
4776 cx,
4777 )
4778 });
4779
4780 let result = task.await;
4781 assert!(
4782 result.is_err(),
4783 "expected move to be blocked due to destination path"
4784 );
4785 assert!(
4786 result.unwrap_err().to_string().contains("blocked"),
4787 "error should mention the move was blocked"
4788 );
4789}
4790
4791#[gpui::test]
4792async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
4793 init_test(cx);
4794
4795 let fs = FakeFs::new(cx.executor());
4796 fs.insert_tree(
4797 "/root",
4798 json!({
4799 "secret.txt": "secret content",
4800 "public": {}
4801 }),
4802 )
4803 .await;
4804 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4805
4806 cx.update(|cx| {
4807 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4808 settings.tool_permissions.tools.insert(
4809 MovePathTool::NAME.into(),
4810 agent_settings::ToolRules {
4811 default: Some(settings::ToolPermissionMode::Allow),
4812 always_allow: vec![],
4813 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
4814 always_confirm: vec![],
4815 invalid_patterns: vec![],
4816 },
4817 );
4818 agent_settings::AgentSettings::override_global(settings, cx);
4819 });
4820
4821 #[allow(clippy::arc_with_non_send_sync)]
4822 let tool = Arc::new(crate::MovePathTool::new(project));
4823 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4824
4825 let task = cx.update(|cx| {
4826 tool.run(
4827 crate::MovePathToolInput {
4828 source_path: "root/secret.txt".to_string(),
4829 destination_path: "root/public/not_secret.txt".to_string(),
4830 },
4831 event_stream,
4832 cx,
4833 )
4834 });
4835
4836 let result = task.await;
4837 assert!(
4838 result.is_err(),
4839 "expected move to be blocked due to source path"
4840 );
4841 assert!(
4842 result.unwrap_err().to_string().contains("blocked"),
4843 "error should mention the move was blocked"
4844 );
4845}
4846
4847#[gpui::test]
4848async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
4849 init_test(cx);
4850
4851 let fs = FakeFs::new(cx.executor());
4852 fs.insert_tree(
4853 "/root",
4854 json!({
4855 "confidential.txt": "confidential data",
4856 "dest": {}
4857 }),
4858 )
4859 .await;
4860 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4861
4862 cx.update(|cx| {
4863 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4864 settings.tool_permissions.tools.insert(
4865 CopyPathTool::NAME.into(),
4866 agent_settings::ToolRules {
4867 default: Some(settings::ToolPermissionMode::Allow),
4868 always_allow: vec![],
4869 always_deny: vec![
4870 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
4871 ],
4872 always_confirm: vec![],
4873 invalid_patterns: vec![],
4874 },
4875 );
4876 agent_settings::AgentSettings::override_global(settings, cx);
4877 });
4878
4879 #[allow(clippy::arc_with_non_send_sync)]
4880 let tool = Arc::new(crate::CopyPathTool::new(project));
4881 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4882
4883 let task = cx.update(|cx| {
4884 tool.run(
4885 crate::CopyPathToolInput {
4886 source_path: "root/confidential.txt".to_string(),
4887 destination_path: "root/dest/copy.txt".to_string(),
4888 },
4889 event_stream,
4890 cx,
4891 )
4892 });
4893
4894 let result = task.await;
4895 assert!(result.is_err(), "expected copy to be blocked");
4896 assert!(
4897 result.unwrap_err().to_string().contains("blocked"),
4898 "error should mention the copy was blocked"
4899 );
4900}
4901
4902#[gpui::test]
4903async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
4904 init_test(cx);
4905
4906 let fs = FakeFs::new(cx.executor());
4907 fs.insert_tree(
4908 "/root",
4909 json!({
4910 "normal.txt": "normal content",
4911 "readonly": {
4912 "config.txt": "readonly content"
4913 }
4914 }),
4915 )
4916 .await;
4917 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4918
4919 cx.update(|cx| {
4920 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4921 settings.tool_permissions.tools.insert(
4922 SaveFileTool::NAME.into(),
4923 agent_settings::ToolRules {
4924 default: Some(settings::ToolPermissionMode::Allow),
4925 always_allow: vec![],
4926 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
4927 always_confirm: vec![],
4928 invalid_patterns: vec![],
4929 },
4930 );
4931 agent_settings::AgentSettings::override_global(settings, cx);
4932 });
4933
4934 #[allow(clippy::arc_with_non_send_sync)]
4935 let tool = Arc::new(crate::SaveFileTool::new(project));
4936 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4937
4938 let task = cx.update(|cx| {
4939 tool.run(
4940 crate::SaveFileToolInput {
4941 paths: vec![
4942 std::path::PathBuf::from("root/normal.txt"),
4943 std::path::PathBuf::from("root/readonly/config.txt"),
4944 ],
4945 },
4946 event_stream,
4947 cx,
4948 )
4949 });
4950
4951 let result = task.await;
4952 assert!(
4953 result.is_err(),
4954 "expected save to be blocked due to denied path"
4955 );
4956 assert!(
4957 result.unwrap_err().to_string().contains("blocked"),
4958 "error should mention the save was blocked"
4959 );
4960}
4961
4962#[gpui::test]
4963async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
4964 init_test(cx);
4965
4966 let fs = FakeFs::new(cx.executor());
4967 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
4968 .await;
4969 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4970
4971 cx.update(|cx| {
4972 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4973 settings.tool_permissions.tools.insert(
4974 SaveFileTool::NAME.into(),
4975 agent_settings::ToolRules {
4976 default: Some(settings::ToolPermissionMode::Allow),
4977 always_allow: vec![],
4978 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
4979 always_confirm: vec![],
4980 invalid_patterns: vec![],
4981 },
4982 );
4983 agent_settings::AgentSettings::override_global(settings, cx);
4984 });
4985
4986 #[allow(clippy::arc_with_non_send_sync)]
4987 let tool = Arc::new(crate::SaveFileTool::new(project));
4988 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4989
4990 let task = cx.update(|cx| {
4991 tool.run(
4992 crate::SaveFileToolInput {
4993 paths: vec![std::path::PathBuf::from("root/config.secret")],
4994 },
4995 event_stream,
4996 cx,
4997 )
4998 });
4999
5000 let result = task.await;
5001 assert!(result.is_err(), "expected save to be blocked");
5002 assert!(
5003 result.unwrap_err().to_string().contains("blocked"),
5004 "error should mention the save was blocked"
5005 );
5006}
5007
5008#[gpui::test]
5009async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5010 init_test(cx);
5011
5012 cx.update(|cx| {
5013 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5014 settings.tool_permissions.tools.insert(
5015 WebSearchTool::NAME.into(),
5016 agent_settings::ToolRules {
5017 default: Some(settings::ToolPermissionMode::Allow),
5018 always_allow: vec![],
5019 always_deny: vec![
5020 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5021 ],
5022 always_confirm: vec![],
5023 invalid_patterns: vec![],
5024 },
5025 );
5026 agent_settings::AgentSettings::override_global(settings, cx);
5027 });
5028
5029 #[allow(clippy::arc_with_non_send_sync)]
5030 let tool = Arc::new(crate::WebSearchTool);
5031 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5032
5033 let input: crate::WebSearchToolInput =
5034 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5035
5036 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5037
5038 let result = task.await;
5039 assert!(result.is_err(), "expected search to be blocked");
5040 assert!(
5041 result.unwrap_err().to_string().contains("blocked"),
5042 "error should mention the search was blocked"
5043 );
5044}
5045
5046#[gpui::test]
5047async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5048 init_test(cx);
5049
5050 let fs = FakeFs::new(cx.executor());
5051 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5052 .await;
5053 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5054
5055 cx.update(|cx| {
5056 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5057 settings.tool_permissions.tools.insert(
5058 EditFileTool::NAME.into(),
5059 agent_settings::ToolRules {
5060 default: Some(settings::ToolPermissionMode::Confirm),
5061 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5062 always_deny: vec![],
5063 always_confirm: vec![],
5064 invalid_patterns: vec![],
5065 },
5066 );
5067 agent_settings::AgentSettings::override_global(settings, cx);
5068 });
5069
5070 let context_server_registry =
5071 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5072 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5073 let templates = crate::Templates::new();
5074 let thread = cx.new(|cx| {
5075 crate::Thread::new(
5076 project.clone(),
5077 cx.new(|_cx| prompt_store::ProjectContext::default()),
5078 context_server_registry,
5079 templates.clone(),
5080 None,
5081 cx,
5082 )
5083 });
5084
5085 #[allow(clippy::arc_with_non_send_sync)]
5086 let tool = Arc::new(crate::EditFileTool::new(
5087 project,
5088 thread.downgrade(),
5089 language_registry,
5090 templates,
5091 ));
5092 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5093
5094 let _task = cx.update(|cx| {
5095 tool.run(
5096 crate::EditFileToolInput {
5097 display_description: "Edit README".to_string(),
5098 path: "root/README.md".into(),
5099 mode: crate::EditFileMode::Edit,
5100 },
5101 event_stream,
5102 cx,
5103 )
5104 });
5105
5106 cx.run_until_parked();
5107
5108 let event = rx.try_next();
5109 assert!(
5110 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5111 "expected no authorization request for allowed .md file"
5112 );
5113}
5114
5115#[gpui::test]
5116async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5117 init_test(cx);
5118
5119 let fs = FakeFs::new(cx.executor());
5120 fs.insert_tree(
5121 "/root",
5122 json!({
5123 ".zed": {
5124 "settings.json": "{}"
5125 },
5126 "README.md": "# Hello"
5127 }),
5128 )
5129 .await;
5130 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5131
5132 cx.update(|cx| {
5133 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5134 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5135 agent_settings::AgentSettings::override_global(settings, cx);
5136 });
5137
5138 let context_server_registry =
5139 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5140 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5141 let templates = crate::Templates::new();
5142 let thread = cx.new(|cx| {
5143 crate::Thread::new(
5144 project.clone(),
5145 cx.new(|_cx| prompt_store::ProjectContext::default()),
5146 context_server_registry,
5147 templates.clone(),
5148 None,
5149 cx,
5150 )
5151 });
5152
5153 #[allow(clippy::arc_with_non_send_sync)]
5154 let tool = Arc::new(crate::EditFileTool::new(
5155 project,
5156 thread.downgrade(),
5157 language_registry,
5158 templates,
5159 ));
5160
5161 // Editing a file inside .zed/ should still prompt even with global default: allow,
5162 // because local settings paths are sensitive and require confirmation regardless.
5163 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5164 let _task = cx.update(|cx| {
5165 tool.run(
5166 crate::EditFileToolInput {
5167 display_description: "Edit local settings".to_string(),
5168 path: "root/.zed/settings.json".into(),
5169 mode: crate::EditFileMode::Edit,
5170 },
5171 event_stream,
5172 cx,
5173 )
5174 });
5175
5176 let _update = rx.expect_update_fields().await;
5177 let _auth = rx.expect_authorization().await;
5178}
5179
5180#[gpui::test]
5181async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5182 init_test(cx);
5183
5184 cx.update(|cx| {
5185 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5186 settings.tool_permissions.tools.insert(
5187 FetchTool::NAME.into(),
5188 agent_settings::ToolRules {
5189 default: Some(settings::ToolPermissionMode::Allow),
5190 always_allow: vec![],
5191 always_deny: vec![
5192 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5193 ],
5194 always_confirm: vec![],
5195 invalid_patterns: vec![],
5196 },
5197 );
5198 agent_settings::AgentSettings::override_global(settings, cx);
5199 });
5200
5201 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5202
5203 #[allow(clippy::arc_with_non_send_sync)]
5204 let tool = Arc::new(crate::FetchTool::new(http_client));
5205 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5206
5207 let input: crate::FetchToolInput =
5208 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5209
5210 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5211
5212 let result = task.await;
5213 assert!(result.is_err(), "expected fetch to be blocked");
5214 assert!(
5215 result.unwrap_err().to_string().contains("blocked"),
5216 "error should mention the fetch was blocked"
5217 );
5218}
5219
5220#[gpui::test]
5221async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5222 init_test(cx);
5223
5224 cx.update(|cx| {
5225 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5226 settings.tool_permissions.tools.insert(
5227 FetchTool::NAME.into(),
5228 agent_settings::ToolRules {
5229 default: Some(settings::ToolPermissionMode::Confirm),
5230 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5231 always_deny: vec![],
5232 always_confirm: vec![],
5233 invalid_patterns: vec![],
5234 },
5235 );
5236 agent_settings::AgentSettings::override_global(settings, cx);
5237 });
5238
5239 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5240
5241 #[allow(clippy::arc_with_non_send_sync)]
5242 let tool = Arc::new(crate::FetchTool::new(http_client));
5243 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5244
5245 let input: crate::FetchToolInput =
5246 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5247
5248 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5249
5250 cx.run_until_parked();
5251
5252 let event = rx.try_next();
5253 assert!(
5254 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5255 "expected no authorization request for allowed docs.rs URL"
5256 );
5257}
5258
5259#[gpui::test]
5260async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5261 init_test(cx);
5262 always_allow_tools(cx);
5263
5264 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5265 let fake_model = model.as_fake();
5266
5267 // Add a tool so we can simulate tool calls
5268 thread.update(cx, |thread, _cx| {
5269 thread.add_tool(EchoTool, None);
5270 });
5271
5272 // Start a turn by sending a message
5273 let mut events = thread
5274 .update(cx, |thread, cx| {
5275 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5276 })
5277 .unwrap();
5278 cx.run_until_parked();
5279
5280 // Simulate the model making a tool call
5281 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5282 LanguageModelToolUse {
5283 id: "tool_1".into(),
5284 name: "echo".into(),
5285 raw_input: r#"{"text": "hello"}"#.into(),
5286 input: json!({"text": "hello"}),
5287 is_input_complete: true,
5288 thought_signature: None,
5289 },
5290 ));
5291 fake_model
5292 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5293
5294 // Signal that a message is queued before ending the stream
5295 thread.update(cx, |thread, _cx| {
5296 thread.set_has_queued_message(true);
5297 });
5298
5299 // Now end the stream - tool will run, and the boundary check should see the queue
5300 fake_model.end_last_completion_stream();
5301
5302 // Collect all events until the turn stops
5303 let all_events = collect_events_until_stop(&mut events, cx).await;
5304
5305 // Verify we received the tool call event
5306 let tool_call_ids: Vec<_> = all_events
5307 .iter()
5308 .filter_map(|e| match e {
5309 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5310 _ => None,
5311 })
5312 .collect();
5313 assert_eq!(
5314 tool_call_ids,
5315 vec!["tool_1"],
5316 "Should have received a tool call event for our echo tool"
5317 );
5318
5319 // The turn should have stopped with EndTurn
5320 let stop_reasons = stop_events(all_events);
5321 assert_eq!(
5322 stop_reasons,
5323 vec![acp::StopReason::EndTurn],
5324 "Turn should have ended after tool completion due to queued message"
5325 );
5326
5327 // Verify the queued message flag is still set
5328 thread.update(cx, |thread, _cx| {
5329 assert!(
5330 thread.has_queued_message(),
5331 "Should still have queued message flag set"
5332 );
5333 });
5334
5335 // Thread should be idle now
5336 thread.update(cx, |thread, _cx| {
5337 assert!(
5338 thread.is_turn_complete(),
5339 "Thread should not be running after turn ends"
5340 );
5341 });
5342}