1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeSubagentHandle {
159 session_id: acp::SessionId,
160 wait_for_summary_task: Shared<Task<String>>,
161}
162
163impl FakeSubagentHandle {
164 fn new_never_completes(cx: &App) -> Self {
165 Self {
166 session_id: acp::SessionId::new("subagent-id"),
167 wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
168 }
169 }
170}
171
172impl SubagentHandle for FakeSubagentHandle {
173 fn id(&self) -> acp::SessionId {
174 self.session_id.clone()
175 }
176
177 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
178 let task = self.wait_for_summary_task.clone();
179 cx.background_spawn(async move { Ok(task.await) })
180 }
181}
182
183#[derive(Default)]
184struct FakeThreadEnvironment {
185 terminal_handle: Option<Rc<FakeTerminalHandle>>,
186 subagent_handle: Option<Rc<FakeSubagentHandle>>,
187}
188
189impl FakeThreadEnvironment {
190 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
191 Self {
192 terminal_handle: Some(terminal_handle.into()),
193 ..self
194 }
195 }
196
197 pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
198 Self {
199 subagent_handle: Some(subagent_handle.into()),
200 ..self
201 }
202 }
203}
204
205impl crate::ThreadEnvironment for FakeThreadEnvironment {
206 fn create_terminal(
207 &self,
208 _command: String,
209 _cwd: Option<std::path::PathBuf>,
210 _output_byte_limit: Option<u64>,
211 _cx: &mut AsyncApp,
212 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
213 let handle = self
214 .terminal_handle
215 .clone()
216 .expect("Terminal handle not available on FakeThreadEnvironment");
217 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
218 }
219
220 fn create_subagent(
221 &self,
222 _parent_thread: Entity<Thread>,
223 _label: String,
224 _initial_prompt: String,
225 _timeout_ms: Option<Duration>,
226 _allowed_tools: Option<Vec<String>>,
227 _cx: &mut App,
228 ) -> Result<Rc<dyn SubagentHandle>> {
229 Ok(self
230 .subagent_handle
231 .clone()
232 .expect("Subagent handle not available on FakeThreadEnvironment")
233 as Rc<dyn SubagentHandle>)
234 }
235}
236
237/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
238struct MultiTerminalEnvironment {
239 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
240}
241
242impl MultiTerminalEnvironment {
243 fn new() -> Self {
244 Self {
245 handles: std::cell::RefCell::new(Vec::new()),
246 }
247 }
248
249 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
250 self.handles.borrow().clone()
251 }
252}
253
254impl crate::ThreadEnvironment for MultiTerminalEnvironment {
255 fn create_terminal(
256 &self,
257 _command: String,
258 _cwd: Option<std::path::PathBuf>,
259 _output_byte_limit: Option<u64>,
260 cx: &mut AsyncApp,
261 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
263 self.handles.borrow_mut().push(handle.clone());
264 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
265 }
266
267 fn create_subagent(
268 &self,
269 _parent_thread: Entity<Thread>,
270 _label: String,
271 _initial_prompt: String,
272 _timeout: Option<Duration>,
273 _allowed_tools: Option<Vec<String>>,
274 _cx: &mut App,
275 ) -> Result<Rc<dyn SubagentHandle>> {
276 unimplemented!()
277 }
278}
279
280fn always_allow_tools(cx: &mut TestAppContext) {
281 cx.update(|cx| {
282 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
283 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
284 agent_settings::AgentSettings::override_global(settings, cx);
285 });
286}
287
288#[gpui::test]
289async fn test_echo(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
296 })
297 .unwrap();
298 cx.run_until_parked();
299 fake_model.send_last_completion_stream_text_chunk("Hello");
300 fake_model
301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
302 fake_model.end_last_completion_stream();
303
304 let events = events.collect().await;
305 thread.update(cx, |thread, _cx| {
306 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
307 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
308 });
309 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
310}
311
312#[gpui::test]
313async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
314 init_test(cx);
315 always_allow_tools(cx);
316
317 let fs = FakeFs::new(cx.executor());
318 let project = Project::test(fs, [], cx).await;
319
320 let environment = Rc::new(cx.update(|cx| {
321 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
322 }));
323 let handle = environment.terminal_handle.clone().unwrap();
324
325 #[allow(clippy::arc_with_non_send_sync)]
326 let tool = Arc::new(crate::TerminalTool::new(project, environment));
327 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
328
329 let task = cx.update(|cx| {
330 tool.run(
331 crate::TerminalToolInput {
332 command: "sleep 1000".to_string(),
333 cd: ".".to_string(),
334 timeout_ms: Some(5),
335 },
336 event_stream,
337 cx,
338 )
339 });
340
341 let update = rx.expect_update_fields().await;
342 assert!(
343 update.content.iter().any(|blocks| {
344 blocks
345 .iter()
346 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
347 }),
348 "expected tool call update to include terminal content"
349 );
350
351 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
352
353 let deadline = std::time::Instant::now() + Duration::from_millis(500);
354 loop {
355 if let Some(result) = task_future.as_mut().now_or_never() {
356 let result = result.expect("terminal tool task should complete");
357
358 assert!(
359 handle.was_killed(),
360 "expected terminal handle to be killed on timeout"
361 );
362 assert!(
363 result.contains("partial output"),
364 "expected result to include terminal output, got: {result}"
365 );
366 return;
367 }
368
369 if std::time::Instant::now() >= deadline {
370 panic!("timed out waiting for terminal tool task to complete");
371 }
372
373 cx.run_until_parked();
374 cx.background_executor.timer(Duration::from_millis(1)).await;
375 }
376}
377
378#[gpui::test]
379#[ignore]
380async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
381 init_test(cx);
382 always_allow_tools(cx);
383
384 let fs = FakeFs::new(cx.executor());
385 let project = Project::test(fs, [], cx).await;
386
387 let environment = Rc::new(cx.update(|cx| {
388 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
389 }));
390 let handle = environment.terminal_handle.clone().unwrap();
391
392 #[allow(clippy::arc_with_non_send_sync)]
393 let tool = Arc::new(crate::TerminalTool::new(project, environment));
394 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
395
396 let _task = cx.update(|cx| {
397 tool.run(
398 crate::TerminalToolInput {
399 command: "sleep 1000".to_string(),
400 cd: ".".to_string(),
401 timeout_ms: None,
402 },
403 event_stream,
404 cx,
405 )
406 });
407
408 let update = rx.expect_update_fields().await;
409 assert!(
410 update.content.iter().any(|blocks| {
411 blocks
412 .iter()
413 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
414 }),
415 "expected tool call update to include terminal content"
416 );
417
418 cx.background_executor
419 .timer(Duration::from_millis(25))
420 .await;
421
422 assert!(
423 !handle.was_killed(),
424 "did not expect terminal handle to be killed without a timeout"
425 );
426}
427
428#[gpui::test]
429async fn test_thinking(cx: &mut TestAppContext) {
430 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
431 let fake_model = model.as_fake();
432
433 let events = thread
434 .update(cx, |thread, cx| {
435 thread.send(
436 UserMessageId::new(),
437 [indoc! {"
438 Testing:
439
440 Generate a thinking step where you just think the word 'Think',
441 and have your final answer be 'Hello'
442 "}],
443 cx,
444 )
445 })
446 .unwrap();
447 cx.run_until_parked();
448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
449 text: "Think".to_string(),
450 signature: None,
451 });
452 fake_model.send_last_completion_stream_text_chunk("Hello");
453 fake_model
454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
455 fake_model.end_last_completion_stream();
456
457 let events = events.collect().await;
458 thread.update(cx, |thread, _cx| {
459 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
460 assert_eq!(
461 thread.last_message().unwrap().to_markdown(),
462 indoc! {"
463 <think>Think</think>
464 Hello
465 "}
466 )
467 });
468 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
469}
470
471#[gpui::test]
472async fn test_system_prompt(cx: &mut TestAppContext) {
473 let ThreadTest {
474 model,
475 thread,
476 project_context,
477 ..
478 } = setup(cx, TestModel::Fake).await;
479 let fake_model = model.as_fake();
480
481 project_context.update(cx, |project_context, _cx| {
482 project_context.shell = "test-shell".into()
483 });
484 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["abc"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491 let mut pending_completions = fake_model.pending_completions();
492 assert_eq!(
493 pending_completions.len(),
494 1,
495 "unexpected pending completions: {:?}",
496 pending_completions
497 );
498
499 let pending_completion = pending_completions.pop().unwrap();
500 assert_eq!(pending_completion.messages[0].role, Role::System);
501
502 let system_message = &pending_completion.messages[0];
503 let system_prompt = system_message.content[0].to_str().unwrap();
504 assert!(
505 system_prompt.contains("test-shell"),
506 "unexpected system message: {:?}",
507 system_message
508 );
509 assert!(
510 system_prompt.contains("## Fixing Diagnostics"),
511 "unexpected system message: {:?}",
512 system_message
513 );
514}
515
516#[gpui::test]
517async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
518 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
519 let fake_model = model.as_fake();
520
521 thread
522 .update(cx, |thread, cx| {
523 thread.send(UserMessageId::new(), ["abc"], cx)
524 })
525 .unwrap();
526 cx.run_until_parked();
527 let mut pending_completions = fake_model.pending_completions();
528 assert_eq!(
529 pending_completions.len(),
530 1,
531 "unexpected pending completions: {:?}",
532 pending_completions
533 );
534
535 let pending_completion = pending_completions.pop().unwrap();
536 assert_eq!(pending_completion.messages[0].role, Role::System);
537
538 let system_message = &pending_completion.messages[0];
539 let system_prompt = system_message.content[0].to_str().unwrap();
540 assert!(
541 !system_prompt.contains("## Tool Use"),
542 "unexpected system message: {:?}",
543 system_message
544 );
545 assert!(
546 !system_prompt.contains("## Fixing Diagnostics"),
547 "unexpected system message: {:?}",
548 system_message
549 );
550}
551
552#[gpui::test]
553async fn test_prompt_caching(cx: &mut TestAppContext) {
554 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
555 let fake_model = model.as_fake();
556
557 // Send initial user message and verify it's cached
558 thread
559 .update(cx, |thread, cx| {
560 thread.send(UserMessageId::new(), ["Message 1"], cx)
561 })
562 .unwrap();
563 cx.run_until_parked();
564
565 let completion = fake_model.pending_completions().pop().unwrap();
566 assert_eq!(
567 completion.messages[1..],
568 vec![LanguageModelRequestMessage {
569 role: Role::User,
570 content: vec!["Message 1".into()],
571 cache: true,
572 reasoning_details: None,
573 }]
574 );
575 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
576 "Response to Message 1".into(),
577 ));
578 fake_model.end_last_completion_stream();
579 cx.run_until_parked();
580
581 // Send another user message and verify only the latest is cached
582 thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["Message 2"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588
589 let completion = fake_model.pending_completions().pop().unwrap();
590 assert_eq!(
591 completion.messages[1..],
592 vec![
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 1".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 1".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Message 2".into()],
608 cache: true,
609 reasoning_details: None,
610 }
611 ]
612 );
613 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
614 "Response to Message 2".into(),
615 ));
616 fake_model.end_last_completion_stream();
617 cx.run_until_parked();
618
619 // Simulate a tool call and verify that the latest tool result is cached
620 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
621 thread
622 .update(cx, |thread, cx| {
623 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
624 })
625 .unwrap();
626 cx.run_until_parked();
627
628 let tool_use = LanguageModelToolUse {
629 id: "tool_1".into(),
630 name: EchoTool::NAME.into(),
631 raw_input: json!({"text": "test"}).to_string(),
632 input: json!({"text": "test"}),
633 is_input_complete: true,
634 thought_signature: None,
635 };
636 fake_model
637 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
638 fake_model.end_last_completion_stream();
639 cx.run_until_parked();
640
641 let completion = fake_model.pending_completions().pop().unwrap();
642 let tool_result = LanguageModelToolResult {
643 tool_use_id: "tool_1".into(),
644 tool_name: EchoTool::NAME.into(),
645 is_error: false,
646 content: "test".into(),
647 output: Some("test".into()),
648 };
649 assert_eq!(
650 completion.messages[1..],
651 vec![
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec!["Message 1".into()],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::Assistant,
660 content: vec!["Response to Message 1".into()],
661 cache: false,
662 reasoning_details: None,
663 },
664 LanguageModelRequestMessage {
665 role: Role::User,
666 content: vec!["Message 2".into()],
667 cache: false,
668 reasoning_details: None,
669 },
670 LanguageModelRequestMessage {
671 role: Role::Assistant,
672 content: vec!["Response to Message 2".into()],
673 cache: false,
674 reasoning_details: None,
675 },
676 LanguageModelRequestMessage {
677 role: Role::User,
678 content: vec!["Use the echo tool".into()],
679 cache: false,
680 reasoning_details: None,
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false,
686 reasoning_details: None,
687 },
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec![MessageContent::ToolResult(tool_result)],
691 cache: true,
692 reasoning_details: None,
693 }
694 ]
695 );
696}
697
698#[gpui::test]
699#[cfg_attr(not(feature = "e2e"), ignore)]
700async fn test_basic_tool_calls(cx: &mut TestAppContext) {
701 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
702
703 // Test a tool call that's likely to complete *before* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.add_tool(EchoTool, None);
707 thread.send(
708 UserMessageId::new(),
709 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
710 cx,
711 )
712 })
713 .unwrap()
714 .collect()
715 .await;
716 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
717
718 // Test a tool calls that's likely to complete *after* streaming stops.
719 let events = thread
720 .update(cx, |thread, cx| {
721 thread.remove_tool(&EchoTool::NAME);
722 thread.add_tool(DelayTool, None);
723 thread.send(
724 UserMessageId::new(),
725 [
726 "Now call the delay tool with 200ms.",
727 "When the timer goes off, then you echo the output of the tool.",
728 ],
729 cx,
730 )
731 })
732 .unwrap()
733 .collect()
734 .await;
735 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
736 thread.update(cx, |thread, _cx| {
737 assert!(
738 thread
739 .last_message()
740 .unwrap()
741 .as_agent_message()
742 .unwrap()
743 .content
744 .iter()
745 .any(|content| {
746 if let AgentMessageContent::Text(text) = content {
747 text.contains("Ding")
748 } else {
749 false
750 }
751 }),
752 "{}",
753 thread.to_markdown()
754 );
755 });
756}
757
758#[gpui::test]
759#[cfg_attr(not(feature = "e2e"), ignore)]
760async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
761 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
762
763 // Test a tool call that's likely to complete *before* streaming stops.
764 let mut events = thread
765 .update(cx, |thread, cx| {
766 thread.add_tool(WordListTool, None);
767 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
768 })
769 .unwrap();
770
771 let mut saw_partial_tool_use = false;
772 while let Some(event) = events.next().await {
773 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
774 thread.update(cx, |thread, _cx| {
775 // Look for a tool use in the thread's last message
776 let message = thread.last_message().unwrap();
777 let agent_message = message.as_agent_message().unwrap();
778 let last_content = agent_message.content.last().unwrap();
779 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
780 assert_eq!(last_tool_use.name.as_ref(), "word_list");
781 if tool_call.status == acp::ToolCallStatus::Pending {
782 if !last_tool_use.is_input_complete
783 && last_tool_use.input.get("g").is_none()
784 {
785 saw_partial_tool_use = true;
786 }
787 } else {
788 last_tool_use
789 .input
790 .get("a")
791 .expect("'a' has streamed because input is now complete");
792 last_tool_use
793 .input
794 .get("g")
795 .expect("'g' has streamed because input is now complete");
796 }
797 } else {
798 panic!("last content should be a tool use");
799 }
800 });
801 }
802 }
803
804 assert!(
805 saw_partial_tool_use,
806 "should see at least one partially streamed tool use in the history"
807 );
808}
809
810#[gpui::test]
811async fn test_tool_authorization(cx: &mut TestAppContext) {
812 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
813 let fake_model = model.as_fake();
814
815 let mut events = thread
816 .update(cx, |thread, cx| {
817 thread.add_tool(ToolRequiringPermission, None);
818 thread.send(UserMessageId::new(), ["abc"], cx)
819 })
820 .unwrap();
821 cx.run_until_parked();
822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
823 LanguageModelToolUse {
824 id: "tool_id_1".into(),
825 name: ToolRequiringPermission::NAME.into(),
826 raw_input: "{}".into(),
827 input: json!({}),
828 is_input_complete: true,
829 thought_signature: None,
830 },
831 ));
832 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
833 LanguageModelToolUse {
834 id: "tool_id_2".into(),
835 name: ToolRequiringPermission::NAME.into(),
836 raw_input: "{}".into(),
837 input: json!({}),
838 is_input_complete: true,
839 thought_signature: None,
840 },
841 ));
842 fake_model.end_last_completion_stream();
843 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
844 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
845
846 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
847 tool_call_auth_1
848 .response
849 .send(acp::PermissionOptionId::new("allow"))
850 .unwrap();
851 cx.run_until_parked();
852
853 // Reject the second - send "deny" option_id directly since Deny is now a button
854 tool_call_auth_2
855 .response
856 .send(acp::PermissionOptionId::new("deny"))
857 .unwrap();
858 cx.run_until_parked();
859
860 let completion = fake_model.pending_completions().pop().unwrap();
861 let message = completion.messages.last().unwrap();
862 assert_eq!(
863 message.content,
864 vec![
865 language_model::MessageContent::ToolResult(LanguageModelToolResult {
866 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
867 tool_name: ToolRequiringPermission::NAME.into(),
868 is_error: false,
869 content: "Allowed".into(),
870 output: Some("Allowed".into())
871 }),
872 language_model::MessageContent::ToolResult(LanguageModelToolResult {
873 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
874 tool_name: ToolRequiringPermission::NAME.into(),
875 is_error: true,
876 content: "Permission to run tool denied by user".into(),
877 output: Some("Permission to run tool denied by user".into())
878 })
879 ]
880 );
881
882 // Simulate yet another tool call.
883 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
884 LanguageModelToolUse {
885 id: "tool_id_3".into(),
886 name: ToolRequiringPermission::NAME.into(),
887 raw_input: "{}".into(),
888 input: json!({}),
889 is_input_complete: true,
890 thought_signature: None,
891 },
892 ));
893 fake_model.end_last_completion_stream();
894
895 // Respond by always allowing tools - send transformed option_id
896 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
897 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
898 tool_call_auth_3
899 .response
900 .send(acp::PermissionOptionId::new(
901 "always_allow:tool_requiring_permission",
902 ))
903 .unwrap();
904 cx.run_until_parked();
905 let completion = fake_model.pending_completions().pop().unwrap();
906 let message = completion.messages.last().unwrap();
907 assert_eq!(
908 message.content,
909 vec![language_model::MessageContent::ToolResult(
910 LanguageModelToolResult {
911 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
912 tool_name: ToolRequiringPermission::NAME.into(),
913 is_error: false,
914 content: "Allowed".into(),
915 output: Some("Allowed".into())
916 }
917 )]
918 );
919
920 // Simulate a final tool call, ensuring we don't trigger authorization.
921 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
922 LanguageModelToolUse {
923 id: "tool_id_4".into(),
924 name: ToolRequiringPermission::NAME.into(),
925 raw_input: "{}".into(),
926 input: json!({}),
927 is_input_complete: true,
928 thought_signature: None,
929 },
930 ));
931 fake_model.end_last_completion_stream();
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let message = completion.messages.last().unwrap();
935 assert_eq!(
936 message.content,
937 vec![language_model::MessageContent::ToolResult(
938 LanguageModelToolResult {
939 tool_use_id: "tool_id_4".into(),
940 tool_name: ToolRequiringPermission::NAME.into(),
941 is_error: false,
942 content: "Allowed".into(),
943 output: Some("Allowed".into())
944 }
945 )]
946 );
947}
948
949#[gpui::test]
950async fn test_tool_hallucination(cx: &mut TestAppContext) {
951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
952 let fake_model = model.as_fake();
953
954 let mut events = thread
955 .update(cx, |thread, cx| {
956 thread.send(UserMessageId::new(), ["abc"], cx)
957 })
958 .unwrap();
959 cx.run_until_parked();
960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
961 LanguageModelToolUse {
962 id: "tool_id_1".into(),
963 name: "nonexistent_tool".into(),
964 raw_input: "{}".into(),
965 input: json!({}),
966 is_input_complete: true,
967 thought_signature: None,
968 },
969 ));
970 fake_model.end_last_completion_stream();
971
972 let tool_call = expect_tool_call(&mut events).await;
973 assert_eq!(tool_call.title, "nonexistent_tool");
974 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
975 let update = expect_tool_call_update_fields(&mut events).await;
976 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
977}
978
979async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
980 let event = events
981 .next()
982 .await
983 .expect("no tool call authorization event received")
984 .unwrap();
985 match event {
986 ThreadEvent::ToolCall(tool_call) => tool_call,
987 event => {
988 panic!("Unexpected event {event:?}");
989 }
990 }
991}
992
993async fn expect_tool_call_update_fields(
994 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
995) -> acp::ToolCallUpdate {
996 let event = events
997 .next()
998 .await
999 .expect("no tool call authorization event received")
1000 .unwrap();
1001 match event {
1002 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1003 event => {
1004 panic!("Unexpected event {event:?}");
1005 }
1006 }
1007}
1008
1009async fn next_tool_call_authorization(
1010 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1011) -> ToolCallAuthorization {
1012 loop {
1013 let event = events
1014 .next()
1015 .await
1016 .expect("no tool call authorization event received")
1017 .unwrap();
1018 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1019 let permission_kinds = tool_call_authorization
1020 .options
1021 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1022 .map(|option| option.kind);
1023 let allow_once = tool_call_authorization
1024 .options
1025 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1026 .map(|option| option.kind);
1027
1028 assert_eq!(
1029 permission_kinds,
1030 Some(acp::PermissionOptionKind::AllowAlways)
1031 );
1032 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1033 return tool_call_authorization;
1034 }
1035 }
1036}
1037
1038#[test]
1039fn test_permission_options_terminal_with_pattern() {
1040 let permission_options = ToolPermissionContext::new(
1041 TerminalTool::NAME,
1042 vec!["cargo build --release".to_string()],
1043 )
1044 .build_permission_options();
1045
1046 let PermissionOptions::Dropdown(choices) = permission_options else {
1047 panic!("Expected dropdown permission options");
1048 };
1049
1050 assert_eq!(choices.len(), 3);
1051 let labels: Vec<&str> = choices
1052 .iter()
1053 .map(|choice| choice.allow.name.as_ref())
1054 .collect();
1055 assert!(labels.contains(&"Always for terminal"));
1056 assert!(labels.contains(&"Always for `cargo` commands"));
1057 assert!(labels.contains(&"Only this time"));
1058}
1059
1060#[test]
1061fn test_permission_options_edit_file_with_path_pattern() {
1062 let permission_options =
1063 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1064 .build_permission_options();
1065
1066 let PermissionOptions::Dropdown(choices) = permission_options else {
1067 panic!("Expected dropdown permission options");
1068 };
1069
1070 let labels: Vec<&str> = choices
1071 .iter()
1072 .map(|choice| choice.allow.name.as_ref())
1073 .collect();
1074 assert!(labels.contains(&"Always for edit file"));
1075 assert!(labels.contains(&"Always for `src/`"));
1076}
1077
1078#[test]
1079fn test_permission_options_fetch_with_domain_pattern() {
1080 let permission_options =
1081 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1082 .build_permission_options();
1083
1084 let PermissionOptions::Dropdown(choices) = permission_options else {
1085 panic!("Expected dropdown permission options");
1086 };
1087
1088 let labels: Vec<&str> = choices
1089 .iter()
1090 .map(|choice| choice.allow.name.as_ref())
1091 .collect();
1092 assert!(labels.contains(&"Always for fetch"));
1093 assert!(labels.contains(&"Always for `docs.rs`"));
1094}
1095
1096#[test]
1097fn test_permission_options_without_pattern() {
1098 let permission_options = ToolPermissionContext::new(
1099 TerminalTool::NAME,
1100 vec!["./deploy.sh --production".to_string()],
1101 )
1102 .build_permission_options();
1103
1104 let PermissionOptions::Dropdown(choices) = permission_options else {
1105 panic!("Expected dropdown permission options");
1106 };
1107
1108 assert_eq!(choices.len(), 2);
1109 let labels: Vec<&str> = choices
1110 .iter()
1111 .map(|choice| choice.allow.name.as_ref())
1112 .collect();
1113 assert!(labels.contains(&"Always for terminal"));
1114 assert!(labels.contains(&"Only this time"));
1115 assert!(!labels.iter().any(|label| label.contains("commands")));
1116}
1117
1118#[test]
1119fn test_permission_option_ids_for_terminal() {
1120 let permission_options = ToolPermissionContext::new(
1121 TerminalTool::NAME,
1122 vec!["cargo build --release".to_string()],
1123 )
1124 .build_permission_options();
1125
1126 let PermissionOptions::Dropdown(choices) = permission_options else {
1127 panic!("Expected dropdown permission options");
1128 };
1129
1130 let allow_ids: Vec<String> = choices
1131 .iter()
1132 .map(|choice| choice.allow.option_id.0.to_string())
1133 .collect();
1134 let deny_ids: Vec<String> = choices
1135 .iter()
1136 .map(|choice| choice.deny.option_id.0.to_string())
1137 .collect();
1138
1139 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1140 assert!(allow_ids.contains(&"allow".to_string()));
1141 assert!(
1142 allow_ids
1143 .iter()
1144 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1145 "Missing allow pattern option"
1146 );
1147
1148 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1149 assert!(deny_ids.contains(&"deny".to_string()));
1150 assert!(
1151 deny_ids
1152 .iter()
1153 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1154 "Missing deny pattern option"
1155 );
1156}
1157
1158#[gpui::test]
1159#[cfg_attr(not(feature = "e2e"), ignore)]
1160async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1161 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1162
1163 // Test concurrent tool calls with different delay times
1164 let events = thread
1165 .update(cx, |thread, cx| {
1166 thread.add_tool(DelayTool, None);
1167 thread.send(
1168 UserMessageId::new(),
1169 [
1170 "Call the delay tool twice in the same message.",
1171 "Once with 100ms. Once with 300ms.",
1172 "When both timers are complete, describe the outputs.",
1173 ],
1174 cx,
1175 )
1176 })
1177 .unwrap()
1178 .collect()
1179 .await;
1180
1181 let stop_reasons = stop_events(events);
1182 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1183
1184 thread.update(cx, |thread, _cx| {
1185 let last_message = thread.last_message().unwrap();
1186 let agent_message = last_message.as_agent_message().unwrap();
1187 let text = agent_message
1188 .content
1189 .iter()
1190 .filter_map(|content| {
1191 if let AgentMessageContent::Text(text) = content {
1192 Some(text.as_str())
1193 } else {
1194 None
1195 }
1196 })
1197 .collect::<String>();
1198
1199 assert!(text.contains("Ding"));
1200 });
1201}
1202
1203#[gpui::test]
1204async fn test_profiles(cx: &mut TestAppContext) {
1205 let ThreadTest {
1206 model, thread, fs, ..
1207 } = setup(cx, TestModel::Fake).await;
1208 let fake_model = model.as_fake();
1209
1210 thread.update(cx, |thread, _cx| {
1211 thread.add_tool(DelayTool, None);
1212 thread.add_tool(EchoTool, None);
1213 thread.add_tool(InfiniteTool, None);
1214 });
1215
1216 // Override profiles and wait for settings to be loaded.
1217 fs.insert_file(
1218 paths::settings_file(),
1219 json!({
1220 "agent": {
1221 "profiles": {
1222 "test-1": {
1223 "name": "Test Profile 1",
1224 "tools": {
1225 EchoTool::NAME: true,
1226 DelayTool::NAME: true,
1227 }
1228 },
1229 "test-2": {
1230 "name": "Test Profile 2",
1231 "tools": {
1232 InfiniteTool::NAME: true,
1233 }
1234 }
1235 }
1236 }
1237 })
1238 .to_string()
1239 .into_bytes(),
1240 )
1241 .await;
1242 cx.run_until_parked();
1243
1244 // Test that test-1 profile (default) has echo and delay tools
1245 thread
1246 .update(cx, |thread, cx| {
1247 thread.set_profile(AgentProfileId("test-1".into()), cx);
1248 thread.send(UserMessageId::new(), ["test"], cx)
1249 })
1250 .unwrap();
1251 cx.run_until_parked();
1252
1253 let mut pending_completions = fake_model.pending_completions();
1254 assert_eq!(pending_completions.len(), 1);
1255 let completion = pending_completions.pop().unwrap();
1256 let tool_names: Vec<String> = completion
1257 .tools
1258 .iter()
1259 .map(|tool| tool.name.clone())
1260 .collect();
1261 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1262 fake_model.end_last_completion_stream();
1263
1264 // Switch to test-2 profile, and verify that it has only the infinite tool.
1265 thread
1266 .update(cx, |thread, cx| {
1267 thread.set_profile(AgentProfileId("test-2".into()), cx);
1268 thread.send(UserMessageId::new(), ["test2"], cx)
1269 })
1270 .unwrap();
1271 cx.run_until_parked();
1272 let mut pending_completions = fake_model.pending_completions();
1273 assert_eq!(pending_completions.len(), 1);
1274 let completion = pending_completions.pop().unwrap();
1275 let tool_names: Vec<String> = completion
1276 .tools
1277 .iter()
1278 .map(|tool| tool.name.clone())
1279 .collect();
1280 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1281}
1282
1283#[gpui::test]
1284async fn test_mcp_tools(cx: &mut TestAppContext) {
1285 let ThreadTest {
1286 model,
1287 thread,
1288 context_server_store,
1289 fs,
1290 ..
1291 } = setup(cx, TestModel::Fake).await;
1292 let fake_model = model.as_fake();
1293
1294 // Override profiles and wait for settings to be loaded.
1295 fs.insert_file(
1296 paths::settings_file(),
1297 json!({
1298 "agent": {
1299 "tool_permissions": { "default": "allow" },
1300 "profiles": {
1301 "test": {
1302 "name": "Test Profile",
1303 "enable_all_context_servers": true,
1304 "tools": {
1305 EchoTool::NAME: true,
1306 }
1307 },
1308 }
1309 }
1310 })
1311 .to_string()
1312 .into_bytes(),
1313 )
1314 .await;
1315 cx.run_until_parked();
1316 thread.update(cx, |thread, cx| {
1317 thread.set_profile(AgentProfileId("test".into()), cx)
1318 });
1319
1320 let mut mcp_tool_calls = setup_context_server(
1321 "test_server",
1322 vec![context_server::types::Tool {
1323 name: "echo".into(),
1324 description: None,
1325 input_schema: serde_json::to_value(EchoTool::input_schema(
1326 LanguageModelToolSchemaFormat::JsonSchema,
1327 ))
1328 .unwrap(),
1329 output_schema: None,
1330 annotations: None,
1331 }],
1332 &context_server_store,
1333 cx,
1334 );
1335
1336 let events = thread.update(cx, |thread, cx| {
1337 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1338 });
1339 cx.run_until_parked();
1340
1341 // Simulate the model calling the MCP tool.
1342 let completion = fake_model.pending_completions().pop().unwrap();
1343 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1344 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1345 LanguageModelToolUse {
1346 id: "tool_1".into(),
1347 name: "echo".into(),
1348 raw_input: json!({"text": "test"}).to_string(),
1349 input: json!({"text": "test"}),
1350 is_input_complete: true,
1351 thought_signature: None,
1352 },
1353 ));
1354 fake_model.end_last_completion_stream();
1355 cx.run_until_parked();
1356
1357 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1358 assert_eq!(tool_call_params.name, "echo");
1359 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1360 tool_call_response
1361 .send(context_server::types::CallToolResponse {
1362 content: vec![context_server::types::ToolResponseContent::Text {
1363 text: "test".into(),
1364 }],
1365 is_error: None,
1366 meta: None,
1367 structured_content: None,
1368 })
1369 .unwrap();
1370 cx.run_until_parked();
1371
1372 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1373 fake_model.send_last_completion_stream_text_chunk("Done!");
1374 fake_model.end_last_completion_stream();
1375 events.collect::<Vec<_>>().await;
1376
1377 // Send again after adding the echo tool, ensuring the name collision is resolved.
1378 let events = thread.update(cx, |thread, cx| {
1379 thread.add_tool(EchoTool, None);
1380 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1381 });
1382 cx.run_until_parked();
1383 let completion = fake_model.pending_completions().pop().unwrap();
1384 assert_eq!(
1385 tool_names_for_completion(&completion),
1386 vec!["echo", "test_server_echo"]
1387 );
1388 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1389 LanguageModelToolUse {
1390 id: "tool_2".into(),
1391 name: "test_server_echo".into(),
1392 raw_input: json!({"text": "mcp"}).to_string(),
1393 input: json!({"text": "mcp"}),
1394 is_input_complete: true,
1395 thought_signature: None,
1396 },
1397 ));
1398 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1399 LanguageModelToolUse {
1400 id: "tool_3".into(),
1401 name: "echo".into(),
1402 raw_input: json!({"text": "native"}).to_string(),
1403 input: json!({"text": "native"}),
1404 is_input_complete: true,
1405 thought_signature: None,
1406 },
1407 ));
1408 fake_model.end_last_completion_stream();
1409 cx.run_until_parked();
1410
1411 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1412 assert_eq!(tool_call_params.name, "echo");
1413 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1414 tool_call_response
1415 .send(context_server::types::CallToolResponse {
1416 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1417 is_error: None,
1418 meta: None,
1419 structured_content: None,
1420 })
1421 .unwrap();
1422 cx.run_until_parked();
1423
1424 // Ensure the tool results were inserted with the correct names.
1425 let completion = fake_model.pending_completions().pop().unwrap();
1426 assert_eq!(
1427 completion.messages.last().unwrap().content,
1428 vec![
1429 MessageContent::ToolResult(LanguageModelToolResult {
1430 tool_use_id: "tool_3".into(),
1431 tool_name: "echo".into(),
1432 is_error: false,
1433 content: "native".into(),
1434 output: Some("native".into()),
1435 },),
1436 MessageContent::ToolResult(LanguageModelToolResult {
1437 tool_use_id: "tool_2".into(),
1438 tool_name: "test_server_echo".into(),
1439 is_error: false,
1440 content: "mcp".into(),
1441 output: Some("mcp".into()),
1442 },),
1443 ]
1444 );
1445 fake_model.end_last_completion_stream();
1446 events.collect::<Vec<_>>().await;
1447}
1448
1449#[gpui::test]
1450async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1451 let ThreadTest {
1452 model,
1453 thread,
1454 context_server_store,
1455 fs,
1456 ..
1457 } = setup(cx, TestModel::Fake).await;
1458 let fake_model = model.as_fake();
1459
1460 // Setup settings to allow MCP tools
1461 fs.insert_file(
1462 paths::settings_file(),
1463 json!({
1464 "agent": {
1465 "always_allow_tool_actions": true,
1466 "profiles": {
1467 "test": {
1468 "name": "Test Profile",
1469 "enable_all_context_servers": true,
1470 "tools": {}
1471 },
1472 }
1473 }
1474 })
1475 .to_string()
1476 .into_bytes(),
1477 )
1478 .await;
1479 cx.run_until_parked();
1480 thread.update(cx, |thread, cx| {
1481 thread.set_profile(AgentProfileId("test".into()), cx)
1482 });
1483
1484 // Setup a context server with a tool
1485 let mut mcp_tool_calls = setup_context_server(
1486 "github_server",
1487 vec![context_server::types::Tool {
1488 name: "issue_read".into(),
1489 description: Some("Read a GitHub issue".into()),
1490 input_schema: json!({
1491 "type": "object",
1492 "properties": {
1493 "issue_url": { "type": "string" }
1494 }
1495 }),
1496 output_schema: None,
1497 annotations: None,
1498 }],
1499 &context_server_store,
1500 cx,
1501 );
1502
1503 // Send a message and have the model call the MCP tool
1504 let events = thread.update(cx, |thread, cx| {
1505 thread
1506 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1507 .unwrap()
1508 });
1509 cx.run_until_parked();
1510
1511 // Verify the MCP tool is available to the model
1512 let completion = fake_model.pending_completions().pop().unwrap();
1513 assert_eq!(
1514 tool_names_for_completion(&completion),
1515 vec!["issue_read"],
1516 "MCP tool should be available"
1517 );
1518
1519 // Simulate the model calling the MCP tool
1520 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1521 LanguageModelToolUse {
1522 id: "tool_1".into(),
1523 name: "issue_read".into(),
1524 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1525 .to_string(),
1526 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1527 is_input_complete: true,
1528 thought_signature: None,
1529 },
1530 ));
1531 fake_model.end_last_completion_stream();
1532 cx.run_until_parked();
1533
1534 // The MCP server receives the tool call and responds with content
1535 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1536 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1537 assert_eq!(tool_call_params.name, "issue_read");
1538 tool_call_response
1539 .send(context_server::types::CallToolResponse {
1540 content: vec![context_server::types::ToolResponseContent::Text {
1541 text: expected_tool_output.into(),
1542 }],
1543 is_error: None,
1544 meta: None,
1545 structured_content: None,
1546 })
1547 .unwrap();
1548 cx.run_until_parked();
1549
1550 // After tool completes, the model continues with a new completion request
1551 // that includes the tool results. We need to respond to this.
1552 let _completion = fake_model.pending_completions().pop().unwrap();
1553 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1554 fake_model
1555 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1556 fake_model.end_last_completion_stream();
1557 events.collect::<Vec<_>>().await;
1558
1559 // Verify the tool result is stored in the thread by checking the markdown output.
1560 // The tool result is in the first assistant message (not the last one, which is
1561 // the model's response after the tool completed).
1562 thread.update(cx, |thread, _cx| {
1563 let markdown = thread.to_markdown();
1564 assert!(
1565 markdown.contains("**Tool Result**: issue_read"),
1566 "Thread should contain tool result header"
1567 );
1568 assert!(
1569 markdown.contains(expected_tool_output),
1570 "Thread should contain tool output: {}",
1571 expected_tool_output
1572 );
1573 });
1574
1575 // Simulate app restart: disconnect the MCP server.
1576 // After restart, the MCP server won't be connected yet when the thread is replayed.
1577 context_server_store.update(cx, |store, cx| {
1578 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1579 });
1580 cx.run_until_parked();
1581
1582 // Replay the thread (this is what happens when loading a saved thread)
1583 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1584
1585 let mut found_tool_call = None;
1586 let mut found_tool_call_update_with_output = None;
1587
1588 while let Some(event) = replay_events.next().await {
1589 let event = event.unwrap();
1590 match &event {
1591 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1592 found_tool_call = Some(tc.clone());
1593 }
1594 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1595 if update.tool_call_id.to_string() == "tool_1" =>
1596 {
1597 if update.fields.raw_output.is_some() {
1598 found_tool_call_update_with_output = Some(update.clone());
1599 }
1600 }
1601 _ => {}
1602 }
1603 }
1604
1605 // The tool call should be found
1606 assert!(
1607 found_tool_call.is_some(),
1608 "Tool call should be emitted during replay"
1609 );
1610
1611 assert!(
1612 found_tool_call_update_with_output.is_some(),
1613 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1614 );
1615
1616 let update = found_tool_call_update_with_output.unwrap();
1617 assert_eq!(
1618 update.fields.raw_output,
1619 Some(expected_tool_output.into()),
1620 "raw_output should contain the saved tool result"
1621 );
1622
1623 // Also verify the status is correct (completed, not failed)
1624 assert_eq!(
1625 update.fields.status,
1626 Some(acp::ToolCallStatus::Completed),
1627 "Tool call status should reflect the original completion status"
1628 );
1629}
1630
1631#[gpui::test]
1632async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1633 let ThreadTest {
1634 model,
1635 thread,
1636 context_server_store,
1637 fs,
1638 ..
1639 } = setup(cx, TestModel::Fake).await;
1640 let fake_model = model.as_fake();
1641
1642 // Set up a profile with all tools enabled
1643 fs.insert_file(
1644 paths::settings_file(),
1645 json!({
1646 "agent": {
1647 "profiles": {
1648 "test": {
1649 "name": "Test Profile",
1650 "enable_all_context_servers": true,
1651 "tools": {
1652 EchoTool::NAME: true,
1653 DelayTool::NAME: true,
1654 WordListTool::NAME: true,
1655 ToolRequiringPermission::NAME: true,
1656 InfiniteTool::NAME: true,
1657 }
1658 },
1659 }
1660 }
1661 })
1662 .to_string()
1663 .into_bytes(),
1664 )
1665 .await;
1666 cx.run_until_parked();
1667
1668 thread.update(cx, |thread, cx| {
1669 thread.set_profile(AgentProfileId("test".into()), cx);
1670 thread.add_tool(EchoTool, None);
1671 thread.add_tool(DelayTool, None);
1672 thread.add_tool(WordListTool, None);
1673 thread.add_tool(ToolRequiringPermission, None);
1674 thread.add_tool(InfiniteTool, None);
1675 });
1676
1677 // Set up multiple context servers with some overlapping tool names
1678 let _server1_calls = setup_context_server(
1679 "xxx",
1680 vec![
1681 context_server::types::Tool {
1682 name: "echo".into(), // Conflicts with native EchoTool
1683 description: None,
1684 input_schema: serde_json::to_value(EchoTool::input_schema(
1685 LanguageModelToolSchemaFormat::JsonSchema,
1686 ))
1687 .unwrap(),
1688 output_schema: None,
1689 annotations: None,
1690 },
1691 context_server::types::Tool {
1692 name: "unique_tool_1".into(),
1693 description: None,
1694 input_schema: json!({"type": "object", "properties": {}}),
1695 output_schema: None,
1696 annotations: None,
1697 },
1698 ],
1699 &context_server_store,
1700 cx,
1701 );
1702
1703 let _server2_calls = setup_context_server(
1704 "yyy",
1705 vec![
1706 context_server::types::Tool {
1707 name: "echo".into(), // Also conflicts with native EchoTool
1708 description: None,
1709 input_schema: serde_json::to_value(EchoTool::input_schema(
1710 LanguageModelToolSchemaFormat::JsonSchema,
1711 ))
1712 .unwrap(),
1713 output_schema: None,
1714 annotations: None,
1715 },
1716 context_server::types::Tool {
1717 name: "unique_tool_2".into(),
1718 description: None,
1719 input_schema: json!({"type": "object", "properties": {}}),
1720 output_schema: None,
1721 annotations: None,
1722 },
1723 context_server::types::Tool {
1724 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1725 description: None,
1726 input_schema: json!({"type": "object", "properties": {}}),
1727 output_schema: None,
1728 annotations: None,
1729 },
1730 context_server::types::Tool {
1731 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1732 description: None,
1733 input_schema: json!({"type": "object", "properties": {}}),
1734 output_schema: None,
1735 annotations: None,
1736 },
1737 ],
1738 &context_server_store,
1739 cx,
1740 );
1741 let _server3_calls = setup_context_server(
1742 "zzz",
1743 vec![
1744 context_server::types::Tool {
1745 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1746 description: None,
1747 input_schema: json!({"type": "object", "properties": {}}),
1748 output_schema: None,
1749 annotations: None,
1750 },
1751 context_server::types::Tool {
1752 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1753 description: None,
1754 input_schema: json!({"type": "object", "properties": {}}),
1755 output_schema: None,
1756 annotations: None,
1757 },
1758 context_server::types::Tool {
1759 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1760 description: None,
1761 input_schema: json!({"type": "object", "properties": {}}),
1762 output_schema: None,
1763 annotations: None,
1764 },
1765 ],
1766 &context_server_store,
1767 cx,
1768 );
1769
1770 // Server with spaces in name - tests snake_case conversion for API compatibility
1771 let _server4_calls = setup_context_server(
1772 "Azure DevOps",
1773 vec![context_server::types::Tool {
1774 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1775 description: None,
1776 input_schema: serde_json::to_value(EchoTool::input_schema(
1777 LanguageModelToolSchemaFormat::JsonSchema,
1778 ))
1779 .unwrap(),
1780 output_schema: None,
1781 annotations: None,
1782 }],
1783 &context_server_store,
1784 cx,
1785 );
1786
1787 thread
1788 .update(cx, |thread, cx| {
1789 thread.send(UserMessageId::new(), ["Go"], cx)
1790 })
1791 .unwrap();
1792 cx.run_until_parked();
1793 let completion = fake_model.pending_completions().pop().unwrap();
1794 assert_eq!(
1795 tool_names_for_completion(&completion),
1796 vec![
1797 "azure_dev_ops_echo",
1798 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1799 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1800 "delay",
1801 "echo",
1802 "infinite",
1803 "tool_requiring_permission",
1804 "unique_tool_1",
1805 "unique_tool_2",
1806 "word_list",
1807 "xxx_echo",
1808 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1809 "yyy_echo",
1810 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1811 ]
1812 );
1813}
1814
1815#[gpui::test]
1816#[cfg_attr(not(feature = "e2e"), ignore)]
1817async fn test_cancellation(cx: &mut TestAppContext) {
1818 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1819
1820 let mut events = thread
1821 .update(cx, |thread, cx| {
1822 thread.add_tool(InfiniteTool, None);
1823 thread.add_tool(EchoTool, None);
1824 thread.send(
1825 UserMessageId::new(),
1826 ["Call the echo tool, then call the infinite tool, then explain their output"],
1827 cx,
1828 )
1829 })
1830 .unwrap();
1831
1832 // Wait until both tools are called.
1833 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1834 let mut echo_id = None;
1835 let mut echo_completed = false;
1836 while let Some(event) = events.next().await {
1837 match event.unwrap() {
1838 ThreadEvent::ToolCall(tool_call) => {
1839 assert_eq!(tool_call.title, expected_tools.remove(0));
1840 if tool_call.title == "Echo" {
1841 echo_id = Some(tool_call.tool_call_id);
1842 }
1843 }
1844 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1845 acp::ToolCallUpdate {
1846 tool_call_id,
1847 fields:
1848 acp::ToolCallUpdateFields {
1849 status: Some(acp::ToolCallStatus::Completed),
1850 ..
1851 },
1852 ..
1853 },
1854 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1855 echo_completed = true;
1856 }
1857 _ => {}
1858 }
1859
1860 if expected_tools.is_empty() && echo_completed {
1861 break;
1862 }
1863 }
1864
1865 // Cancel the current send and ensure that the event stream is closed, even
1866 // if one of the tools is still running.
1867 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1868 let events = events.collect::<Vec<_>>().await;
1869 let last_event = events.last();
1870 assert!(
1871 matches!(
1872 last_event,
1873 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1874 ),
1875 "unexpected event {last_event:?}"
1876 );
1877
1878 // Ensure we can still send a new message after cancellation.
1879 let events = thread
1880 .update(cx, |thread, cx| {
1881 thread.send(
1882 UserMessageId::new(),
1883 ["Testing: reply with 'Hello' then stop."],
1884 cx,
1885 )
1886 })
1887 .unwrap()
1888 .collect::<Vec<_>>()
1889 .await;
1890 thread.update(cx, |thread, _cx| {
1891 let message = thread.last_message().unwrap();
1892 let agent_message = message.as_agent_message().unwrap();
1893 assert_eq!(
1894 agent_message.content,
1895 vec![AgentMessageContent::Text("Hello".to_string())]
1896 );
1897 });
1898 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1899}
1900
1901#[gpui::test]
1902async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1903 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1904 always_allow_tools(cx);
1905 let fake_model = model.as_fake();
1906
1907 let environment = Rc::new(cx.update(|cx| {
1908 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1909 }));
1910 let handle = environment.terminal_handle.clone().unwrap();
1911
1912 let mut events = thread
1913 .update(cx, |thread, cx| {
1914 thread.add_tool(
1915 crate::TerminalTool::new(thread.project().clone(), environment),
1916 None,
1917 );
1918 thread.send(UserMessageId::new(), ["run a command"], cx)
1919 })
1920 .unwrap();
1921
1922 cx.run_until_parked();
1923
1924 // Simulate the model calling the terminal tool
1925 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1926 LanguageModelToolUse {
1927 id: "terminal_tool_1".into(),
1928 name: TerminalTool::NAME.into(),
1929 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1930 input: json!({"command": "sleep 1000", "cd": "."}),
1931 is_input_complete: true,
1932 thought_signature: None,
1933 },
1934 ));
1935 fake_model.end_last_completion_stream();
1936
1937 // Wait for the terminal tool to start running
1938 wait_for_terminal_tool_started(&mut events, cx).await;
1939
1940 // Cancel the thread while the terminal is running
1941 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1942
1943 // Collect remaining events, driving the executor to let cancellation complete
1944 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1945
1946 // Verify the terminal was killed
1947 assert!(
1948 handle.was_killed(),
1949 "expected terminal handle to be killed on cancellation"
1950 );
1951
1952 // Verify we got a cancellation stop event
1953 assert_eq!(
1954 stop_events(remaining_events),
1955 vec![acp::StopReason::Cancelled],
1956 );
1957
1958 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1959 thread.update(cx, |thread, _cx| {
1960 let message = thread.last_message().unwrap();
1961 let agent_message = message.as_agent_message().unwrap();
1962
1963 let tool_use = agent_message
1964 .content
1965 .iter()
1966 .find_map(|content| match content {
1967 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1968 _ => None,
1969 })
1970 .expect("expected tool use in agent message");
1971
1972 let tool_result = agent_message
1973 .tool_results
1974 .get(&tool_use.id)
1975 .expect("expected tool result");
1976
1977 let result_text = match &tool_result.content {
1978 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1979 _ => panic!("expected text content in tool result"),
1980 };
1981
1982 // "partial output" comes from FakeTerminalHandle's output field
1983 assert!(
1984 result_text.contains("partial output"),
1985 "expected tool result to contain terminal output, got: {result_text}"
1986 );
1987 // Match the actual format from process_content in terminal_tool.rs
1988 assert!(
1989 result_text.contains("The user stopped this command"),
1990 "expected tool result to indicate user stopped, got: {result_text}"
1991 );
1992 });
1993
1994 // Verify we can send a new message after cancellation
1995 verify_thread_recovery(&thread, &fake_model, cx).await;
1996}
1997
1998#[gpui::test]
1999async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2000 // This test verifies that tools which properly handle cancellation via
2001 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2002 // to cancellation and report that they were cancelled.
2003 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2004 always_allow_tools(cx);
2005 let fake_model = model.as_fake();
2006
2007 let (tool, was_cancelled) = CancellationAwareTool::new();
2008
2009 let mut events = thread
2010 .update(cx, |thread, cx| {
2011 thread.add_tool(tool, None);
2012 thread.send(
2013 UserMessageId::new(),
2014 ["call the cancellation aware tool"],
2015 cx,
2016 )
2017 })
2018 .unwrap();
2019
2020 cx.run_until_parked();
2021
2022 // Simulate the model calling the cancellation-aware tool
2023 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2024 LanguageModelToolUse {
2025 id: "cancellation_aware_1".into(),
2026 name: "cancellation_aware".into(),
2027 raw_input: r#"{}"#.into(),
2028 input: json!({}),
2029 is_input_complete: true,
2030 thought_signature: None,
2031 },
2032 ));
2033 fake_model.end_last_completion_stream();
2034
2035 cx.run_until_parked();
2036
2037 // Wait for the tool call to be reported
2038 let mut tool_started = false;
2039 let deadline = cx.executor().num_cpus() * 100;
2040 for _ in 0..deadline {
2041 cx.run_until_parked();
2042
2043 while let Some(Some(event)) = events.next().now_or_never() {
2044 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2045 if tool_call.title == "Cancellation Aware Tool" {
2046 tool_started = true;
2047 break;
2048 }
2049 }
2050 }
2051
2052 if tool_started {
2053 break;
2054 }
2055
2056 cx.background_executor
2057 .timer(Duration::from_millis(10))
2058 .await;
2059 }
2060 assert!(tool_started, "expected cancellation aware tool to start");
2061
2062 // Cancel the thread and wait for it to complete
2063 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2064
2065 // The cancel task should complete promptly because the tool handles cancellation
2066 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2067 futures::select! {
2068 _ = cancel_task.fuse() => {}
2069 _ = timeout.fuse() => {
2070 panic!("cancel task timed out - tool did not respond to cancellation");
2071 }
2072 }
2073
2074 // Verify the tool detected cancellation via its flag
2075 assert!(
2076 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2077 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2078 );
2079
2080 // Collect remaining events
2081 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2082
2083 // Verify we got a cancellation stop event
2084 assert_eq!(
2085 stop_events(remaining_events),
2086 vec![acp::StopReason::Cancelled],
2087 );
2088
2089 // Verify we can send a new message after cancellation
2090 verify_thread_recovery(&thread, &fake_model, cx).await;
2091}
2092
2093/// Helper to verify thread can recover after cancellation by sending a simple message.
2094async fn verify_thread_recovery(
2095 thread: &Entity<Thread>,
2096 fake_model: &FakeLanguageModel,
2097 cx: &mut TestAppContext,
2098) {
2099 let events = thread
2100 .update(cx, |thread, cx| {
2101 thread.send(
2102 UserMessageId::new(),
2103 ["Testing: reply with 'Hello' then stop."],
2104 cx,
2105 )
2106 })
2107 .unwrap();
2108 cx.run_until_parked();
2109 fake_model.send_last_completion_stream_text_chunk("Hello");
2110 fake_model
2111 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2112 fake_model.end_last_completion_stream();
2113
2114 let events = events.collect::<Vec<_>>().await;
2115 thread.update(cx, |thread, _cx| {
2116 let message = thread.last_message().unwrap();
2117 let agent_message = message.as_agent_message().unwrap();
2118 assert_eq!(
2119 agent_message.content,
2120 vec![AgentMessageContent::Text("Hello".to_string())]
2121 );
2122 });
2123 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2124}
2125
2126/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2127async fn wait_for_terminal_tool_started(
2128 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2129 cx: &mut TestAppContext,
2130) {
2131 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2132 for _ in 0..deadline {
2133 cx.run_until_parked();
2134
2135 while let Some(Some(event)) = events.next().now_or_never() {
2136 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2137 update,
2138 ))) = &event
2139 {
2140 if update.fields.content.as_ref().is_some_and(|content| {
2141 content
2142 .iter()
2143 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2144 }) {
2145 return;
2146 }
2147 }
2148 }
2149
2150 cx.background_executor
2151 .timer(Duration::from_millis(10))
2152 .await;
2153 }
2154 panic!("terminal tool did not start within the expected time");
2155}
2156
2157/// Collects events until a Stop event is received, driving the executor to completion.
2158async fn collect_events_until_stop(
2159 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2160 cx: &mut TestAppContext,
2161) -> Vec<Result<ThreadEvent>> {
2162 let mut collected = Vec::new();
2163 let deadline = cx.executor().num_cpus() * 200;
2164
2165 for _ in 0..deadline {
2166 cx.executor().advance_clock(Duration::from_millis(10));
2167 cx.run_until_parked();
2168
2169 while let Some(Some(event)) = events.next().now_or_never() {
2170 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2171 collected.push(event);
2172 if is_stop {
2173 return collected;
2174 }
2175 }
2176 }
2177 panic!(
2178 "did not receive Stop event within the expected time; collected {} events",
2179 collected.len()
2180 );
2181}
2182
2183#[gpui::test]
2184async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2185 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2186 always_allow_tools(cx);
2187 let fake_model = model.as_fake();
2188
2189 let environment = Rc::new(cx.update(|cx| {
2190 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2191 }));
2192 let handle = environment.terminal_handle.clone().unwrap();
2193
2194 let message_id = UserMessageId::new();
2195 let mut events = thread
2196 .update(cx, |thread, cx| {
2197 thread.add_tool(
2198 crate::TerminalTool::new(thread.project().clone(), environment),
2199 None,
2200 );
2201 thread.send(message_id.clone(), ["run a command"], cx)
2202 })
2203 .unwrap();
2204
2205 cx.run_until_parked();
2206
2207 // Simulate the model calling the terminal tool
2208 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2209 LanguageModelToolUse {
2210 id: "terminal_tool_1".into(),
2211 name: TerminalTool::NAME.into(),
2212 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2213 input: json!({"command": "sleep 1000", "cd": "."}),
2214 is_input_complete: true,
2215 thought_signature: None,
2216 },
2217 ));
2218 fake_model.end_last_completion_stream();
2219
2220 // Wait for the terminal tool to start running
2221 wait_for_terminal_tool_started(&mut events, cx).await;
2222
2223 // Truncate the thread while the terminal is running
2224 thread
2225 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2226 .unwrap();
2227
2228 // Drive the executor to let cancellation complete
2229 let _ = collect_events_until_stop(&mut events, cx).await;
2230
2231 // Verify the terminal was killed
2232 assert!(
2233 handle.was_killed(),
2234 "expected terminal handle to be killed on truncate"
2235 );
2236
2237 // Verify the thread is empty after truncation
2238 thread.update(cx, |thread, _cx| {
2239 assert_eq!(
2240 thread.to_markdown(),
2241 "",
2242 "expected thread to be empty after truncating the only message"
2243 );
2244 });
2245
2246 // Verify we can send a new message after truncation
2247 verify_thread_recovery(&thread, &fake_model, cx).await;
2248}
2249
2250#[gpui::test]
2251async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2252 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2253 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2254 always_allow_tools(cx);
2255 let fake_model = model.as_fake();
2256
2257 let environment = Rc::new(MultiTerminalEnvironment::new());
2258
2259 let mut events = thread
2260 .update(cx, |thread, cx| {
2261 thread.add_tool(
2262 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2263 None,
2264 );
2265 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2266 })
2267 .unwrap();
2268
2269 cx.run_until_parked();
2270
2271 // Simulate the model calling two terminal tools
2272 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2273 LanguageModelToolUse {
2274 id: "terminal_tool_1".into(),
2275 name: TerminalTool::NAME.into(),
2276 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2277 input: json!({"command": "sleep 1000", "cd": "."}),
2278 is_input_complete: true,
2279 thought_signature: None,
2280 },
2281 ));
2282 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2283 LanguageModelToolUse {
2284 id: "terminal_tool_2".into(),
2285 name: TerminalTool::NAME.into(),
2286 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2287 input: json!({"command": "sleep 2000", "cd": "."}),
2288 is_input_complete: true,
2289 thought_signature: None,
2290 },
2291 ));
2292 fake_model.end_last_completion_stream();
2293
2294 // Wait for both terminal tools to start by counting terminal content updates
2295 let mut terminals_started = 0;
2296 let deadline = cx.executor().num_cpus() * 100;
2297 for _ in 0..deadline {
2298 cx.run_until_parked();
2299
2300 while let Some(Some(event)) = events.next().now_or_never() {
2301 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2302 update,
2303 ))) = &event
2304 {
2305 if update.fields.content.as_ref().is_some_and(|content| {
2306 content
2307 .iter()
2308 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2309 }) {
2310 terminals_started += 1;
2311 if terminals_started >= 2 {
2312 break;
2313 }
2314 }
2315 }
2316 }
2317 if terminals_started >= 2 {
2318 break;
2319 }
2320
2321 cx.background_executor
2322 .timer(Duration::from_millis(10))
2323 .await;
2324 }
2325 assert!(
2326 terminals_started >= 2,
2327 "expected 2 terminal tools to start, got {terminals_started}"
2328 );
2329
2330 // Cancel the thread while both terminals are running
2331 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2332
2333 // Collect remaining events
2334 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2335
2336 // Verify both terminal handles were killed
2337 let handles = environment.handles();
2338 assert_eq!(
2339 handles.len(),
2340 2,
2341 "expected 2 terminal handles to be created"
2342 );
2343 assert!(
2344 handles[0].was_killed(),
2345 "expected first terminal handle to be killed on cancellation"
2346 );
2347 assert!(
2348 handles[1].was_killed(),
2349 "expected second terminal handle to be killed on cancellation"
2350 );
2351
2352 // Verify we got a cancellation stop event
2353 assert_eq!(
2354 stop_events(remaining_events),
2355 vec![acp::StopReason::Cancelled],
2356 );
2357}
2358
2359#[gpui::test]
2360async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2361 // Tests that clicking the stop button on the terminal card (as opposed to the main
2362 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2363 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2364 always_allow_tools(cx);
2365 let fake_model = model.as_fake();
2366
2367 let environment = Rc::new(cx.update(|cx| {
2368 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2369 }));
2370 let handle = environment.terminal_handle.clone().unwrap();
2371
2372 let mut events = thread
2373 .update(cx, |thread, cx| {
2374 thread.add_tool(
2375 crate::TerminalTool::new(thread.project().clone(), environment),
2376 None,
2377 );
2378 thread.send(UserMessageId::new(), ["run a command"], cx)
2379 })
2380 .unwrap();
2381
2382 cx.run_until_parked();
2383
2384 // Simulate the model calling the terminal tool
2385 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2386 LanguageModelToolUse {
2387 id: "terminal_tool_1".into(),
2388 name: TerminalTool::NAME.into(),
2389 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2390 input: json!({"command": "sleep 1000", "cd": "."}),
2391 is_input_complete: true,
2392 thought_signature: None,
2393 },
2394 ));
2395 fake_model.end_last_completion_stream();
2396
2397 // Wait for the terminal tool to start running
2398 wait_for_terminal_tool_started(&mut events, cx).await;
2399
2400 // Simulate user clicking stop on the terminal card itself.
2401 // This sets the flag and signals exit (simulating what the real UI would do).
2402 handle.set_stopped_by_user(true);
2403 handle.killed.store(true, Ordering::SeqCst);
2404 handle.signal_exit();
2405
2406 // Wait for the tool to complete
2407 cx.run_until_parked();
2408
2409 // The thread continues after tool completion - simulate the model ending its turn
2410 fake_model
2411 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2412 fake_model.end_last_completion_stream();
2413
2414 // Collect remaining events
2415 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2416
2417 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2418 assert_eq!(
2419 stop_events(remaining_events),
2420 vec![acp::StopReason::EndTurn],
2421 );
2422
2423 // Verify the tool result indicates user stopped
2424 thread.update(cx, |thread, _cx| {
2425 let message = thread.last_message().unwrap();
2426 let agent_message = message.as_agent_message().unwrap();
2427
2428 let tool_use = agent_message
2429 .content
2430 .iter()
2431 .find_map(|content| match content {
2432 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2433 _ => None,
2434 })
2435 .expect("expected tool use in agent message");
2436
2437 let tool_result = agent_message
2438 .tool_results
2439 .get(&tool_use.id)
2440 .expect("expected tool result");
2441
2442 let result_text = match &tool_result.content {
2443 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2444 _ => panic!("expected text content in tool result"),
2445 };
2446
2447 assert!(
2448 result_text.contains("The user stopped this command"),
2449 "expected tool result to indicate user stopped, got: {result_text}"
2450 );
2451 });
2452}
2453
2454#[gpui::test]
2455async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2456 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2457 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2458 always_allow_tools(cx);
2459 let fake_model = model.as_fake();
2460
2461 let environment = Rc::new(cx.update(|cx| {
2462 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2463 }));
2464 let handle = environment.terminal_handle.clone().unwrap();
2465
2466 let mut events = thread
2467 .update(cx, |thread, cx| {
2468 thread.add_tool(
2469 crate::TerminalTool::new(thread.project().clone(), environment),
2470 None,
2471 );
2472 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2473 })
2474 .unwrap();
2475
2476 cx.run_until_parked();
2477
2478 // Simulate the model calling the terminal tool with a short timeout
2479 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2480 LanguageModelToolUse {
2481 id: "terminal_tool_1".into(),
2482 name: TerminalTool::NAME.into(),
2483 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2484 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2485 is_input_complete: true,
2486 thought_signature: None,
2487 },
2488 ));
2489 fake_model.end_last_completion_stream();
2490
2491 // Wait for the terminal tool to start running
2492 wait_for_terminal_tool_started(&mut events, cx).await;
2493
2494 // Advance clock past the timeout
2495 cx.executor().advance_clock(Duration::from_millis(200));
2496 cx.run_until_parked();
2497
2498 // The thread continues after tool completion - simulate the model ending its turn
2499 fake_model
2500 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2501 fake_model.end_last_completion_stream();
2502
2503 // Collect remaining events
2504 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2505
2506 // Verify the terminal was killed due to timeout
2507 assert!(
2508 handle.was_killed(),
2509 "expected terminal handle to be killed on timeout"
2510 );
2511
2512 // Verify we got an EndTurn (the tool completed, just with timeout)
2513 assert_eq!(
2514 stop_events(remaining_events),
2515 vec![acp::StopReason::EndTurn],
2516 );
2517
2518 // Verify the tool result indicates timeout, not user stopped
2519 thread.update(cx, |thread, _cx| {
2520 let message = thread.last_message().unwrap();
2521 let agent_message = message.as_agent_message().unwrap();
2522
2523 let tool_use = agent_message
2524 .content
2525 .iter()
2526 .find_map(|content| match content {
2527 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2528 _ => None,
2529 })
2530 .expect("expected tool use in agent message");
2531
2532 let tool_result = agent_message
2533 .tool_results
2534 .get(&tool_use.id)
2535 .expect("expected tool result");
2536
2537 let result_text = match &tool_result.content {
2538 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2539 _ => panic!("expected text content in tool result"),
2540 };
2541
2542 assert!(
2543 result_text.contains("timed out"),
2544 "expected tool result to indicate timeout, got: {result_text}"
2545 );
2546 assert!(
2547 !result_text.contains("The user stopped"),
2548 "tool result should not mention user stopped when it timed out, got: {result_text}"
2549 );
2550 });
2551}
2552
2553#[gpui::test]
2554async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2555 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2556 let fake_model = model.as_fake();
2557
2558 let events_1 = thread
2559 .update(cx, |thread, cx| {
2560 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2561 })
2562 .unwrap();
2563 cx.run_until_parked();
2564 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2565 cx.run_until_parked();
2566
2567 let events_2 = thread
2568 .update(cx, |thread, cx| {
2569 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2570 })
2571 .unwrap();
2572 cx.run_until_parked();
2573 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2574 fake_model
2575 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2576 fake_model.end_last_completion_stream();
2577
2578 let events_1 = events_1.collect::<Vec<_>>().await;
2579 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2580 let events_2 = events_2.collect::<Vec<_>>().await;
2581 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2582}
2583
2584#[gpui::test]
2585async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2586 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2587 let fake_model = model.as_fake();
2588
2589 let events_1 = thread
2590 .update(cx, |thread, cx| {
2591 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2592 })
2593 .unwrap();
2594 cx.run_until_parked();
2595 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2596 fake_model
2597 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2598 fake_model.end_last_completion_stream();
2599 let events_1 = events_1.collect::<Vec<_>>().await;
2600
2601 let events_2 = thread
2602 .update(cx, |thread, cx| {
2603 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2604 })
2605 .unwrap();
2606 cx.run_until_parked();
2607 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2608 fake_model
2609 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2610 fake_model.end_last_completion_stream();
2611 let events_2 = events_2.collect::<Vec<_>>().await;
2612
2613 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2614 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2615}
2616
2617#[gpui::test]
2618async fn test_refusal(cx: &mut TestAppContext) {
2619 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2620 let fake_model = model.as_fake();
2621
2622 let events = thread
2623 .update(cx, |thread, cx| {
2624 thread.send(UserMessageId::new(), ["Hello"], cx)
2625 })
2626 .unwrap();
2627 cx.run_until_parked();
2628 thread.read_with(cx, |thread, _| {
2629 assert_eq!(
2630 thread.to_markdown(),
2631 indoc! {"
2632 ## User
2633
2634 Hello
2635 "}
2636 );
2637 });
2638
2639 fake_model.send_last_completion_stream_text_chunk("Hey!");
2640 cx.run_until_parked();
2641 thread.read_with(cx, |thread, _| {
2642 assert_eq!(
2643 thread.to_markdown(),
2644 indoc! {"
2645 ## User
2646
2647 Hello
2648
2649 ## Assistant
2650
2651 Hey!
2652 "}
2653 );
2654 });
2655
2656 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2657 fake_model
2658 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2659 let events = events.collect::<Vec<_>>().await;
2660 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2661 thread.read_with(cx, |thread, _| {
2662 assert_eq!(thread.to_markdown(), "");
2663 });
2664}
2665
2666#[gpui::test]
2667async fn test_truncate_first_message(cx: &mut TestAppContext) {
2668 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2669 let fake_model = model.as_fake();
2670
2671 let message_id = UserMessageId::new();
2672 thread
2673 .update(cx, |thread, cx| {
2674 thread.send(message_id.clone(), ["Hello"], cx)
2675 })
2676 .unwrap();
2677 cx.run_until_parked();
2678 thread.read_with(cx, |thread, _| {
2679 assert_eq!(
2680 thread.to_markdown(),
2681 indoc! {"
2682 ## User
2683
2684 Hello
2685 "}
2686 );
2687 assert_eq!(thread.latest_token_usage(), None);
2688 });
2689
2690 fake_model.send_last_completion_stream_text_chunk("Hey!");
2691 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2692 language_model::TokenUsage {
2693 input_tokens: 32_000,
2694 output_tokens: 16_000,
2695 cache_creation_input_tokens: 0,
2696 cache_read_input_tokens: 0,
2697 },
2698 ));
2699 cx.run_until_parked();
2700 thread.read_with(cx, |thread, _| {
2701 assert_eq!(
2702 thread.to_markdown(),
2703 indoc! {"
2704 ## User
2705
2706 Hello
2707
2708 ## Assistant
2709
2710 Hey!
2711 "}
2712 );
2713 assert_eq!(
2714 thread.latest_token_usage(),
2715 Some(acp_thread::TokenUsage {
2716 used_tokens: 32_000 + 16_000,
2717 max_tokens: 1_000_000,
2718 input_tokens: 32_000,
2719 output_tokens: 16_000,
2720 })
2721 );
2722 });
2723
2724 thread
2725 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2726 .unwrap();
2727 cx.run_until_parked();
2728 thread.read_with(cx, |thread, _| {
2729 assert_eq!(thread.to_markdown(), "");
2730 assert_eq!(thread.latest_token_usage(), None);
2731 });
2732
2733 // Ensure we can still send a new message after truncation.
2734 thread
2735 .update(cx, |thread, cx| {
2736 thread.send(UserMessageId::new(), ["Hi"], cx)
2737 })
2738 .unwrap();
2739 thread.update(cx, |thread, _cx| {
2740 assert_eq!(
2741 thread.to_markdown(),
2742 indoc! {"
2743 ## User
2744
2745 Hi
2746 "}
2747 );
2748 });
2749 cx.run_until_parked();
2750 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2751 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2752 language_model::TokenUsage {
2753 input_tokens: 40_000,
2754 output_tokens: 20_000,
2755 cache_creation_input_tokens: 0,
2756 cache_read_input_tokens: 0,
2757 },
2758 ));
2759 cx.run_until_parked();
2760 thread.read_with(cx, |thread, _| {
2761 assert_eq!(
2762 thread.to_markdown(),
2763 indoc! {"
2764 ## User
2765
2766 Hi
2767
2768 ## Assistant
2769
2770 Ahoy!
2771 "}
2772 );
2773
2774 assert_eq!(
2775 thread.latest_token_usage(),
2776 Some(acp_thread::TokenUsage {
2777 used_tokens: 40_000 + 20_000,
2778 max_tokens: 1_000_000,
2779 input_tokens: 40_000,
2780 output_tokens: 20_000,
2781 })
2782 );
2783 });
2784}
2785
2786#[gpui::test]
2787async fn test_truncate_second_message(cx: &mut TestAppContext) {
2788 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2789 let fake_model = model.as_fake();
2790
2791 thread
2792 .update(cx, |thread, cx| {
2793 thread.send(UserMessageId::new(), ["Message 1"], cx)
2794 })
2795 .unwrap();
2796 cx.run_until_parked();
2797 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2798 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2799 language_model::TokenUsage {
2800 input_tokens: 32_000,
2801 output_tokens: 16_000,
2802 cache_creation_input_tokens: 0,
2803 cache_read_input_tokens: 0,
2804 },
2805 ));
2806 fake_model.end_last_completion_stream();
2807 cx.run_until_parked();
2808
2809 let assert_first_message_state = |cx: &mut TestAppContext| {
2810 thread.clone().read_with(cx, |thread, _| {
2811 assert_eq!(
2812 thread.to_markdown(),
2813 indoc! {"
2814 ## User
2815
2816 Message 1
2817
2818 ## Assistant
2819
2820 Message 1 response
2821 "}
2822 );
2823
2824 assert_eq!(
2825 thread.latest_token_usage(),
2826 Some(acp_thread::TokenUsage {
2827 used_tokens: 32_000 + 16_000,
2828 max_tokens: 1_000_000,
2829 input_tokens: 32_000,
2830 output_tokens: 16_000,
2831 })
2832 );
2833 });
2834 };
2835
2836 assert_first_message_state(cx);
2837
2838 let second_message_id = UserMessageId::new();
2839 thread
2840 .update(cx, |thread, cx| {
2841 thread.send(second_message_id.clone(), ["Message 2"], cx)
2842 })
2843 .unwrap();
2844 cx.run_until_parked();
2845
2846 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2847 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2848 language_model::TokenUsage {
2849 input_tokens: 40_000,
2850 output_tokens: 20_000,
2851 cache_creation_input_tokens: 0,
2852 cache_read_input_tokens: 0,
2853 },
2854 ));
2855 fake_model.end_last_completion_stream();
2856 cx.run_until_parked();
2857
2858 thread.read_with(cx, |thread, _| {
2859 assert_eq!(
2860 thread.to_markdown(),
2861 indoc! {"
2862 ## User
2863
2864 Message 1
2865
2866 ## Assistant
2867
2868 Message 1 response
2869
2870 ## User
2871
2872 Message 2
2873
2874 ## Assistant
2875
2876 Message 2 response
2877 "}
2878 );
2879
2880 assert_eq!(
2881 thread.latest_token_usage(),
2882 Some(acp_thread::TokenUsage {
2883 used_tokens: 40_000 + 20_000,
2884 max_tokens: 1_000_000,
2885 input_tokens: 40_000,
2886 output_tokens: 20_000,
2887 })
2888 );
2889 });
2890
2891 thread
2892 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2893 .unwrap();
2894 cx.run_until_parked();
2895
2896 assert_first_message_state(cx);
2897}
2898
2899#[gpui::test]
2900async fn test_title_generation(cx: &mut TestAppContext) {
2901 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2902 let fake_model = model.as_fake();
2903
2904 let summary_model = Arc::new(FakeLanguageModel::default());
2905 thread.update(cx, |thread, cx| {
2906 thread.set_summarization_model(Some(summary_model.clone()), cx)
2907 });
2908
2909 let send = thread
2910 .update(cx, |thread, cx| {
2911 thread.send(UserMessageId::new(), ["Hello"], cx)
2912 })
2913 .unwrap();
2914 cx.run_until_parked();
2915
2916 fake_model.send_last_completion_stream_text_chunk("Hey!");
2917 fake_model.end_last_completion_stream();
2918 cx.run_until_parked();
2919 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2920
2921 // Ensure the summary model has been invoked to generate a title.
2922 summary_model.send_last_completion_stream_text_chunk("Hello ");
2923 summary_model.send_last_completion_stream_text_chunk("world\nG");
2924 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2925 summary_model.end_last_completion_stream();
2926 send.collect::<Vec<_>>().await;
2927 cx.run_until_parked();
2928 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2929
2930 // Send another message, ensuring no title is generated this time.
2931 let send = thread
2932 .update(cx, |thread, cx| {
2933 thread.send(UserMessageId::new(), ["Hello again"], cx)
2934 })
2935 .unwrap();
2936 cx.run_until_parked();
2937 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2938 fake_model.end_last_completion_stream();
2939 cx.run_until_parked();
2940 assert_eq!(summary_model.pending_completions(), Vec::new());
2941 send.collect::<Vec<_>>().await;
2942 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2943}
2944
2945#[gpui::test]
2946async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2947 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2948 let fake_model = model.as_fake();
2949
2950 let _events = thread
2951 .update(cx, |thread, cx| {
2952 thread.add_tool(ToolRequiringPermission, None);
2953 thread.add_tool(EchoTool, None);
2954 thread.send(UserMessageId::new(), ["Hey!"], cx)
2955 })
2956 .unwrap();
2957 cx.run_until_parked();
2958
2959 let permission_tool_use = LanguageModelToolUse {
2960 id: "tool_id_1".into(),
2961 name: ToolRequiringPermission::NAME.into(),
2962 raw_input: "{}".into(),
2963 input: json!({}),
2964 is_input_complete: true,
2965 thought_signature: None,
2966 };
2967 let echo_tool_use = LanguageModelToolUse {
2968 id: "tool_id_2".into(),
2969 name: EchoTool::NAME.into(),
2970 raw_input: json!({"text": "test"}).to_string(),
2971 input: json!({"text": "test"}),
2972 is_input_complete: true,
2973 thought_signature: None,
2974 };
2975 fake_model.send_last_completion_stream_text_chunk("Hi!");
2976 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2977 permission_tool_use,
2978 ));
2979 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2980 echo_tool_use.clone(),
2981 ));
2982 fake_model.end_last_completion_stream();
2983 cx.run_until_parked();
2984
2985 // Ensure pending tools are skipped when building a request.
2986 let request = thread
2987 .read_with(cx, |thread, cx| {
2988 thread.build_completion_request(CompletionIntent::EditFile, cx)
2989 })
2990 .unwrap();
2991 assert_eq!(
2992 request.messages[1..],
2993 vec![
2994 LanguageModelRequestMessage {
2995 role: Role::User,
2996 content: vec!["Hey!".into()],
2997 cache: true,
2998 reasoning_details: None,
2999 },
3000 LanguageModelRequestMessage {
3001 role: Role::Assistant,
3002 content: vec![
3003 MessageContent::Text("Hi!".into()),
3004 MessageContent::ToolUse(echo_tool_use.clone())
3005 ],
3006 cache: false,
3007 reasoning_details: None,
3008 },
3009 LanguageModelRequestMessage {
3010 role: Role::User,
3011 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3012 tool_use_id: echo_tool_use.id.clone(),
3013 tool_name: echo_tool_use.name,
3014 is_error: false,
3015 content: "test".into(),
3016 output: Some("test".into())
3017 })],
3018 cache: false,
3019 reasoning_details: None,
3020 },
3021 ],
3022 );
3023}
3024
3025#[gpui::test]
3026async fn test_agent_connection(cx: &mut TestAppContext) {
3027 cx.update(settings::init);
3028 let templates = Templates::new();
3029
3030 // Initialize language model system with test provider
3031 cx.update(|cx| {
3032 gpui_tokio::init(cx);
3033
3034 let http_client = FakeHttpClient::with_404_response();
3035 let clock = Arc::new(clock::FakeSystemClock::new());
3036 let client = Client::new(clock, http_client, cx);
3037 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3038 language_model::init(client.clone(), cx);
3039 language_models::init(user_store, client.clone(), cx);
3040 LanguageModelRegistry::test(cx);
3041 });
3042 cx.executor().forbid_parking();
3043
3044 // Create a project for new_thread
3045 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3046 fake_fs.insert_tree(path!("/test"), json!({})).await;
3047 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3048 let cwd = Path::new("/test");
3049 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3050
3051 // Create agent and connection
3052 let agent = NativeAgent::new(
3053 project.clone(),
3054 thread_store,
3055 templates.clone(),
3056 None,
3057 fake_fs.clone(),
3058 &mut cx.to_async(),
3059 )
3060 .await
3061 .unwrap();
3062 let connection = NativeAgentConnection(agent.clone());
3063
3064 // Create a thread using new_thread
3065 let connection_rc = Rc::new(connection.clone());
3066 let acp_thread = cx
3067 .update(|cx| connection_rc.new_session(project, cwd, cx))
3068 .await
3069 .expect("new_thread should succeed");
3070
3071 // Get the session_id from the AcpThread
3072 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3073
3074 // Test model_selector returns Some
3075 let selector_opt = connection.model_selector(&session_id);
3076 assert!(
3077 selector_opt.is_some(),
3078 "agent should always support ModelSelector"
3079 );
3080 let selector = selector_opt.unwrap();
3081
3082 // Test list_models
3083 let listed_models = cx
3084 .update(|cx| selector.list_models(cx))
3085 .await
3086 .expect("list_models should succeed");
3087 let AgentModelList::Grouped(listed_models) = listed_models else {
3088 panic!("Unexpected model list type");
3089 };
3090 assert!(!listed_models.is_empty(), "should have at least one model");
3091 assert_eq!(
3092 listed_models[&AgentModelGroupName("Fake".into())][0]
3093 .id
3094 .0
3095 .as_ref(),
3096 "fake/fake"
3097 );
3098
3099 // Test selected_model returns the default
3100 let model = cx
3101 .update(|cx| selector.selected_model(cx))
3102 .await
3103 .expect("selected_model should succeed");
3104 let model = cx
3105 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3106 .unwrap();
3107 let model = model.as_fake();
3108 assert_eq!(model.id().0, "fake", "should return default model");
3109
3110 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3111 cx.run_until_parked();
3112 model.send_last_completion_stream_text_chunk("def");
3113 cx.run_until_parked();
3114 acp_thread.read_with(cx, |thread, cx| {
3115 assert_eq!(
3116 thread.to_markdown(cx),
3117 indoc! {"
3118 ## User
3119
3120 abc
3121
3122 ## Assistant
3123
3124 def
3125
3126 "}
3127 )
3128 });
3129
3130 // Test cancel
3131 cx.update(|cx| connection.cancel(&session_id, cx));
3132 request.await.expect("prompt should fail gracefully");
3133
3134 // Explicitly close the session and drop the ACP thread.
3135 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3136 .await
3137 .unwrap();
3138 drop(acp_thread);
3139 let result = cx
3140 .update(|cx| {
3141 connection.prompt(
3142 Some(acp_thread::UserMessageId::new()),
3143 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3144 cx,
3145 )
3146 })
3147 .await;
3148 assert_eq!(
3149 result.as_ref().unwrap_err().to_string(),
3150 "Session not found",
3151 "unexpected result: {:?}",
3152 result
3153 );
3154}
3155
3156#[gpui::test]
3157async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3158 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3159 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
3160 let fake_model = model.as_fake();
3161
3162 let mut events = thread
3163 .update(cx, |thread, cx| {
3164 thread.send(UserMessageId::new(), ["Echo something"], cx)
3165 })
3166 .unwrap();
3167 cx.run_until_parked();
3168
3169 // Simulate streaming partial input.
3170 let input = json!({});
3171 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3172 LanguageModelToolUse {
3173 id: "1".into(),
3174 name: EchoTool::NAME.into(),
3175 raw_input: input.to_string(),
3176 input,
3177 is_input_complete: false,
3178 thought_signature: None,
3179 },
3180 ));
3181
3182 // Input streaming completed
3183 let input = json!({ "text": "Hello!" });
3184 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3185 LanguageModelToolUse {
3186 id: "1".into(),
3187 name: "echo".into(),
3188 raw_input: input.to_string(),
3189 input,
3190 is_input_complete: true,
3191 thought_signature: None,
3192 },
3193 ));
3194 fake_model.end_last_completion_stream();
3195 cx.run_until_parked();
3196
3197 let tool_call = expect_tool_call(&mut events).await;
3198 assert_eq!(
3199 tool_call,
3200 acp::ToolCall::new("1", "Echo")
3201 .raw_input(json!({}))
3202 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3203 );
3204 let update = expect_tool_call_update_fields(&mut events).await;
3205 assert_eq!(
3206 update,
3207 acp::ToolCallUpdate::new(
3208 "1",
3209 acp::ToolCallUpdateFields::new()
3210 .title("Echo")
3211 .kind(acp::ToolKind::Other)
3212 .raw_input(json!({ "text": "Hello!"}))
3213 )
3214 );
3215 let update = expect_tool_call_update_fields(&mut events).await;
3216 assert_eq!(
3217 update,
3218 acp::ToolCallUpdate::new(
3219 "1",
3220 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3221 )
3222 );
3223 let update = expect_tool_call_update_fields(&mut events).await;
3224 assert_eq!(
3225 update,
3226 acp::ToolCallUpdate::new(
3227 "1",
3228 acp::ToolCallUpdateFields::new()
3229 .status(acp::ToolCallStatus::Completed)
3230 .raw_output("Hello!")
3231 )
3232 );
3233}
3234
3235#[gpui::test]
3236async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3237 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3238 let fake_model = model.as_fake();
3239
3240 let mut events = thread
3241 .update(cx, |thread, cx| {
3242 thread.send(UserMessageId::new(), ["Hello!"], cx)
3243 })
3244 .unwrap();
3245 cx.run_until_parked();
3246
3247 fake_model.send_last_completion_stream_text_chunk("Hey!");
3248 fake_model.end_last_completion_stream();
3249
3250 let mut retry_events = Vec::new();
3251 while let Some(Ok(event)) = events.next().await {
3252 match event {
3253 ThreadEvent::Retry(retry_status) => {
3254 retry_events.push(retry_status);
3255 }
3256 ThreadEvent::Stop(..) => break,
3257 _ => {}
3258 }
3259 }
3260
3261 assert_eq!(retry_events.len(), 0);
3262 thread.read_with(cx, |thread, _cx| {
3263 assert_eq!(
3264 thread.to_markdown(),
3265 indoc! {"
3266 ## User
3267
3268 Hello!
3269
3270 ## Assistant
3271
3272 Hey!
3273 "}
3274 )
3275 });
3276}
3277
3278#[gpui::test]
3279async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3280 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3281 let fake_model = model.as_fake();
3282
3283 let mut events = thread
3284 .update(cx, |thread, cx| {
3285 thread.send(UserMessageId::new(), ["Hello!"], cx)
3286 })
3287 .unwrap();
3288 cx.run_until_parked();
3289
3290 fake_model.send_last_completion_stream_text_chunk("Hey,");
3291 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3292 provider: LanguageModelProviderName::new("Anthropic"),
3293 retry_after: Some(Duration::from_secs(3)),
3294 });
3295 fake_model.end_last_completion_stream();
3296
3297 cx.executor().advance_clock(Duration::from_secs(3));
3298 cx.run_until_parked();
3299
3300 fake_model.send_last_completion_stream_text_chunk("there!");
3301 fake_model.end_last_completion_stream();
3302 cx.run_until_parked();
3303
3304 let mut retry_events = Vec::new();
3305 while let Some(Ok(event)) = events.next().await {
3306 match event {
3307 ThreadEvent::Retry(retry_status) => {
3308 retry_events.push(retry_status);
3309 }
3310 ThreadEvent::Stop(..) => break,
3311 _ => {}
3312 }
3313 }
3314
3315 assert_eq!(retry_events.len(), 1);
3316 assert!(matches!(
3317 retry_events[0],
3318 acp_thread::RetryStatus { attempt: 1, .. }
3319 ));
3320 thread.read_with(cx, |thread, _cx| {
3321 assert_eq!(
3322 thread.to_markdown(),
3323 indoc! {"
3324 ## User
3325
3326 Hello!
3327
3328 ## Assistant
3329
3330 Hey,
3331
3332 [resume]
3333
3334 ## Assistant
3335
3336 there!
3337 "}
3338 )
3339 });
3340}
3341
3342#[gpui::test]
3343async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3344 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3345 let fake_model = model.as_fake();
3346
3347 let events = thread
3348 .update(cx, |thread, cx| {
3349 thread.add_tool(EchoTool, None);
3350 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3351 })
3352 .unwrap();
3353 cx.run_until_parked();
3354
3355 let tool_use_1 = LanguageModelToolUse {
3356 id: "tool_1".into(),
3357 name: EchoTool::NAME.into(),
3358 raw_input: json!({"text": "test"}).to_string(),
3359 input: json!({"text": "test"}),
3360 is_input_complete: true,
3361 thought_signature: None,
3362 };
3363 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3364 tool_use_1.clone(),
3365 ));
3366 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3367 provider: LanguageModelProviderName::new("Anthropic"),
3368 retry_after: Some(Duration::from_secs(3)),
3369 });
3370 fake_model.end_last_completion_stream();
3371
3372 cx.executor().advance_clock(Duration::from_secs(3));
3373 let completion = fake_model.pending_completions().pop().unwrap();
3374 assert_eq!(
3375 completion.messages[1..],
3376 vec![
3377 LanguageModelRequestMessage {
3378 role: Role::User,
3379 content: vec!["Call the echo tool!".into()],
3380 cache: false,
3381 reasoning_details: None,
3382 },
3383 LanguageModelRequestMessage {
3384 role: Role::Assistant,
3385 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3386 cache: false,
3387 reasoning_details: None,
3388 },
3389 LanguageModelRequestMessage {
3390 role: Role::User,
3391 content: vec![language_model::MessageContent::ToolResult(
3392 LanguageModelToolResult {
3393 tool_use_id: tool_use_1.id.clone(),
3394 tool_name: tool_use_1.name.clone(),
3395 is_error: false,
3396 content: "test".into(),
3397 output: Some("test".into())
3398 }
3399 )],
3400 cache: true,
3401 reasoning_details: None,
3402 },
3403 ]
3404 );
3405
3406 fake_model.send_last_completion_stream_text_chunk("Done");
3407 fake_model.end_last_completion_stream();
3408 cx.run_until_parked();
3409 events.collect::<Vec<_>>().await;
3410 thread.read_with(cx, |thread, _cx| {
3411 assert_eq!(
3412 thread.last_message(),
3413 Some(Message::Agent(AgentMessage {
3414 content: vec![AgentMessageContent::Text("Done".into())],
3415 tool_results: IndexMap::default(),
3416 reasoning_details: None,
3417 }))
3418 );
3419 })
3420}
3421
3422#[gpui::test]
3423async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3424 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3425 let fake_model = model.as_fake();
3426
3427 let mut events = thread
3428 .update(cx, |thread, cx| {
3429 thread.send(UserMessageId::new(), ["Hello!"], cx)
3430 })
3431 .unwrap();
3432 cx.run_until_parked();
3433
3434 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3435 fake_model.send_last_completion_stream_error(
3436 LanguageModelCompletionError::ServerOverloaded {
3437 provider: LanguageModelProviderName::new("Anthropic"),
3438 retry_after: Some(Duration::from_secs(3)),
3439 },
3440 );
3441 fake_model.end_last_completion_stream();
3442 cx.executor().advance_clock(Duration::from_secs(3));
3443 cx.run_until_parked();
3444 }
3445
3446 let mut errors = Vec::new();
3447 let mut retry_events = Vec::new();
3448 while let Some(event) = events.next().await {
3449 match event {
3450 Ok(ThreadEvent::Retry(retry_status)) => {
3451 retry_events.push(retry_status);
3452 }
3453 Ok(ThreadEvent::Stop(..)) => break,
3454 Err(error) => errors.push(error),
3455 _ => {}
3456 }
3457 }
3458
3459 assert_eq!(
3460 retry_events.len(),
3461 crate::thread::MAX_RETRY_ATTEMPTS as usize
3462 );
3463 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3464 assert_eq!(retry_events[i].attempt, i + 1);
3465 }
3466 assert_eq!(errors.len(), 1);
3467 let error = errors[0]
3468 .downcast_ref::<LanguageModelCompletionError>()
3469 .unwrap();
3470 assert!(matches!(
3471 error,
3472 LanguageModelCompletionError::ServerOverloaded { .. }
3473 ));
3474}
3475
3476/// Filters out the stop events for asserting against in tests
3477fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3478 result_events
3479 .into_iter()
3480 .filter_map(|event| match event.unwrap() {
3481 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3482 _ => None,
3483 })
3484 .collect()
3485}
3486
3487struct ThreadTest {
3488 model: Arc<dyn LanguageModel>,
3489 thread: Entity<Thread>,
3490 project_context: Entity<ProjectContext>,
3491 context_server_store: Entity<ContextServerStore>,
3492 fs: Arc<FakeFs>,
3493}
3494
3495enum TestModel {
3496 Sonnet4,
3497 Fake,
3498}
3499
3500impl TestModel {
3501 fn id(&self) -> LanguageModelId {
3502 match self {
3503 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3504 TestModel::Fake => unreachable!(),
3505 }
3506 }
3507}
3508
3509async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3510 cx.executor().allow_parking();
3511
3512 let fs = FakeFs::new(cx.background_executor.clone());
3513 fs.create_dir(paths::settings_file().parent().unwrap())
3514 .await
3515 .unwrap();
3516 fs.insert_file(
3517 paths::settings_file(),
3518 json!({
3519 "agent": {
3520 "default_profile": "test-profile",
3521 "profiles": {
3522 "test-profile": {
3523 "name": "Test Profile",
3524 "tools": {
3525 EchoTool::NAME: true,
3526 DelayTool::NAME: true,
3527 WordListTool::NAME: true,
3528 ToolRequiringPermission::NAME: true,
3529 InfiniteTool::NAME: true,
3530 CancellationAwareTool::NAME: true,
3531 (TerminalTool::NAME): true,
3532 }
3533 }
3534 }
3535 }
3536 })
3537 .to_string()
3538 .into_bytes(),
3539 )
3540 .await;
3541
3542 cx.update(|cx| {
3543 settings::init(cx);
3544
3545 match model {
3546 TestModel::Fake => {}
3547 TestModel::Sonnet4 => {
3548 gpui_tokio::init(cx);
3549 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3550 cx.set_http_client(Arc::new(http_client));
3551 let client = Client::production(cx);
3552 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3553 language_model::init(client.clone(), cx);
3554 language_models::init(user_store, client.clone(), cx);
3555 }
3556 };
3557
3558 watch_settings(fs.clone(), cx);
3559 });
3560
3561 let templates = Templates::new();
3562
3563 fs.insert_tree(path!("/test"), json!({})).await;
3564 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3565
3566 let model = cx
3567 .update(|cx| {
3568 if let TestModel::Fake = model {
3569 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3570 } else {
3571 let model_id = model.id();
3572 let models = LanguageModelRegistry::read_global(cx);
3573 let model = models
3574 .available_models(cx)
3575 .find(|model| model.id() == model_id)
3576 .unwrap();
3577
3578 let provider = models.provider(&model.provider_id()).unwrap();
3579 let authenticated = provider.authenticate(cx);
3580
3581 cx.spawn(async move |_cx| {
3582 authenticated.await.unwrap();
3583 model
3584 })
3585 }
3586 })
3587 .await;
3588
3589 let project_context = cx.new(|_cx| ProjectContext::default());
3590 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3591 let context_server_registry =
3592 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3593 let thread = cx.new(|cx| {
3594 Thread::new(
3595 project,
3596 project_context.clone(),
3597 context_server_registry,
3598 templates,
3599 Some(model.clone()),
3600 cx,
3601 )
3602 });
3603 ThreadTest {
3604 model,
3605 thread,
3606 project_context,
3607 context_server_store,
3608 fs,
3609 }
3610}
3611
3612#[cfg(test)]
3613#[ctor::ctor]
3614fn init_logger() {
3615 if std::env::var("RUST_LOG").is_ok() {
3616 env_logger::init();
3617 }
3618}
3619
3620fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3621 let fs = fs.clone();
3622 cx.spawn({
3623 async move |cx| {
3624 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3625 cx.background_executor(),
3626 fs,
3627 paths::settings_file().clone(),
3628 );
3629 let _watcher_task = watcher_task;
3630
3631 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3632 cx.update(|cx| {
3633 SettingsStore::update_global(cx, |settings, cx| {
3634 settings.set_user_settings(&new_settings_content, cx)
3635 })
3636 })
3637 .ok();
3638 }
3639 }
3640 })
3641 .detach();
3642}
3643
3644fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3645 completion
3646 .tools
3647 .iter()
3648 .map(|tool| tool.name.clone())
3649 .collect()
3650}
3651
3652fn setup_context_server(
3653 name: &'static str,
3654 tools: Vec<context_server::types::Tool>,
3655 context_server_store: &Entity<ContextServerStore>,
3656 cx: &mut TestAppContext,
3657) -> mpsc::UnboundedReceiver<(
3658 context_server::types::CallToolParams,
3659 oneshot::Sender<context_server::types::CallToolResponse>,
3660)> {
3661 cx.update(|cx| {
3662 let mut settings = ProjectSettings::get_global(cx).clone();
3663 settings.context_servers.insert(
3664 name.into(),
3665 project::project_settings::ContextServerSettings::Stdio {
3666 enabled: true,
3667 remote: false,
3668 command: ContextServerCommand {
3669 path: "somebinary".into(),
3670 args: Vec::new(),
3671 env: None,
3672 timeout: None,
3673 },
3674 },
3675 );
3676 ProjectSettings::override_global(settings, cx);
3677 });
3678
3679 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3680 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3681 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3682 context_server::types::InitializeResponse {
3683 protocol_version: context_server::types::ProtocolVersion(
3684 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3685 ),
3686 server_info: context_server::types::Implementation {
3687 name: name.into(),
3688 version: "1.0.0".to_string(),
3689 },
3690 capabilities: context_server::types::ServerCapabilities {
3691 tools: Some(context_server::types::ToolsCapabilities {
3692 list_changed: Some(true),
3693 }),
3694 ..Default::default()
3695 },
3696 meta: None,
3697 }
3698 })
3699 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3700 let tools = tools.clone();
3701 async move {
3702 context_server::types::ListToolsResponse {
3703 tools,
3704 next_cursor: None,
3705 meta: None,
3706 }
3707 }
3708 })
3709 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3710 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3711 async move {
3712 let (response_tx, response_rx) = oneshot::channel();
3713 mcp_tool_calls_tx
3714 .unbounded_send((params, response_tx))
3715 .unwrap();
3716 response_rx.await.unwrap()
3717 }
3718 });
3719 context_server_store.update(cx, |store, cx| {
3720 store.start_server(
3721 Arc::new(ContextServer::new(
3722 ContextServerId(name.into()),
3723 Arc::new(fake_transport),
3724 )),
3725 cx,
3726 );
3727 });
3728 cx.run_until_parked();
3729 mcp_tool_calls_rx
3730}
3731
3732#[gpui::test]
3733async fn test_tokens_before_message(cx: &mut TestAppContext) {
3734 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3735 let fake_model = model.as_fake();
3736
3737 // First message
3738 let message_1_id = UserMessageId::new();
3739 thread
3740 .update(cx, |thread, cx| {
3741 thread.send(message_1_id.clone(), ["First message"], cx)
3742 })
3743 .unwrap();
3744 cx.run_until_parked();
3745
3746 // Before any response, tokens_before_message should return None for first message
3747 thread.read_with(cx, |thread, _| {
3748 assert_eq!(
3749 thread.tokens_before_message(&message_1_id),
3750 None,
3751 "First message should have no tokens before it"
3752 );
3753 });
3754
3755 // Complete first message with usage
3756 fake_model.send_last_completion_stream_text_chunk("Response 1");
3757 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3758 language_model::TokenUsage {
3759 input_tokens: 100,
3760 output_tokens: 50,
3761 cache_creation_input_tokens: 0,
3762 cache_read_input_tokens: 0,
3763 },
3764 ));
3765 fake_model.end_last_completion_stream();
3766 cx.run_until_parked();
3767
3768 // First message still has no tokens before it
3769 thread.read_with(cx, |thread, _| {
3770 assert_eq!(
3771 thread.tokens_before_message(&message_1_id),
3772 None,
3773 "First message should still have no tokens before it after response"
3774 );
3775 });
3776
3777 // Second message
3778 let message_2_id = UserMessageId::new();
3779 thread
3780 .update(cx, |thread, cx| {
3781 thread.send(message_2_id.clone(), ["Second message"], cx)
3782 })
3783 .unwrap();
3784 cx.run_until_parked();
3785
3786 // Second message should have first message's input tokens before it
3787 thread.read_with(cx, |thread, _| {
3788 assert_eq!(
3789 thread.tokens_before_message(&message_2_id),
3790 Some(100),
3791 "Second message should have 100 tokens before it (from first request)"
3792 );
3793 });
3794
3795 // Complete second message
3796 fake_model.send_last_completion_stream_text_chunk("Response 2");
3797 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3798 language_model::TokenUsage {
3799 input_tokens: 250, // Total for this request (includes previous context)
3800 output_tokens: 75,
3801 cache_creation_input_tokens: 0,
3802 cache_read_input_tokens: 0,
3803 },
3804 ));
3805 fake_model.end_last_completion_stream();
3806 cx.run_until_parked();
3807
3808 // Third message
3809 let message_3_id = UserMessageId::new();
3810 thread
3811 .update(cx, |thread, cx| {
3812 thread.send(message_3_id.clone(), ["Third message"], cx)
3813 })
3814 .unwrap();
3815 cx.run_until_parked();
3816
3817 // Third message should have second message's input tokens (250) before it
3818 thread.read_with(cx, |thread, _| {
3819 assert_eq!(
3820 thread.tokens_before_message(&message_3_id),
3821 Some(250),
3822 "Third message should have 250 tokens before it (from second request)"
3823 );
3824 // Second message should still have 100
3825 assert_eq!(
3826 thread.tokens_before_message(&message_2_id),
3827 Some(100),
3828 "Second message should still have 100 tokens before it"
3829 );
3830 // First message still has none
3831 assert_eq!(
3832 thread.tokens_before_message(&message_1_id),
3833 None,
3834 "First message should still have no tokens before it"
3835 );
3836 });
3837}
3838
3839#[gpui::test]
3840async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3841 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3842 let fake_model = model.as_fake();
3843
3844 // Set up three messages with responses
3845 let message_1_id = UserMessageId::new();
3846 thread
3847 .update(cx, |thread, cx| {
3848 thread.send(message_1_id.clone(), ["Message 1"], cx)
3849 })
3850 .unwrap();
3851 cx.run_until_parked();
3852 fake_model.send_last_completion_stream_text_chunk("Response 1");
3853 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3854 language_model::TokenUsage {
3855 input_tokens: 100,
3856 output_tokens: 50,
3857 cache_creation_input_tokens: 0,
3858 cache_read_input_tokens: 0,
3859 },
3860 ));
3861 fake_model.end_last_completion_stream();
3862 cx.run_until_parked();
3863
3864 let message_2_id = UserMessageId::new();
3865 thread
3866 .update(cx, |thread, cx| {
3867 thread.send(message_2_id.clone(), ["Message 2"], cx)
3868 })
3869 .unwrap();
3870 cx.run_until_parked();
3871 fake_model.send_last_completion_stream_text_chunk("Response 2");
3872 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3873 language_model::TokenUsage {
3874 input_tokens: 250,
3875 output_tokens: 75,
3876 cache_creation_input_tokens: 0,
3877 cache_read_input_tokens: 0,
3878 },
3879 ));
3880 fake_model.end_last_completion_stream();
3881 cx.run_until_parked();
3882
3883 // Verify initial state
3884 thread.read_with(cx, |thread, _| {
3885 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3886 });
3887
3888 // Truncate at message 2 (removes message 2 and everything after)
3889 thread
3890 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3891 .unwrap();
3892 cx.run_until_parked();
3893
3894 // After truncation, message_2_id no longer exists, so lookup should return None
3895 thread.read_with(cx, |thread, _| {
3896 assert_eq!(
3897 thread.tokens_before_message(&message_2_id),
3898 None,
3899 "After truncation, message 2 no longer exists"
3900 );
3901 // Message 1 still exists but has no tokens before it
3902 assert_eq!(
3903 thread.tokens_before_message(&message_1_id),
3904 None,
3905 "First message still has no tokens before it"
3906 );
3907 });
3908}
3909
3910#[gpui::test]
3911async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3912 init_test(cx);
3913
3914 let fs = FakeFs::new(cx.executor());
3915 fs.insert_tree("/root", json!({})).await;
3916 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3917
3918 // Test 1: Deny rule blocks command
3919 {
3920 let environment = Rc::new(cx.update(|cx| {
3921 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3922 }));
3923
3924 cx.update(|cx| {
3925 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3926 settings.tool_permissions.tools.insert(
3927 TerminalTool::NAME.into(),
3928 agent_settings::ToolRules {
3929 default: Some(settings::ToolPermissionMode::Confirm),
3930 always_allow: vec![],
3931 always_deny: vec![
3932 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3933 ],
3934 always_confirm: vec![],
3935 invalid_patterns: vec![],
3936 },
3937 );
3938 agent_settings::AgentSettings::override_global(settings, cx);
3939 });
3940
3941 #[allow(clippy::arc_with_non_send_sync)]
3942 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3943 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3944
3945 let task = cx.update(|cx| {
3946 tool.run(
3947 crate::TerminalToolInput {
3948 command: "rm -rf /".to_string(),
3949 cd: ".".to_string(),
3950 timeout_ms: None,
3951 },
3952 event_stream,
3953 cx,
3954 )
3955 });
3956
3957 let result = task.await;
3958 assert!(
3959 result.is_err(),
3960 "expected command to be blocked by deny rule"
3961 );
3962 let err_msg = result.unwrap_err().to_string().to_lowercase();
3963 assert!(
3964 err_msg.contains("blocked"),
3965 "error should mention the command was blocked"
3966 );
3967 }
3968
3969 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
3970 {
3971 let environment = Rc::new(cx.update(|cx| {
3972 FakeThreadEnvironment::default()
3973 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3974 }));
3975
3976 cx.update(|cx| {
3977 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3978 settings.tool_permissions.tools.insert(
3979 TerminalTool::NAME.into(),
3980 agent_settings::ToolRules {
3981 default: Some(settings::ToolPermissionMode::Deny),
3982 always_allow: vec![
3983 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3984 ],
3985 always_deny: vec![],
3986 always_confirm: vec![],
3987 invalid_patterns: vec![],
3988 },
3989 );
3990 agent_settings::AgentSettings::override_global(settings, cx);
3991 });
3992
3993 #[allow(clippy::arc_with_non_send_sync)]
3994 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3995 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3996
3997 let task = cx.update(|cx| {
3998 tool.run(
3999 crate::TerminalToolInput {
4000 command: "echo hello".to_string(),
4001 cd: ".".to_string(),
4002 timeout_ms: None,
4003 },
4004 event_stream,
4005 cx,
4006 )
4007 });
4008
4009 let update = rx.expect_update_fields().await;
4010 assert!(
4011 update.content.iter().any(|blocks| {
4012 blocks
4013 .iter()
4014 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4015 }),
4016 "expected terminal content (allow rule should skip confirmation and override default deny)"
4017 );
4018
4019 let result = task.await;
4020 assert!(
4021 result.is_ok(),
4022 "expected command to succeed without confirmation"
4023 );
4024 }
4025
4026 // Test 3: global default: allow does NOT override always_confirm patterns
4027 {
4028 let environment = Rc::new(cx.update(|cx| {
4029 FakeThreadEnvironment::default()
4030 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4031 }));
4032
4033 cx.update(|cx| {
4034 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4035 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4036 settings.tool_permissions.tools.insert(
4037 TerminalTool::NAME.into(),
4038 agent_settings::ToolRules {
4039 default: Some(settings::ToolPermissionMode::Allow),
4040 always_allow: vec![],
4041 always_deny: vec![],
4042 always_confirm: vec![
4043 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4044 ],
4045 invalid_patterns: vec![],
4046 },
4047 );
4048 agent_settings::AgentSettings::override_global(settings, cx);
4049 });
4050
4051 #[allow(clippy::arc_with_non_send_sync)]
4052 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4053 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4054
4055 let _task = cx.update(|cx| {
4056 tool.run(
4057 crate::TerminalToolInput {
4058 command: "sudo rm file".to_string(),
4059 cd: ".".to_string(),
4060 timeout_ms: None,
4061 },
4062 event_stream,
4063 cx,
4064 )
4065 });
4066
4067 // With global default: allow, confirm patterns are still respected
4068 // The expect_authorization() call will panic if no authorization is requested,
4069 // which validates that the confirm pattern still triggers confirmation
4070 let _auth = rx.expect_authorization().await;
4071
4072 drop(_task);
4073 }
4074
4075 // Test 4: tool-specific default: deny is respected even with global default: allow
4076 {
4077 let environment = Rc::new(cx.update(|cx| {
4078 FakeThreadEnvironment::default()
4079 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4080 }));
4081
4082 cx.update(|cx| {
4083 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4084 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4085 settings.tool_permissions.tools.insert(
4086 TerminalTool::NAME.into(),
4087 agent_settings::ToolRules {
4088 default: Some(settings::ToolPermissionMode::Deny),
4089 always_allow: vec![],
4090 always_deny: vec![],
4091 always_confirm: vec![],
4092 invalid_patterns: vec![],
4093 },
4094 );
4095 agent_settings::AgentSettings::override_global(settings, cx);
4096 });
4097
4098 #[allow(clippy::arc_with_non_send_sync)]
4099 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4100 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4101
4102 let task = cx.update(|cx| {
4103 tool.run(
4104 crate::TerminalToolInput {
4105 command: "echo hello".to_string(),
4106 cd: ".".to_string(),
4107 timeout_ms: None,
4108 },
4109 event_stream,
4110 cx,
4111 )
4112 });
4113
4114 // tool-specific default: deny is respected even with global default: allow
4115 let result = task.await;
4116 assert!(
4117 result.is_err(),
4118 "expected command to be blocked by tool-specific deny default"
4119 );
4120 let err_msg = result.unwrap_err().to_string().to_lowercase();
4121 assert!(
4122 err_msg.contains("disabled"),
4123 "error should mention the tool is disabled, got: {err_msg}"
4124 );
4125 }
4126}
4127
4128#[gpui::test]
4129async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4130 init_test(cx);
4131
4132 cx.update(|cx| {
4133 cx.update_flags(true, vec!["subagents".to_string()]);
4134 });
4135
4136 let fs = FakeFs::new(cx.executor());
4137 fs.insert_tree(path!("/test"), json!({})).await;
4138 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4139 let project_context = cx.new(|_cx| ProjectContext::default());
4140 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4141 let context_server_registry =
4142 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4143 let model = Arc::new(FakeLanguageModel::default());
4144
4145 let environment = Rc::new(cx.update(|cx| {
4146 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4147 }));
4148
4149 let thread = cx.new(|cx| {
4150 let mut thread = Thread::new(
4151 project.clone(),
4152 project_context,
4153 context_server_registry,
4154 Templates::new(),
4155 Some(model),
4156 cx,
4157 );
4158 thread.add_default_tools(None, environment, cx);
4159 thread
4160 });
4161
4162 thread.read_with(cx, |thread, _| {
4163 assert!(
4164 thread.has_registered_tool(SubagentTool::NAME),
4165 "subagent tool should be present when feature flag is enabled"
4166 );
4167 });
4168}
4169
4170#[gpui::test]
4171async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4172 init_test(cx);
4173
4174 cx.update(|cx| {
4175 cx.update_flags(true, vec!["subagents".to_string()]);
4176 });
4177
4178 let fs = FakeFs::new(cx.executor());
4179 fs.insert_tree(path!("/test"), json!({})).await;
4180 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4181 let project_context = cx.new(|_cx| ProjectContext::default());
4182 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4183 let context_server_registry =
4184 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4185 let model = Arc::new(FakeLanguageModel::default());
4186
4187 let parent_thread = cx.new(|cx| {
4188 Thread::new(
4189 project.clone(),
4190 project_context,
4191 context_server_registry,
4192 Templates::new(),
4193 Some(model.clone()),
4194 cx,
4195 )
4196 });
4197
4198 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4199 subagent_thread.read_with(cx, |subagent_thread, cx| {
4200 assert!(subagent_thread.is_subagent());
4201 assert_eq!(subagent_thread.depth(), 1);
4202 assert_eq!(
4203 subagent_thread.model().map(|model| model.id()),
4204 Some(model.id())
4205 );
4206 assert_eq!(
4207 subagent_thread.parent_thread_id(),
4208 Some(parent_thread.read(cx).id().clone())
4209 );
4210 });
4211}
4212
4213#[gpui::test]
4214async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4215 init_test(cx);
4216
4217 cx.update(|cx| {
4218 cx.update_flags(true, vec!["subagents".to_string()]);
4219 });
4220
4221 let fs = FakeFs::new(cx.executor());
4222 fs.insert_tree(path!("/test"), json!({})).await;
4223 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4224 let project_context = cx.new(|_cx| ProjectContext::default());
4225 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4226 let context_server_registry =
4227 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4228 let model = Arc::new(FakeLanguageModel::default());
4229 let environment = Rc::new(cx.update(|cx| {
4230 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4231 }));
4232
4233 let deep_parent_thread = cx.new(|cx| {
4234 let mut thread = Thread::new(
4235 project.clone(),
4236 project_context,
4237 context_server_registry,
4238 Templates::new(),
4239 Some(model.clone()),
4240 cx,
4241 );
4242 thread.set_subagent_context(SubagentContext {
4243 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4244 depth: MAX_SUBAGENT_DEPTH - 1,
4245 });
4246 thread
4247 });
4248 let deep_subagent_thread = cx.new(|cx| {
4249 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4250 thread.add_default_tools(None, environment, cx);
4251 thread
4252 });
4253
4254 deep_subagent_thread.read_with(cx, |thread, _| {
4255 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4256 assert!(
4257 !thread.has_registered_tool(SubagentTool::NAME),
4258 "subagent tool should not be present at max depth"
4259 );
4260 });
4261}
4262
4263#[gpui::test]
4264async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4265 init_test(cx);
4266
4267 cx.update(|cx| {
4268 cx.update_flags(true, vec!["subagents".to_string()]);
4269 });
4270
4271 let fs = FakeFs::new(cx.executor());
4272 fs.insert_tree(path!("/test"), json!({})).await;
4273 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4274 let project_context = cx.new(|_cx| ProjectContext::default());
4275 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4276 let context_server_registry =
4277 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4278 let model = Arc::new(FakeLanguageModel::default());
4279
4280 let parent = cx.new(|cx| {
4281 Thread::new(
4282 project.clone(),
4283 project_context.clone(),
4284 context_server_registry.clone(),
4285 Templates::new(),
4286 Some(model.clone()),
4287 cx,
4288 )
4289 });
4290
4291 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4292
4293 parent.update(cx, |thread, _cx| {
4294 thread.register_running_subagent(subagent.downgrade());
4295 });
4296
4297 subagent
4298 .update(cx, |thread, cx| {
4299 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4300 })
4301 .unwrap();
4302 cx.run_until_parked();
4303
4304 subagent.read_with(cx, |thread, _| {
4305 assert!(!thread.is_turn_complete(), "subagent should be running");
4306 });
4307
4308 parent.update(cx, |thread, cx| {
4309 thread.cancel(cx).detach();
4310 });
4311
4312 subagent.read_with(cx, |thread, _| {
4313 assert!(
4314 thread.is_turn_complete(),
4315 "subagent should be cancelled when parent cancels"
4316 );
4317 });
4318}
4319
4320#[gpui::test]
4321async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4322 // This test verifies that the subagent tool properly handles user cancellation
4323 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4324 init_test(cx);
4325 always_allow_tools(cx);
4326
4327 cx.update(|cx| {
4328 cx.update_flags(true, vec!["subagents".to_string()]);
4329 });
4330
4331 let fs = FakeFs::new(cx.executor());
4332 fs.insert_tree(path!("/test"), json!({})).await;
4333 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4334 let project_context = cx.new(|_cx| ProjectContext::default());
4335 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4336 let context_server_registry =
4337 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4338 let model = Arc::new(FakeLanguageModel::default());
4339 let environment = Rc::new(cx.update(|cx| {
4340 FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
4341 }));
4342
4343 let parent = cx.new(|cx| {
4344 Thread::new(
4345 project.clone(),
4346 project_context.clone(),
4347 context_server_registry.clone(),
4348 Templates::new(),
4349 Some(model.clone()),
4350 cx,
4351 )
4352 });
4353
4354 #[allow(clippy::arc_with_non_send_sync)]
4355 let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
4356
4357 let (event_stream, _rx, mut cancellation_tx) =
4358 crate::ToolCallEventStream::test_with_cancellation();
4359
4360 // Start the subagent tool
4361 let task = cx.update(|cx| {
4362 tool.run(
4363 SubagentToolInput {
4364 label: "Long running task".to_string(),
4365 task_prompt: "Do a very long task that takes forever".to_string(),
4366 summary_prompt: "Summarize".to_string(),
4367 timeout_ms: None,
4368 allowed_tools: None,
4369 },
4370 event_stream.clone(),
4371 cx,
4372 )
4373 });
4374
4375 cx.run_until_parked();
4376
4377 // Signal cancellation via the event stream
4378 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4379
4380 // The task should complete promptly with a cancellation error
4381 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4382 let result = futures::select! {
4383 result = task.fuse() => result,
4384 _ = timeout.fuse() => {
4385 panic!("subagent tool did not respond to cancellation within timeout");
4386 }
4387 };
4388
4389 // Verify we got a cancellation error
4390 let err = result.unwrap_err();
4391 assert!(
4392 err.to_string().contains("cancelled by user"),
4393 "expected cancellation error, got: {}",
4394 err
4395 );
4396}
4397
4398#[gpui::test]
4399async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4400 init_test(cx);
4401 always_allow_tools(cx);
4402
4403 cx.update(|cx| {
4404 cx.update_flags(true, vec!["subagents".to_string()]);
4405 });
4406
4407 let fs = FakeFs::new(cx.executor());
4408 fs.insert_tree(path!("/test"), json!({})).await;
4409 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4410 let project_context = cx.new(|_cx| ProjectContext::default());
4411 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4412 let context_server_registry =
4413 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4414 cx.update(LanguageModelRegistry::test);
4415 let model = Arc::new(FakeLanguageModel::default());
4416 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4417 let native_agent = NativeAgent::new(
4418 project.clone(),
4419 thread_store,
4420 Templates::new(),
4421 None,
4422 fs,
4423 &mut cx.to_async(),
4424 )
4425 .await
4426 .unwrap();
4427 let parent_thread = cx.new(|cx| {
4428 Thread::new(
4429 project.clone(),
4430 project_context,
4431 context_server_registry,
4432 Templates::new(),
4433 Some(model.clone()),
4434 cx,
4435 )
4436 });
4437
4438 let mut handles = Vec::new();
4439 for _ in 0..MAX_PARALLEL_SUBAGENTS {
4440 let handle = cx
4441 .update(|cx| {
4442 NativeThreadEnvironment::create_subagent_thread(
4443 native_agent.downgrade(),
4444 parent_thread.clone(),
4445 "some title".to_string(),
4446 "some task".to_string(),
4447 None,
4448 None,
4449 cx,
4450 )
4451 })
4452 .expect("Expected to be able to create subagent thread");
4453 handles.push(handle);
4454 }
4455
4456 let result = cx.update(|cx| {
4457 NativeThreadEnvironment::create_subagent_thread(
4458 native_agent.downgrade(),
4459 parent_thread.clone(),
4460 "some title".to_string(),
4461 "some task".to_string(),
4462 None,
4463 None,
4464 cx,
4465 )
4466 });
4467 assert!(result.is_err());
4468 assert_eq!(
4469 result.err().unwrap().to_string(),
4470 format!(
4471 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
4472 MAX_PARALLEL_SUBAGENTS
4473 )
4474 );
4475}
4476
4477#[gpui::test]
4478async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4479 init_test(cx);
4480
4481 always_allow_tools(cx);
4482
4483 cx.update(|cx| {
4484 cx.update_flags(true, vec!["subagents".to_string()]);
4485 });
4486
4487 let fs = FakeFs::new(cx.executor());
4488 fs.insert_tree(path!("/test"), json!({})).await;
4489 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4490 let project_context = cx.new(|_cx| ProjectContext::default());
4491 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4492 let context_server_registry =
4493 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4494 cx.update(LanguageModelRegistry::test);
4495 let model = Arc::new(FakeLanguageModel::default());
4496 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4497 let native_agent = NativeAgent::new(
4498 project.clone(),
4499 thread_store,
4500 Templates::new(),
4501 None,
4502 fs,
4503 &mut cx.to_async(),
4504 )
4505 .await
4506 .unwrap();
4507 let parent_thread = cx.new(|cx| {
4508 Thread::new(
4509 project.clone(),
4510 project_context,
4511 context_server_registry,
4512 Templates::new(),
4513 Some(model.clone()),
4514 cx,
4515 )
4516 });
4517
4518 let subagent_handle = cx
4519 .update(|cx| {
4520 NativeThreadEnvironment::create_subagent_thread(
4521 native_agent.downgrade(),
4522 parent_thread.clone(),
4523 "some title".to_string(),
4524 "task prompt".to_string(),
4525 Some(Duration::from_millis(10)),
4526 None,
4527 cx,
4528 )
4529 })
4530 .expect("Failed to create subagent");
4531
4532 let summary_task =
4533 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4534
4535 cx.run_until_parked();
4536
4537 {
4538 let messages = model.pending_completions().last().unwrap().messages.clone();
4539 // Ensure that model received a system prompt
4540 assert_eq!(messages[0].role, Role::System);
4541 // Ensure that model received a task prompt
4542 assert_eq!(messages[1].role, Role::User);
4543 assert_eq!(
4544 messages[1].content,
4545 vec![MessageContent::Text("task prompt".to_string())]
4546 );
4547 }
4548
4549 model.send_last_completion_stream_text_chunk("Some task response...");
4550 model.end_last_completion_stream();
4551
4552 cx.run_until_parked();
4553
4554 {
4555 let messages = model.pending_completions().last().unwrap().messages.clone();
4556 assert_eq!(messages[2].role, Role::Assistant);
4557 assert_eq!(
4558 messages[2].content,
4559 vec![MessageContent::Text("Some task response...".to_string())]
4560 );
4561 // Ensure that model received a summary prompt
4562 assert_eq!(messages[3].role, Role::User);
4563 assert_eq!(
4564 messages[3].content,
4565 vec![MessageContent::Text("summary prompt".to_string())]
4566 );
4567 }
4568
4569 model.send_last_completion_stream_text_chunk("Some summary...");
4570 model.end_last_completion_stream();
4571
4572 let result = summary_task.await;
4573 assert_eq!(result.unwrap(), "Some summary...\n");
4574}
4575
4576#[gpui::test]
4577async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4578 cx: &mut TestAppContext,
4579) {
4580 init_test(cx);
4581
4582 always_allow_tools(cx);
4583
4584 cx.update(|cx| {
4585 cx.update_flags(true, vec!["subagents".to_string()]);
4586 });
4587
4588 let fs = FakeFs::new(cx.executor());
4589 fs.insert_tree(path!("/test"), json!({})).await;
4590 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4591 let project_context = cx.new(|_cx| ProjectContext::default());
4592 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4593 let context_server_registry =
4594 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4595 cx.update(LanguageModelRegistry::test);
4596 let model = Arc::new(FakeLanguageModel::default());
4597 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4598 let native_agent = NativeAgent::new(
4599 project.clone(),
4600 thread_store,
4601 Templates::new(),
4602 None,
4603 fs,
4604 &mut cx.to_async(),
4605 )
4606 .await
4607 .unwrap();
4608 let parent_thread = cx.new(|cx| {
4609 Thread::new(
4610 project.clone(),
4611 project_context,
4612 context_server_registry,
4613 Templates::new(),
4614 Some(model.clone()),
4615 cx,
4616 )
4617 });
4618
4619 let subagent_handle = cx
4620 .update(|cx| {
4621 NativeThreadEnvironment::create_subagent_thread(
4622 native_agent.downgrade(),
4623 parent_thread.clone(),
4624 "some title".to_string(),
4625 "task prompt".to_string(),
4626 Some(Duration::from_millis(100)),
4627 None,
4628 cx,
4629 )
4630 })
4631 .expect("Failed to create subagent");
4632
4633 let summary_task =
4634 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4635
4636 cx.run_until_parked();
4637
4638 {
4639 let messages = model.pending_completions().last().unwrap().messages.clone();
4640 // Ensure that model received a system prompt
4641 assert_eq!(messages[0].role, Role::System);
4642 // Ensure that model received a task prompt
4643 assert_eq!(
4644 messages[1].content,
4645 vec![MessageContent::Text("task prompt".to_string())]
4646 );
4647 }
4648
4649 // Don't complete the initial model stream — let the timeout expire instead.
4650 cx.executor().advance_clock(Duration::from_millis(200));
4651 cx.run_until_parked();
4652
4653 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
4654 // instead of the summary_prompt.
4655 {
4656 let messages = model.pending_completions().last().unwrap().messages.clone();
4657 let last_user_message = messages
4658 .iter()
4659 .rev()
4660 .find(|m| m.role == Role::User)
4661 .unwrap();
4662 assert_eq!(
4663 last_user_message.content,
4664 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
4665 );
4666 }
4667
4668 model.send_last_completion_stream_text_chunk("Some context low response...");
4669 model.end_last_completion_stream();
4670
4671 let result = summary_task.await;
4672 assert_eq!(result.unwrap(), "Some context low response...\n");
4673}
4674
4675#[gpui::test]
4676async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
4677 init_test(cx);
4678
4679 always_allow_tools(cx);
4680
4681 cx.update(|cx| {
4682 cx.update_flags(true, vec!["subagents".to_string()]);
4683 });
4684
4685 let fs = FakeFs::new(cx.executor());
4686 fs.insert_tree(path!("/test"), json!({})).await;
4687 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4688 let project_context = cx.new(|_cx| ProjectContext::default());
4689 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4690 let context_server_registry =
4691 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4692 cx.update(LanguageModelRegistry::test);
4693 let model = Arc::new(FakeLanguageModel::default());
4694 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4695 let native_agent = NativeAgent::new(
4696 project.clone(),
4697 thread_store,
4698 Templates::new(),
4699 None,
4700 fs,
4701 &mut cx.to_async(),
4702 )
4703 .await
4704 .unwrap();
4705 let parent_thread = cx.new(|cx| {
4706 let mut thread = Thread::new(
4707 project.clone(),
4708 project_context,
4709 context_server_registry,
4710 Templates::new(),
4711 Some(model.clone()),
4712 cx,
4713 );
4714 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4715 thread.add_tool(GrepTool::new(project.clone()), None);
4716 thread
4717 });
4718
4719 let _subagent_handle = cx
4720 .update(|cx| {
4721 NativeThreadEnvironment::create_subagent_thread(
4722 native_agent.downgrade(),
4723 parent_thread.clone(),
4724 "some title".to_string(),
4725 "task prompt".to_string(),
4726 Some(Duration::from_millis(10)),
4727 None,
4728 cx,
4729 )
4730 })
4731 .expect("Failed to create subagent");
4732
4733 cx.run_until_parked();
4734
4735 let tools = model
4736 .pending_completions()
4737 .last()
4738 .unwrap()
4739 .tools
4740 .iter()
4741 .map(|tool| tool.name.clone())
4742 .collect::<Vec<_>>();
4743 assert_eq!(tools.len(), 2);
4744 assert!(tools.contains(&"grep".to_string()));
4745 assert!(tools.contains(&"list_directory".to_string()));
4746}
4747
4748#[gpui::test]
4749async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
4750 init_test(cx);
4751
4752 always_allow_tools(cx);
4753
4754 cx.update(|cx| {
4755 cx.update_flags(true, vec!["subagents".to_string()]);
4756 });
4757
4758 let fs = FakeFs::new(cx.executor());
4759 fs.insert_tree(path!("/test"), json!({})).await;
4760 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4761 let project_context = cx.new(|_cx| ProjectContext::default());
4762 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4763 let context_server_registry =
4764 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4765 cx.update(LanguageModelRegistry::test);
4766 let model = Arc::new(FakeLanguageModel::default());
4767 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4768 let native_agent = NativeAgent::new(
4769 project.clone(),
4770 thread_store,
4771 Templates::new(),
4772 None,
4773 fs,
4774 &mut cx.to_async(),
4775 )
4776 .await
4777 .unwrap();
4778 let parent_thread = cx.new(|cx| {
4779 let mut thread = Thread::new(
4780 project.clone(),
4781 project_context,
4782 context_server_registry,
4783 Templates::new(),
4784 Some(model.clone()),
4785 cx,
4786 );
4787 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4788 thread.add_tool(GrepTool::new(project.clone()), None);
4789 thread
4790 });
4791
4792 let _subagent_handle = cx
4793 .update(|cx| {
4794 NativeThreadEnvironment::create_subagent_thread(
4795 native_agent.downgrade(),
4796 parent_thread.clone(),
4797 "some title".to_string(),
4798 "task prompt".to_string(),
4799 Some(Duration::from_millis(10)),
4800 Some(vec!["grep".to_string()]),
4801 cx,
4802 )
4803 })
4804 .expect("Failed to create subagent");
4805
4806 cx.run_until_parked();
4807
4808 let tools = model
4809 .pending_completions()
4810 .last()
4811 .unwrap()
4812 .tools
4813 .iter()
4814 .map(|tool| tool.name.clone())
4815 .collect::<Vec<_>>();
4816 assert_eq!(tools, vec!["grep"]);
4817}
4818
4819#[gpui::test]
4820async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4821 init_test(cx);
4822
4823 let fs = FakeFs::new(cx.executor());
4824 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4825 .await;
4826 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4827
4828 cx.update(|cx| {
4829 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4830 settings.tool_permissions.tools.insert(
4831 EditFileTool::NAME.into(),
4832 agent_settings::ToolRules {
4833 default: Some(settings::ToolPermissionMode::Allow),
4834 always_allow: vec![],
4835 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
4836 always_confirm: vec![],
4837 invalid_patterns: vec![],
4838 },
4839 );
4840 agent_settings::AgentSettings::override_global(settings, cx);
4841 });
4842
4843 let context_server_registry =
4844 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4845 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4846 let templates = crate::Templates::new();
4847 let thread = cx.new(|cx| {
4848 crate::Thread::new(
4849 project.clone(),
4850 cx.new(|_cx| prompt_store::ProjectContext::default()),
4851 context_server_registry,
4852 templates.clone(),
4853 None,
4854 cx,
4855 )
4856 });
4857
4858 #[allow(clippy::arc_with_non_send_sync)]
4859 let tool = Arc::new(crate::EditFileTool::new(
4860 project.clone(),
4861 thread.downgrade(),
4862 language_registry,
4863 templates,
4864 ));
4865 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4866
4867 let task = cx.update(|cx| {
4868 tool.run(
4869 crate::EditFileToolInput {
4870 display_description: "Edit sensitive file".to_string(),
4871 path: "root/sensitive_config.txt".into(),
4872 mode: crate::EditFileMode::Edit,
4873 },
4874 event_stream,
4875 cx,
4876 )
4877 });
4878
4879 let result = task.await;
4880 assert!(result.is_err(), "expected edit to be blocked");
4881 assert!(
4882 result.unwrap_err().to_string().contains("blocked"),
4883 "error should mention the edit was blocked"
4884 );
4885}
4886
4887#[gpui::test]
4888async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4889 init_test(cx);
4890
4891 let fs = FakeFs::new(cx.executor());
4892 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4893 .await;
4894 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4895
4896 cx.update(|cx| {
4897 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4898 settings.tool_permissions.tools.insert(
4899 DeletePathTool::NAME.into(),
4900 agent_settings::ToolRules {
4901 default: Some(settings::ToolPermissionMode::Allow),
4902 always_allow: vec![],
4903 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4904 always_confirm: vec![],
4905 invalid_patterns: vec![],
4906 },
4907 );
4908 agent_settings::AgentSettings::override_global(settings, cx);
4909 });
4910
4911 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4912
4913 #[allow(clippy::arc_with_non_send_sync)]
4914 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4915 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4916
4917 let task = cx.update(|cx| {
4918 tool.run(
4919 crate::DeletePathToolInput {
4920 path: "root/important_data.txt".to_string(),
4921 },
4922 event_stream,
4923 cx,
4924 )
4925 });
4926
4927 let result = task.await;
4928 assert!(result.is_err(), "expected deletion to be blocked");
4929 assert!(
4930 result.unwrap_err().to_string().contains("blocked"),
4931 "error should mention the deletion was blocked"
4932 );
4933}
4934
4935#[gpui::test]
4936async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
4937 init_test(cx);
4938
4939 let fs = FakeFs::new(cx.executor());
4940 fs.insert_tree(
4941 "/root",
4942 json!({
4943 "safe.txt": "content",
4944 "protected": {}
4945 }),
4946 )
4947 .await;
4948 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4949
4950 cx.update(|cx| {
4951 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4952 settings.tool_permissions.tools.insert(
4953 MovePathTool::NAME.into(),
4954 agent_settings::ToolRules {
4955 default: Some(settings::ToolPermissionMode::Allow),
4956 always_allow: vec![],
4957 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
4958 always_confirm: vec![],
4959 invalid_patterns: vec![],
4960 },
4961 );
4962 agent_settings::AgentSettings::override_global(settings, cx);
4963 });
4964
4965 #[allow(clippy::arc_with_non_send_sync)]
4966 let tool = Arc::new(crate::MovePathTool::new(project));
4967 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4968
4969 let task = cx.update(|cx| {
4970 tool.run(
4971 crate::MovePathToolInput {
4972 source_path: "root/safe.txt".to_string(),
4973 destination_path: "root/protected/safe.txt".to_string(),
4974 },
4975 event_stream,
4976 cx,
4977 )
4978 });
4979
4980 let result = task.await;
4981 assert!(
4982 result.is_err(),
4983 "expected move to be blocked due to destination path"
4984 );
4985 assert!(
4986 result.unwrap_err().to_string().contains("blocked"),
4987 "error should mention the move was blocked"
4988 );
4989}
4990
4991#[gpui::test]
4992async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
4993 init_test(cx);
4994
4995 let fs = FakeFs::new(cx.executor());
4996 fs.insert_tree(
4997 "/root",
4998 json!({
4999 "secret.txt": "secret content",
5000 "public": {}
5001 }),
5002 )
5003 .await;
5004 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5005
5006 cx.update(|cx| {
5007 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5008 settings.tool_permissions.tools.insert(
5009 MovePathTool::NAME.into(),
5010 agent_settings::ToolRules {
5011 default: Some(settings::ToolPermissionMode::Allow),
5012 always_allow: vec![],
5013 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5014 always_confirm: vec![],
5015 invalid_patterns: vec![],
5016 },
5017 );
5018 agent_settings::AgentSettings::override_global(settings, cx);
5019 });
5020
5021 #[allow(clippy::arc_with_non_send_sync)]
5022 let tool = Arc::new(crate::MovePathTool::new(project));
5023 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5024
5025 let task = cx.update(|cx| {
5026 tool.run(
5027 crate::MovePathToolInput {
5028 source_path: "root/secret.txt".to_string(),
5029 destination_path: "root/public/not_secret.txt".to_string(),
5030 },
5031 event_stream,
5032 cx,
5033 )
5034 });
5035
5036 let result = task.await;
5037 assert!(
5038 result.is_err(),
5039 "expected move to be blocked due to source path"
5040 );
5041 assert!(
5042 result.unwrap_err().to_string().contains("blocked"),
5043 "error should mention the move was blocked"
5044 );
5045}
5046
5047#[gpui::test]
5048async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5049 init_test(cx);
5050
5051 let fs = FakeFs::new(cx.executor());
5052 fs.insert_tree(
5053 "/root",
5054 json!({
5055 "confidential.txt": "confidential data",
5056 "dest": {}
5057 }),
5058 )
5059 .await;
5060 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5061
5062 cx.update(|cx| {
5063 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5064 settings.tool_permissions.tools.insert(
5065 CopyPathTool::NAME.into(),
5066 agent_settings::ToolRules {
5067 default: Some(settings::ToolPermissionMode::Allow),
5068 always_allow: vec![],
5069 always_deny: vec![
5070 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5071 ],
5072 always_confirm: vec![],
5073 invalid_patterns: vec![],
5074 },
5075 );
5076 agent_settings::AgentSettings::override_global(settings, cx);
5077 });
5078
5079 #[allow(clippy::arc_with_non_send_sync)]
5080 let tool = Arc::new(crate::CopyPathTool::new(project));
5081 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5082
5083 let task = cx.update(|cx| {
5084 tool.run(
5085 crate::CopyPathToolInput {
5086 source_path: "root/confidential.txt".to_string(),
5087 destination_path: "root/dest/copy.txt".to_string(),
5088 },
5089 event_stream,
5090 cx,
5091 )
5092 });
5093
5094 let result = task.await;
5095 assert!(result.is_err(), "expected copy to be blocked");
5096 assert!(
5097 result.unwrap_err().to_string().contains("blocked"),
5098 "error should mention the copy was blocked"
5099 );
5100}
5101
5102#[gpui::test]
5103async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5104 init_test(cx);
5105
5106 let fs = FakeFs::new(cx.executor());
5107 fs.insert_tree(
5108 "/root",
5109 json!({
5110 "normal.txt": "normal content",
5111 "readonly": {
5112 "config.txt": "readonly content"
5113 }
5114 }),
5115 )
5116 .await;
5117 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5118
5119 cx.update(|cx| {
5120 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5121 settings.tool_permissions.tools.insert(
5122 SaveFileTool::NAME.into(),
5123 agent_settings::ToolRules {
5124 default: Some(settings::ToolPermissionMode::Allow),
5125 always_allow: vec![],
5126 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5127 always_confirm: vec![],
5128 invalid_patterns: vec![],
5129 },
5130 );
5131 agent_settings::AgentSettings::override_global(settings, cx);
5132 });
5133
5134 #[allow(clippy::arc_with_non_send_sync)]
5135 let tool = Arc::new(crate::SaveFileTool::new(project));
5136 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5137
5138 let task = cx.update(|cx| {
5139 tool.run(
5140 crate::SaveFileToolInput {
5141 paths: vec![
5142 std::path::PathBuf::from("root/normal.txt"),
5143 std::path::PathBuf::from("root/readonly/config.txt"),
5144 ],
5145 },
5146 event_stream,
5147 cx,
5148 )
5149 });
5150
5151 let result = task.await;
5152 assert!(
5153 result.is_err(),
5154 "expected save to be blocked due to denied path"
5155 );
5156 assert!(
5157 result.unwrap_err().to_string().contains("blocked"),
5158 "error should mention the save was blocked"
5159 );
5160}
5161
5162#[gpui::test]
5163async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5164 init_test(cx);
5165
5166 let fs = FakeFs::new(cx.executor());
5167 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5168 .await;
5169 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5170
5171 cx.update(|cx| {
5172 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5173 settings.tool_permissions.tools.insert(
5174 SaveFileTool::NAME.into(),
5175 agent_settings::ToolRules {
5176 default: Some(settings::ToolPermissionMode::Allow),
5177 always_allow: vec![],
5178 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5179 always_confirm: vec![],
5180 invalid_patterns: vec![],
5181 },
5182 );
5183 agent_settings::AgentSettings::override_global(settings, cx);
5184 });
5185
5186 #[allow(clippy::arc_with_non_send_sync)]
5187 let tool = Arc::new(crate::SaveFileTool::new(project));
5188 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5189
5190 let task = cx.update(|cx| {
5191 tool.run(
5192 crate::SaveFileToolInput {
5193 paths: vec![std::path::PathBuf::from("root/config.secret")],
5194 },
5195 event_stream,
5196 cx,
5197 )
5198 });
5199
5200 let result = task.await;
5201 assert!(result.is_err(), "expected save to be blocked");
5202 assert!(
5203 result.unwrap_err().to_string().contains("blocked"),
5204 "error should mention the save was blocked"
5205 );
5206}
5207
5208#[gpui::test]
5209async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5210 init_test(cx);
5211
5212 cx.update(|cx| {
5213 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5214 settings.tool_permissions.tools.insert(
5215 WebSearchTool::NAME.into(),
5216 agent_settings::ToolRules {
5217 default: Some(settings::ToolPermissionMode::Allow),
5218 always_allow: vec![],
5219 always_deny: vec![
5220 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5221 ],
5222 always_confirm: vec![],
5223 invalid_patterns: vec![],
5224 },
5225 );
5226 agent_settings::AgentSettings::override_global(settings, cx);
5227 });
5228
5229 #[allow(clippy::arc_with_non_send_sync)]
5230 let tool = Arc::new(crate::WebSearchTool);
5231 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5232
5233 let input: crate::WebSearchToolInput =
5234 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5235
5236 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5237
5238 let result = task.await;
5239 assert!(result.is_err(), "expected search to be blocked");
5240 assert!(
5241 result.unwrap_err().to_string().contains("blocked"),
5242 "error should mention the search was blocked"
5243 );
5244}
5245
5246#[gpui::test]
5247async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5248 init_test(cx);
5249
5250 let fs = FakeFs::new(cx.executor());
5251 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5252 .await;
5253 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5254
5255 cx.update(|cx| {
5256 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5257 settings.tool_permissions.tools.insert(
5258 EditFileTool::NAME.into(),
5259 agent_settings::ToolRules {
5260 default: Some(settings::ToolPermissionMode::Confirm),
5261 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5262 always_deny: vec![],
5263 always_confirm: vec![],
5264 invalid_patterns: vec![],
5265 },
5266 );
5267 agent_settings::AgentSettings::override_global(settings, cx);
5268 });
5269
5270 let context_server_registry =
5271 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5272 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5273 let templates = crate::Templates::new();
5274 let thread = cx.new(|cx| {
5275 crate::Thread::new(
5276 project.clone(),
5277 cx.new(|_cx| prompt_store::ProjectContext::default()),
5278 context_server_registry,
5279 templates.clone(),
5280 None,
5281 cx,
5282 )
5283 });
5284
5285 #[allow(clippy::arc_with_non_send_sync)]
5286 let tool = Arc::new(crate::EditFileTool::new(
5287 project,
5288 thread.downgrade(),
5289 language_registry,
5290 templates,
5291 ));
5292 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5293
5294 let _task = cx.update(|cx| {
5295 tool.run(
5296 crate::EditFileToolInput {
5297 display_description: "Edit README".to_string(),
5298 path: "root/README.md".into(),
5299 mode: crate::EditFileMode::Edit,
5300 },
5301 event_stream,
5302 cx,
5303 )
5304 });
5305
5306 cx.run_until_parked();
5307
5308 let event = rx.try_next();
5309 assert!(
5310 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5311 "expected no authorization request for allowed .md file"
5312 );
5313}
5314
5315#[gpui::test]
5316async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5317 init_test(cx);
5318
5319 let fs = FakeFs::new(cx.executor());
5320 fs.insert_tree(
5321 "/root",
5322 json!({
5323 ".zed": {
5324 "settings.json": "{}"
5325 },
5326 "README.md": "# Hello"
5327 }),
5328 )
5329 .await;
5330 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5331
5332 cx.update(|cx| {
5333 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5334 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5335 agent_settings::AgentSettings::override_global(settings, cx);
5336 });
5337
5338 let context_server_registry =
5339 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5340 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5341 let templates = crate::Templates::new();
5342 let thread = cx.new(|cx| {
5343 crate::Thread::new(
5344 project.clone(),
5345 cx.new(|_cx| prompt_store::ProjectContext::default()),
5346 context_server_registry,
5347 templates.clone(),
5348 None,
5349 cx,
5350 )
5351 });
5352
5353 #[allow(clippy::arc_with_non_send_sync)]
5354 let tool = Arc::new(crate::EditFileTool::new(
5355 project,
5356 thread.downgrade(),
5357 language_registry,
5358 templates,
5359 ));
5360
5361 // Editing a file inside .zed/ should still prompt even with global default: allow,
5362 // because local settings paths are sensitive and require confirmation regardless.
5363 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5364 let _task = cx.update(|cx| {
5365 tool.run(
5366 crate::EditFileToolInput {
5367 display_description: "Edit local settings".to_string(),
5368 path: "root/.zed/settings.json".into(),
5369 mode: crate::EditFileMode::Edit,
5370 },
5371 event_stream,
5372 cx,
5373 )
5374 });
5375
5376 let _update = rx.expect_update_fields().await;
5377 let _auth = rx.expect_authorization().await;
5378}
5379
5380#[gpui::test]
5381async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5382 init_test(cx);
5383
5384 cx.update(|cx| {
5385 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5386 settings.tool_permissions.tools.insert(
5387 FetchTool::NAME.into(),
5388 agent_settings::ToolRules {
5389 default: Some(settings::ToolPermissionMode::Allow),
5390 always_allow: vec![],
5391 always_deny: vec![
5392 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5393 ],
5394 always_confirm: vec![],
5395 invalid_patterns: vec![],
5396 },
5397 );
5398 agent_settings::AgentSettings::override_global(settings, cx);
5399 });
5400
5401 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5402
5403 #[allow(clippy::arc_with_non_send_sync)]
5404 let tool = Arc::new(crate::FetchTool::new(http_client));
5405 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5406
5407 let input: crate::FetchToolInput =
5408 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5409
5410 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5411
5412 let result = task.await;
5413 assert!(result.is_err(), "expected fetch to be blocked");
5414 assert!(
5415 result.unwrap_err().to_string().contains("blocked"),
5416 "error should mention the fetch was blocked"
5417 );
5418}
5419
5420#[gpui::test]
5421async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5422 init_test(cx);
5423
5424 cx.update(|cx| {
5425 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5426 settings.tool_permissions.tools.insert(
5427 FetchTool::NAME.into(),
5428 agent_settings::ToolRules {
5429 default: Some(settings::ToolPermissionMode::Confirm),
5430 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5431 always_deny: vec![],
5432 always_confirm: vec![],
5433 invalid_patterns: vec![],
5434 },
5435 );
5436 agent_settings::AgentSettings::override_global(settings, cx);
5437 });
5438
5439 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5440
5441 #[allow(clippy::arc_with_non_send_sync)]
5442 let tool = Arc::new(crate::FetchTool::new(http_client));
5443 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5444
5445 let input: crate::FetchToolInput =
5446 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5447
5448 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5449
5450 cx.run_until_parked();
5451
5452 let event = rx.try_next();
5453 assert!(
5454 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5455 "expected no authorization request for allowed docs.rs URL"
5456 );
5457}
5458
5459#[gpui::test]
5460async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5461 init_test(cx);
5462 always_allow_tools(cx);
5463
5464 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5465 let fake_model = model.as_fake();
5466
5467 // Add a tool so we can simulate tool calls
5468 thread.update(cx, |thread, _cx| {
5469 thread.add_tool(EchoTool, None);
5470 });
5471
5472 // Start a turn by sending a message
5473 let mut events = thread
5474 .update(cx, |thread, cx| {
5475 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5476 })
5477 .unwrap();
5478 cx.run_until_parked();
5479
5480 // Simulate the model making a tool call
5481 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5482 LanguageModelToolUse {
5483 id: "tool_1".into(),
5484 name: "echo".into(),
5485 raw_input: r#"{"text": "hello"}"#.into(),
5486 input: json!({"text": "hello"}),
5487 is_input_complete: true,
5488 thought_signature: None,
5489 },
5490 ));
5491 fake_model
5492 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5493
5494 // Signal that a message is queued before ending the stream
5495 thread.update(cx, |thread, _cx| {
5496 thread.set_has_queued_message(true);
5497 });
5498
5499 // Now end the stream - tool will run, and the boundary check should see the queue
5500 fake_model.end_last_completion_stream();
5501
5502 // Collect all events until the turn stops
5503 let all_events = collect_events_until_stop(&mut events, cx).await;
5504
5505 // Verify we received the tool call event
5506 let tool_call_ids: Vec<_> = all_events
5507 .iter()
5508 .filter_map(|e| match e {
5509 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5510 _ => None,
5511 })
5512 .collect();
5513 assert_eq!(
5514 tool_call_ids,
5515 vec!["tool_1"],
5516 "Should have received a tool call event for our echo tool"
5517 );
5518
5519 // The turn should have stopped with EndTurn
5520 let stop_reasons = stop_events(all_events);
5521 assert_eq!(
5522 stop_reasons,
5523 vec![acp::StopReason::EndTurn],
5524 "Turn should have ended after tool completion due to queued message"
5525 );
5526
5527 // Verify the queued message flag is still set
5528 thread.update(cx, |thread, _cx| {
5529 assert!(
5530 thread.has_queued_message(),
5531 "Should still have queued message flag set"
5532 );
5533 });
5534
5535 // Thread should be idle now
5536 thread.update(cx, |thread, _cx| {
5537 assert!(
5538 thread.is_turn_complete(),
5539 "Thread should not be running after turn ends"
5540 );
5541 });
5542}