1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeSubagentHandle {
159 session_id: acp::SessionId,
160 wait_for_summary_task: Shared<Task<String>>,
161}
162
163impl FakeSubagentHandle {
164 fn new_never_completes(cx: &App) -> Self {
165 Self {
166 session_id: acp::SessionId::new("subagent-id"),
167 wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
168 }
169 }
170}
171
172impl SubagentHandle for FakeSubagentHandle {
173 fn id(&self) -> acp::SessionId {
174 self.session_id.clone()
175 }
176
177 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
178 let task = self.wait_for_summary_task.clone();
179 cx.background_spawn(async move { Ok(task.await) })
180 }
181}
182
183#[derive(Default)]
184struct FakeThreadEnvironment {
185 terminal_handle: Option<Rc<FakeTerminalHandle>>,
186 subagent_handle: Option<Rc<FakeSubagentHandle>>,
187}
188
189impl FakeThreadEnvironment {
190 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
191 Self {
192 terminal_handle: Some(terminal_handle.into()),
193 ..self
194 }
195 }
196
197 pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
198 Self {
199 subagent_handle: Some(subagent_handle.into()),
200 ..self
201 }
202 }
203}
204
205impl crate::ThreadEnvironment for FakeThreadEnvironment {
206 fn create_terminal(
207 &self,
208 _command: String,
209 _cwd: Option<std::path::PathBuf>,
210 _output_byte_limit: Option<u64>,
211 _cx: &mut AsyncApp,
212 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
213 let handle = self
214 .terminal_handle
215 .clone()
216 .expect("Terminal handle not available on FakeThreadEnvironment");
217 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
218 }
219
220 fn create_subagent(
221 &self,
222 _parent_thread: Entity<Thread>,
223 _label: String,
224 _initial_prompt: String,
225 _timeout_ms: Option<Duration>,
226 _allowed_tools: Option<Vec<String>>,
227 _cx: &mut App,
228 ) -> Result<Rc<dyn SubagentHandle>> {
229 Ok(self
230 .subagent_handle
231 .clone()
232 .expect("Subagent handle not available on FakeThreadEnvironment")
233 as Rc<dyn SubagentHandle>)
234 }
235}
236
237/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
238struct MultiTerminalEnvironment {
239 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
240}
241
242impl MultiTerminalEnvironment {
243 fn new() -> Self {
244 Self {
245 handles: std::cell::RefCell::new(Vec::new()),
246 }
247 }
248
249 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
250 self.handles.borrow().clone()
251 }
252}
253
254impl crate::ThreadEnvironment for MultiTerminalEnvironment {
255 fn create_terminal(
256 &self,
257 _command: String,
258 _cwd: Option<std::path::PathBuf>,
259 _output_byte_limit: Option<u64>,
260 cx: &mut AsyncApp,
261 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
263 self.handles.borrow_mut().push(handle.clone());
264 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
265 }
266
267 fn create_subagent(
268 &self,
269 _parent_thread: Entity<Thread>,
270 _label: String,
271 _initial_prompt: String,
272 _timeout: Option<Duration>,
273 _allowed_tools: Option<Vec<String>>,
274 _cx: &mut App,
275 ) -> Result<Rc<dyn SubagentHandle>> {
276 unimplemented!()
277 }
278}
279
280fn always_allow_tools(cx: &mut TestAppContext) {
281 cx.update(|cx| {
282 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
283 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
284 agent_settings::AgentSettings::override_global(settings, cx);
285 });
286}
287
288#[gpui::test]
289async fn test_echo(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
296 })
297 .unwrap();
298 cx.run_until_parked();
299 fake_model.send_last_completion_stream_text_chunk("Hello");
300 fake_model
301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
302 fake_model.end_last_completion_stream();
303
304 let events = events.collect().await;
305 thread.update(cx, |thread, _cx| {
306 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
307 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
308 });
309 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
310}
311
312#[gpui::test]
313async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
314 init_test(cx);
315 always_allow_tools(cx);
316
317 let fs = FakeFs::new(cx.executor());
318 let project = Project::test(fs, [], cx).await;
319
320 let environment = Rc::new(cx.update(|cx| {
321 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
322 }));
323 let handle = environment.terminal_handle.clone().unwrap();
324
325 #[allow(clippy::arc_with_non_send_sync)]
326 let tool = Arc::new(crate::TerminalTool::new(project, environment));
327 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
328
329 let task = cx.update(|cx| {
330 tool.run(
331 crate::TerminalToolInput {
332 command: "sleep 1000".to_string(),
333 cd: ".".to_string(),
334 timeout_ms: Some(5),
335 },
336 event_stream,
337 cx,
338 )
339 });
340
341 let update = rx.expect_update_fields().await;
342 assert!(
343 update.content.iter().any(|blocks| {
344 blocks
345 .iter()
346 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
347 }),
348 "expected tool call update to include terminal content"
349 );
350
351 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
352
353 let deadline = std::time::Instant::now() + Duration::from_millis(500);
354 loop {
355 if let Some(result) = task_future.as_mut().now_or_never() {
356 let result = result.expect("terminal tool task should complete");
357
358 assert!(
359 handle.was_killed(),
360 "expected terminal handle to be killed on timeout"
361 );
362 assert!(
363 result.contains("partial output"),
364 "expected result to include terminal output, got: {result}"
365 );
366 return;
367 }
368
369 if std::time::Instant::now() >= deadline {
370 panic!("timed out waiting for terminal tool task to complete");
371 }
372
373 cx.run_until_parked();
374 cx.background_executor.timer(Duration::from_millis(1)).await;
375 }
376}
377
378#[gpui::test]
379#[ignore]
380async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
381 init_test(cx);
382 always_allow_tools(cx);
383
384 let fs = FakeFs::new(cx.executor());
385 let project = Project::test(fs, [], cx).await;
386
387 let environment = Rc::new(cx.update(|cx| {
388 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
389 }));
390 let handle = environment.terminal_handle.clone().unwrap();
391
392 #[allow(clippy::arc_with_non_send_sync)]
393 let tool = Arc::new(crate::TerminalTool::new(project, environment));
394 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
395
396 let _task = cx.update(|cx| {
397 tool.run(
398 crate::TerminalToolInput {
399 command: "sleep 1000".to_string(),
400 cd: ".".to_string(),
401 timeout_ms: None,
402 },
403 event_stream,
404 cx,
405 )
406 });
407
408 let update = rx.expect_update_fields().await;
409 assert!(
410 update.content.iter().any(|blocks| {
411 blocks
412 .iter()
413 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
414 }),
415 "expected tool call update to include terminal content"
416 );
417
418 cx.background_executor
419 .timer(Duration::from_millis(25))
420 .await;
421
422 assert!(
423 !handle.was_killed(),
424 "did not expect terminal handle to be killed without a timeout"
425 );
426}
427
428#[gpui::test]
429async fn test_thinking(cx: &mut TestAppContext) {
430 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
431 let fake_model = model.as_fake();
432
433 let events = thread
434 .update(cx, |thread, cx| {
435 thread.send(
436 UserMessageId::new(),
437 [indoc! {"
438 Testing:
439
440 Generate a thinking step where you just think the word 'Think',
441 and have your final answer be 'Hello'
442 "}],
443 cx,
444 )
445 })
446 .unwrap();
447 cx.run_until_parked();
448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
449 text: "Think".to_string(),
450 signature: None,
451 });
452 fake_model.send_last_completion_stream_text_chunk("Hello");
453 fake_model
454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
455 fake_model.end_last_completion_stream();
456
457 let events = events.collect().await;
458 thread.update(cx, |thread, _cx| {
459 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
460 assert_eq!(
461 thread.last_message().unwrap().to_markdown(),
462 indoc! {"
463 <think>Think</think>
464 Hello
465 "}
466 )
467 });
468 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
469}
470
471#[gpui::test]
472async fn test_system_prompt(cx: &mut TestAppContext) {
473 let ThreadTest {
474 model,
475 thread,
476 project_context,
477 ..
478 } = setup(cx, TestModel::Fake).await;
479 let fake_model = model.as_fake();
480
481 project_context.update(cx, |project_context, _cx| {
482 project_context.shell = "test-shell".into()
483 });
484 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["abc"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491 let mut pending_completions = fake_model.pending_completions();
492 assert_eq!(
493 pending_completions.len(),
494 1,
495 "unexpected pending completions: {:?}",
496 pending_completions
497 );
498
499 let pending_completion = pending_completions.pop().unwrap();
500 assert_eq!(pending_completion.messages[0].role, Role::System);
501
502 let system_message = &pending_completion.messages[0];
503 let system_prompt = system_message.content[0].to_str().unwrap();
504 assert!(
505 system_prompt.contains("test-shell"),
506 "unexpected system message: {:?}",
507 system_message
508 );
509 assert!(
510 system_prompt.contains("## Fixing Diagnostics"),
511 "unexpected system message: {:?}",
512 system_message
513 );
514}
515
516#[gpui::test]
517async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
518 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
519 let fake_model = model.as_fake();
520
521 thread
522 .update(cx, |thread, cx| {
523 thread.send(UserMessageId::new(), ["abc"], cx)
524 })
525 .unwrap();
526 cx.run_until_parked();
527 let mut pending_completions = fake_model.pending_completions();
528 assert_eq!(
529 pending_completions.len(),
530 1,
531 "unexpected pending completions: {:?}",
532 pending_completions
533 );
534
535 let pending_completion = pending_completions.pop().unwrap();
536 assert_eq!(pending_completion.messages[0].role, Role::System);
537
538 let system_message = &pending_completion.messages[0];
539 let system_prompt = system_message.content[0].to_str().unwrap();
540 assert!(
541 !system_prompt.contains("## Tool Use"),
542 "unexpected system message: {:?}",
543 system_message
544 );
545 assert!(
546 !system_prompt.contains("## Fixing Diagnostics"),
547 "unexpected system message: {:?}",
548 system_message
549 );
550}
551
552#[gpui::test]
553async fn test_prompt_caching(cx: &mut TestAppContext) {
554 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
555 let fake_model = model.as_fake();
556
557 // Send initial user message and verify it's cached
558 thread
559 .update(cx, |thread, cx| {
560 thread.send(UserMessageId::new(), ["Message 1"], cx)
561 })
562 .unwrap();
563 cx.run_until_parked();
564
565 let completion = fake_model.pending_completions().pop().unwrap();
566 assert_eq!(
567 completion.messages[1..],
568 vec![LanguageModelRequestMessage {
569 role: Role::User,
570 content: vec!["Message 1".into()],
571 cache: true,
572 reasoning_details: None,
573 }]
574 );
575 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
576 "Response to Message 1".into(),
577 ));
578 fake_model.end_last_completion_stream();
579 cx.run_until_parked();
580
581 // Send another user message and verify only the latest is cached
582 thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["Message 2"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588
589 let completion = fake_model.pending_completions().pop().unwrap();
590 assert_eq!(
591 completion.messages[1..],
592 vec![
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 1".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 1".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Message 2".into()],
608 cache: true,
609 reasoning_details: None,
610 }
611 ]
612 );
613 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
614 "Response to Message 2".into(),
615 ));
616 fake_model.end_last_completion_stream();
617 cx.run_until_parked();
618
619 // Simulate a tool call and verify that the latest tool result is cached
620 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
621 thread
622 .update(cx, |thread, cx| {
623 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
624 })
625 .unwrap();
626 cx.run_until_parked();
627
628 let tool_use = LanguageModelToolUse {
629 id: "tool_1".into(),
630 name: EchoTool::NAME.into(),
631 raw_input: json!({"text": "test"}).to_string(),
632 input: json!({"text": "test"}),
633 is_input_complete: true,
634 thought_signature: None,
635 };
636 fake_model
637 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
638 fake_model.end_last_completion_stream();
639 cx.run_until_parked();
640
641 let completion = fake_model.pending_completions().pop().unwrap();
642 let tool_result = LanguageModelToolResult {
643 tool_use_id: "tool_1".into(),
644 tool_name: EchoTool::NAME.into(),
645 is_error: false,
646 content: "test".into(),
647 output: Some("test".into()),
648 };
649 assert_eq!(
650 completion.messages[1..],
651 vec![
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec!["Message 1".into()],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::Assistant,
660 content: vec!["Response to Message 1".into()],
661 cache: false,
662 reasoning_details: None,
663 },
664 LanguageModelRequestMessage {
665 role: Role::User,
666 content: vec!["Message 2".into()],
667 cache: false,
668 reasoning_details: None,
669 },
670 LanguageModelRequestMessage {
671 role: Role::Assistant,
672 content: vec!["Response to Message 2".into()],
673 cache: false,
674 reasoning_details: None,
675 },
676 LanguageModelRequestMessage {
677 role: Role::User,
678 content: vec!["Use the echo tool".into()],
679 cache: false,
680 reasoning_details: None,
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false,
686 reasoning_details: None,
687 },
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec![MessageContent::ToolResult(tool_result)],
691 cache: true,
692 reasoning_details: None,
693 }
694 ]
695 );
696}
697
698#[gpui::test]
699#[cfg_attr(not(feature = "e2e"), ignore)]
700async fn test_basic_tool_calls(cx: &mut TestAppContext) {
701 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
702
703 // Test a tool call that's likely to complete *before* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.add_tool(EchoTool, None);
707 thread.send(
708 UserMessageId::new(),
709 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
710 cx,
711 )
712 })
713 .unwrap()
714 .collect()
715 .await;
716 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
717
718 // Test a tool calls that's likely to complete *after* streaming stops.
719 let events = thread
720 .update(cx, |thread, cx| {
721 thread.remove_tool(&EchoTool::NAME);
722 thread.add_tool(DelayTool, None);
723 thread.send(
724 UserMessageId::new(),
725 [
726 "Now call the delay tool with 200ms.",
727 "When the timer goes off, then you echo the output of the tool.",
728 ],
729 cx,
730 )
731 })
732 .unwrap()
733 .collect()
734 .await;
735 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
736 thread.update(cx, |thread, _cx| {
737 assert!(
738 thread
739 .last_message()
740 .unwrap()
741 .as_agent_message()
742 .unwrap()
743 .content
744 .iter()
745 .any(|content| {
746 if let AgentMessageContent::Text(text) = content {
747 text.contains("Ding")
748 } else {
749 false
750 }
751 }),
752 "{}",
753 thread.to_markdown()
754 );
755 });
756}
757
758#[gpui::test]
759#[cfg_attr(not(feature = "e2e"), ignore)]
760async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
761 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
762
763 // Test a tool call that's likely to complete *before* streaming stops.
764 let mut events = thread
765 .update(cx, |thread, cx| {
766 thread.add_tool(WordListTool, None);
767 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
768 })
769 .unwrap();
770
771 let mut saw_partial_tool_use = false;
772 while let Some(event) = events.next().await {
773 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
774 thread.update(cx, |thread, _cx| {
775 // Look for a tool use in the thread's last message
776 let message = thread.last_message().unwrap();
777 let agent_message = message.as_agent_message().unwrap();
778 let last_content = agent_message.content.last().unwrap();
779 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
780 assert_eq!(last_tool_use.name.as_ref(), "word_list");
781 if tool_call.status == acp::ToolCallStatus::Pending {
782 if !last_tool_use.is_input_complete
783 && last_tool_use.input.get("g").is_none()
784 {
785 saw_partial_tool_use = true;
786 }
787 } else {
788 last_tool_use
789 .input
790 .get("a")
791 .expect("'a' has streamed because input is now complete");
792 last_tool_use
793 .input
794 .get("g")
795 .expect("'g' has streamed because input is now complete");
796 }
797 } else {
798 panic!("last content should be a tool use");
799 }
800 });
801 }
802 }
803
804 assert!(
805 saw_partial_tool_use,
806 "should see at least one partially streamed tool use in the history"
807 );
808}
809
810#[gpui::test]
811async fn test_tool_authorization(cx: &mut TestAppContext) {
812 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
813 let fake_model = model.as_fake();
814
815 let mut events = thread
816 .update(cx, |thread, cx| {
817 thread.add_tool(ToolRequiringPermission, None);
818 thread.send(UserMessageId::new(), ["abc"], cx)
819 })
820 .unwrap();
821 cx.run_until_parked();
822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
823 LanguageModelToolUse {
824 id: "tool_id_1".into(),
825 name: ToolRequiringPermission::NAME.into(),
826 raw_input: "{}".into(),
827 input: json!({}),
828 is_input_complete: true,
829 thought_signature: None,
830 },
831 ));
832 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
833 LanguageModelToolUse {
834 id: "tool_id_2".into(),
835 name: ToolRequiringPermission::NAME.into(),
836 raw_input: "{}".into(),
837 input: json!({}),
838 is_input_complete: true,
839 thought_signature: None,
840 },
841 ));
842 fake_model.end_last_completion_stream();
843 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
844 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
845
846 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
847 tool_call_auth_1
848 .response
849 .send(acp::PermissionOptionId::new("allow"))
850 .unwrap();
851 cx.run_until_parked();
852
853 // Reject the second - send "deny" option_id directly since Deny is now a button
854 tool_call_auth_2
855 .response
856 .send(acp::PermissionOptionId::new("deny"))
857 .unwrap();
858 cx.run_until_parked();
859
860 let completion = fake_model.pending_completions().pop().unwrap();
861 let message = completion.messages.last().unwrap();
862 assert_eq!(
863 message.content,
864 vec![
865 language_model::MessageContent::ToolResult(LanguageModelToolResult {
866 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
867 tool_name: ToolRequiringPermission::NAME.into(),
868 is_error: false,
869 content: "Allowed".into(),
870 output: Some("Allowed".into())
871 }),
872 language_model::MessageContent::ToolResult(LanguageModelToolResult {
873 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
874 tool_name: ToolRequiringPermission::NAME.into(),
875 is_error: true,
876 content: "Permission to run tool denied by user".into(),
877 output: Some("Permission to run tool denied by user".into())
878 })
879 ]
880 );
881
882 // Simulate yet another tool call.
883 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
884 LanguageModelToolUse {
885 id: "tool_id_3".into(),
886 name: ToolRequiringPermission::NAME.into(),
887 raw_input: "{}".into(),
888 input: json!({}),
889 is_input_complete: true,
890 thought_signature: None,
891 },
892 ));
893 fake_model.end_last_completion_stream();
894
895 // Respond by always allowing tools - send transformed option_id
896 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
897 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
898 tool_call_auth_3
899 .response
900 .send(acp::PermissionOptionId::new(
901 "always_allow:tool_requiring_permission",
902 ))
903 .unwrap();
904 cx.run_until_parked();
905 let completion = fake_model.pending_completions().pop().unwrap();
906 let message = completion.messages.last().unwrap();
907 assert_eq!(
908 message.content,
909 vec![language_model::MessageContent::ToolResult(
910 LanguageModelToolResult {
911 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
912 tool_name: ToolRequiringPermission::NAME.into(),
913 is_error: false,
914 content: "Allowed".into(),
915 output: Some("Allowed".into())
916 }
917 )]
918 );
919
920 // Simulate a final tool call, ensuring we don't trigger authorization.
921 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
922 LanguageModelToolUse {
923 id: "tool_id_4".into(),
924 name: ToolRequiringPermission::NAME.into(),
925 raw_input: "{}".into(),
926 input: json!({}),
927 is_input_complete: true,
928 thought_signature: None,
929 },
930 ));
931 fake_model.end_last_completion_stream();
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let message = completion.messages.last().unwrap();
935 assert_eq!(
936 message.content,
937 vec![language_model::MessageContent::ToolResult(
938 LanguageModelToolResult {
939 tool_use_id: "tool_id_4".into(),
940 tool_name: ToolRequiringPermission::NAME.into(),
941 is_error: false,
942 content: "Allowed".into(),
943 output: Some("Allowed".into())
944 }
945 )]
946 );
947}
948
949#[gpui::test]
950async fn test_tool_hallucination(cx: &mut TestAppContext) {
951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
952 let fake_model = model.as_fake();
953
954 let mut events = thread
955 .update(cx, |thread, cx| {
956 thread.send(UserMessageId::new(), ["abc"], cx)
957 })
958 .unwrap();
959 cx.run_until_parked();
960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
961 LanguageModelToolUse {
962 id: "tool_id_1".into(),
963 name: "nonexistent_tool".into(),
964 raw_input: "{}".into(),
965 input: json!({}),
966 is_input_complete: true,
967 thought_signature: None,
968 },
969 ));
970 fake_model.end_last_completion_stream();
971
972 let tool_call = expect_tool_call(&mut events).await;
973 assert_eq!(tool_call.title, "nonexistent_tool");
974 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
975 let update = expect_tool_call_update_fields(&mut events).await;
976 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
977}
978
979async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
980 let event = events
981 .next()
982 .await
983 .expect("no tool call authorization event received")
984 .unwrap();
985 match event {
986 ThreadEvent::ToolCall(tool_call) => tool_call,
987 event => {
988 panic!("Unexpected event {event:?}");
989 }
990 }
991}
992
993async fn expect_tool_call_update_fields(
994 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
995) -> acp::ToolCallUpdate {
996 let event = events
997 .next()
998 .await
999 .expect("no tool call authorization event received")
1000 .unwrap();
1001 match event {
1002 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1003 event => {
1004 panic!("Unexpected event {event:?}");
1005 }
1006 }
1007}
1008
1009async fn next_tool_call_authorization(
1010 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1011) -> ToolCallAuthorization {
1012 loop {
1013 let event = events
1014 .next()
1015 .await
1016 .expect("no tool call authorization event received")
1017 .unwrap();
1018 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1019 let permission_kinds = tool_call_authorization
1020 .options
1021 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1022 .map(|option| option.kind);
1023 let allow_once = tool_call_authorization
1024 .options
1025 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1026 .map(|option| option.kind);
1027
1028 assert_eq!(
1029 permission_kinds,
1030 Some(acp::PermissionOptionKind::AllowAlways)
1031 );
1032 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1033 return tool_call_authorization;
1034 }
1035 }
1036}
1037
1038#[test]
1039fn test_permission_options_terminal_with_pattern() {
1040 let permission_options = ToolPermissionContext::new(
1041 TerminalTool::NAME,
1042 vec!["cargo build --release".to_string()],
1043 )
1044 .build_permission_options();
1045
1046 let PermissionOptions::Dropdown(choices) = permission_options else {
1047 panic!("Expected dropdown permission options");
1048 };
1049
1050 assert_eq!(choices.len(), 3);
1051 let labels: Vec<&str> = choices
1052 .iter()
1053 .map(|choice| choice.allow.name.as_ref())
1054 .collect();
1055 assert!(labels.contains(&"Always for terminal"));
1056 assert!(labels.contains(&"Always for `cargo` commands"));
1057 assert!(labels.contains(&"Only this time"));
1058}
1059
1060#[test]
1061fn test_permission_options_edit_file_with_path_pattern() {
1062 let permission_options =
1063 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1064 .build_permission_options();
1065
1066 let PermissionOptions::Dropdown(choices) = permission_options else {
1067 panic!("Expected dropdown permission options");
1068 };
1069
1070 let labels: Vec<&str> = choices
1071 .iter()
1072 .map(|choice| choice.allow.name.as_ref())
1073 .collect();
1074 assert!(labels.contains(&"Always for edit file"));
1075 assert!(labels.contains(&"Always for `src/`"));
1076}
1077
1078#[test]
1079fn test_permission_options_fetch_with_domain_pattern() {
1080 let permission_options =
1081 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1082 .build_permission_options();
1083
1084 let PermissionOptions::Dropdown(choices) = permission_options else {
1085 panic!("Expected dropdown permission options");
1086 };
1087
1088 let labels: Vec<&str> = choices
1089 .iter()
1090 .map(|choice| choice.allow.name.as_ref())
1091 .collect();
1092 assert!(labels.contains(&"Always for fetch"));
1093 assert!(labels.contains(&"Always for `docs.rs`"));
1094}
1095
1096#[test]
1097fn test_permission_options_without_pattern() {
1098 let permission_options = ToolPermissionContext::new(
1099 TerminalTool::NAME,
1100 vec!["./deploy.sh --production".to_string()],
1101 )
1102 .build_permission_options();
1103
1104 let PermissionOptions::Dropdown(choices) = permission_options else {
1105 panic!("Expected dropdown permission options");
1106 };
1107
1108 assert_eq!(choices.len(), 2);
1109 let labels: Vec<&str> = choices
1110 .iter()
1111 .map(|choice| choice.allow.name.as_ref())
1112 .collect();
1113 assert!(labels.contains(&"Always for terminal"));
1114 assert!(labels.contains(&"Only this time"));
1115 assert!(!labels.iter().any(|label| label.contains("commands")));
1116}
1117
1118#[test]
1119fn test_permission_options_symlink_target_are_flat_once_only() {
1120 let permission_options =
1121 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1122 .build_permission_options();
1123
1124 let PermissionOptions::Flat(options) = permission_options else {
1125 panic!("Expected flat permission options for symlink target authorization");
1126 };
1127
1128 assert_eq!(options.len(), 2);
1129 assert!(options.iter().any(|option| {
1130 option.option_id.0.as_ref() == "allow"
1131 && option.kind == acp::PermissionOptionKind::AllowOnce
1132 }));
1133 assert!(options.iter().any(|option| {
1134 option.option_id.0.as_ref() == "deny"
1135 && option.kind == acp::PermissionOptionKind::RejectOnce
1136 }));
1137}
1138
1139#[test]
1140fn test_permission_option_ids_for_terminal() {
1141 let permission_options = ToolPermissionContext::new(
1142 TerminalTool::NAME,
1143 vec!["cargo build --release".to_string()],
1144 )
1145 .build_permission_options();
1146
1147 let PermissionOptions::Dropdown(choices) = permission_options else {
1148 panic!("Expected dropdown permission options");
1149 };
1150
1151 let allow_ids: Vec<String> = choices
1152 .iter()
1153 .map(|choice| choice.allow.option_id.0.to_string())
1154 .collect();
1155 let deny_ids: Vec<String> = choices
1156 .iter()
1157 .map(|choice| choice.deny.option_id.0.to_string())
1158 .collect();
1159
1160 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1161 assert!(allow_ids.contains(&"allow".to_string()));
1162 assert!(
1163 allow_ids
1164 .iter()
1165 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1166 "Missing allow pattern option"
1167 );
1168
1169 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1170 assert!(deny_ids.contains(&"deny".to_string()));
1171 assert!(
1172 deny_ids
1173 .iter()
1174 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1175 "Missing deny pattern option"
1176 );
1177}
1178
1179#[gpui::test]
1180#[cfg_attr(not(feature = "e2e"), ignore)]
1181async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1182 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1183
1184 // Test concurrent tool calls with different delay times
1185 let events = thread
1186 .update(cx, |thread, cx| {
1187 thread.add_tool(DelayTool, None);
1188 thread.send(
1189 UserMessageId::new(),
1190 [
1191 "Call the delay tool twice in the same message.",
1192 "Once with 100ms. Once with 300ms.",
1193 "When both timers are complete, describe the outputs.",
1194 ],
1195 cx,
1196 )
1197 })
1198 .unwrap()
1199 .collect()
1200 .await;
1201
1202 let stop_reasons = stop_events(events);
1203 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1204
1205 thread.update(cx, |thread, _cx| {
1206 let last_message = thread.last_message().unwrap();
1207 let agent_message = last_message.as_agent_message().unwrap();
1208 let text = agent_message
1209 .content
1210 .iter()
1211 .filter_map(|content| {
1212 if let AgentMessageContent::Text(text) = content {
1213 Some(text.as_str())
1214 } else {
1215 None
1216 }
1217 })
1218 .collect::<String>();
1219
1220 assert!(text.contains("Ding"));
1221 });
1222}
1223
1224#[gpui::test]
1225async fn test_profiles(cx: &mut TestAppContext) {
1226 let ThreadTest {
1227 model, thread, fs, ..
1228 } = setup(cx, TestModel::Fake).await;
1229 let fake_model = model.as_fake();
1230
1231 thread.update(cx, |thread, _cx| {
1232 thread.add_tool(DelayTool, None);
1233 thread.add_tool(EchoTool, None);
1234 thread.add_tool(InfiniteTool, None);
1235 });
1236
1237 // Override profiles and wait for settings to be loaded.
1238 fs.insert_file(
1239 paths::settings_file(),
1240 json!({
1241 "agent": {
1242 "profiles": {
1243 "test-1": {
1244 "name": "Test Profile 1",
1245 "tools": {
1246 EchoTool::NAME: true,
1247 DelayTool::NAME: true,
1248 }
1249 },
1250 "test-2": {
1251 "name": "Test Profile 2",
1252 "tools": {
1253 InfiniteTool::NAME: true,
1254 }
1255 }
1256 }
1257 }
1258 })
1259 .to_string()
1260 .into_bytes(),
1261 )
1262 .await;
1263 cx.run_until_parked();
1264
1265 // Test that test-1 profile (default) has echo and delay tools
1266 thread
1267 .update(cx, |thread, cx| {
1268 thread.set_profile(AgentProfileId("test-1".into()), cx);
1269 thread.send(UserMessageId::new(), ["test"], cx)
1270 })
1271 .unwrap();
1272 cx.run_until_parked();
1273
1274 let mut pending_completions = fake_model.pending_completions();
1275 assert_eq!(pending_completions.len(), 1);
1276 let completion = pending_completions.pop().unwrap();
1277 let tool_names: Vec<String> = completion
1278 .tools
1279 .iter()
1280 .map(|tool| tool.name.clone())
1281 .collect();
1282 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1283 fake_model.end_last_completion_stream();
1284
1285 // Switch to test-2 profile, and verify that it has only the infinite tool.
1286 thread
1287 .update(cx, |thread, cx| {
1288 thread.set_profile(AgentProfileId("test-2".into()), cx);
1289 thread.send(UserMessageId::new(), ["test2"], cx)
1290 })
1291 .unwrap();
1292 cx.run_until_parked();
1293 let mut pending_completions = fake_model.pending_completions();
1294 assert_eq!(pending_completions.len(), 1);
1295 let completion = pending_completions.pop().unwrap();
1296 let tool_names: Vec<String> = completion
1297 .tools
1298 .iter()
1299 .map(|tool| tool.name.clone())
1300 .collect();
1301 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1302}
1303
1304#[gpui::test]
1305async fn test_mcp_tools(cx: &mut TestAppContext) {
1306 let ThreadTest {
1307 model,
1308 thread,
1309 context_server_store,
1310 fs,
1311 ..
1312 } = setup(cx, TestModel::Fake).await;
1313 let fake_model = model.as_fake();
1314
1315 // Override profiles and wait for settings to be loaded.
1316 fs.insert_file(
1317 paths::settings_file(),
1318 json!({
1319 "agent": {
1320 "tool_permissions": { "default": "allow" },
1321 "profiles": {
1322 "test": {
1323 "name": "Test Profile",
1324 "enable_all_context_servers": true,
1325 "tools": {
1326 EchoTool::NAME: true,
1327 }
1328 },
1329 }
1330 }
1331 })
1332 .to_string()
1333 .into_bytes(),
1334 )
1335 .await;
1336 cx.run_until_parked();
1337 thread.update(cx, |thread, cx| {
1338 thread.set_profile(AgentProfileId("test".into()), cx)
1339 });
1340
1341 let mut mcp_tool_calls = setup_context_server(
1342 "test_server",
1343 vec![context_server::types::Tool {
1344 name: "echo".into(),
1345 description: None,
1346 input_schema: serde_json::to_value(EchoTool::input_schema(
1347 LanguageModelToolSchemaFormat::JsonSchema,
1348 ))
1349 .unwrap(),
1350 output_schema: None,
1351 annotations: None,
1352 }],
1353 &context_server_store,
1354 cx,
1355 );
1356
1357 let events = thread.update(cx, |thread, cx| {
1358 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1359 });
1360 cx.run_until_parked();
1361
1362 // Simulate the model calling the MCP tool.
1363 let completion = fake_model.pending_completions().pop().unwrap();
1364 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1365 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1366 LanguageModelToolUse {
1367 id: "tool_1".into(),
1368 name: "echo".into(),
1369 raw_input: json!({"text": "test"}).to_string(),
1370 input: json!({"text": "test"}),
1371 is_input_complete: true,
1372 thought_signature: None,
1373 },
1374 ));
1375 fake_model.end_last_completion_stream();
1376 cx.run_until_parked();
1377
1378 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1379 assert_eq!(tool_call_params.name, "echo");
1380 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1381 tool_call_response
1382 .send(context_server::types::CallToolResponse {
1383 content: vec![context_server::types::ToolResponseContent::Text {
1384 text: "test".into(),
1385 }],
1386 is_error: None,
1387 meta: None,
1388 structured_content: None,
1389 })
1390 .unwrap();
1391 cx.run_until_parked();
1392
1393 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1394 fake_model.send_last_completion_stream_text_chunk("Done!");
1395 fake_model.end_last_completion_stream();
1396 events.collect::<Vec<_>>().await;
1397
1398 // Send again after adding the echo tool, ensuring the name collision is resolved.
1399 let events = thread.update(cx, |thread, cx| {
1400 thread.add_tool(EchoTool, None);
1401 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1402 });
1403 cx.run_until_parked();
1404 let completion = fake_model.pending_completions().pop().unwrap();
1405 assert_eq!(
1406 tool_names_for_completion(&completion),
1407 vec!["echo", "test_server_echo"]
1408 );
1409 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1410 LanguageModelToolUse {
1411 id: "tool_2".into(),
1412 name: "test_server_echo".into(),
1413 raw_input: json!({"text": "mcp"}).to_string(),
1414 input: json!({"text": "mcp"}),
1415 is_input_complete: true,
1416 thought_signature: None,
1417 },
1418 ));
1419 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1420 LanguageModelToolUse {
1421 id: "tool_3".into(),
1422 name: "echo".into(),
1423 raw_input: json!({"text": "native"}).to_string(),
1424 input: json!({"text": "native"}),
1425 is_input_complete: true,
1426 thought_signature: None,
1427 },
1428 ));
1429 fake_model.end_last_completion_stream();
1430 cx.run_until_parked();
1431
1432 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1433 assert_eq!(tool_call_params.name, "echo");
1434 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1435 tool_call_response
1436 .send(context_server::types::CallToolResponse {
1437 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1438 is_error: None,
1439 meta: None,
1440 structured_content: None,
1441 })
1442 .unwrap();
1443 cx.run_until_parked();
1444
1445 // Ensure the tool results were inserted with the correct names.
1446 let completion = fake_model.pending_completions().pop().unwrap();
1447 assert_eq!(
1448 completion.messages.last().unwrap().content,
1449 vec![
1450 MessageContent::ToolResult(LanguageModelToolResult {
1451 tool_use_id: "tool_3".into(),
1452 tool_name: "echo".into(),
1453 is_error: false,
1454 content: "native".into(),
1455 output: Some("native".into()),
1456 },),
1457 MessageContent::ToolResult(LanguageModelToolResult {
1458 tool_use_id: "tool_2".into(),
1459 tool_name: "test_server_echo".into(),
1460 is_error: false,
1461 content: "mcp".into(),
1462 output: Some("mcp".into()),
1463 },),
1464 ]
1465 );
1466 fake_model.end_last_completion_stream();
1467 events.collect::<Vec<_>>().await;
1468}
1469
1470#[gpui::test]
1471async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1472 let ThreadTest {
1473 model,
1474 thread,
1475 context_server_store,
1476 fs,
1477 ..
1478 } = setup(cx, TestModel::Fake).await;
1479 let fake_model = model.as_fake();
1480
1481 // Setup settings to allow MCP tools
1482 fs.insert_file(
1483 paths::settings_file(),
1484 json!({
1485 "agent": {
1486 "always_allow_tool_actions": true,
1487 "profiles": {
1488 "test": {
1489 "name": "Test Profile",
1490 "enable_all_context_servers": true,
1491 "tools": {}
1492 },
1493 }
1494 }
1495 })
1496 .to_string()
1497 .into_bytes(),
1498 )
1499 .await;
1500 cx.run_until_parked();
1501 thread.update(cx, |thread, cx| {
1502 thread.set_profile(AgentProfileId("test".into()), cx)
1503 });
1504
1505 // Setup a context server with a tool
1506 let mut mcp_tool_calls = setup_context_server(
1507 "github_server",
1508 vec![context_server::types::Tool {
1509 name: "issue_read".into(),
1510 description: Some("Read a GitHub issue".into()),
1511 input_schema: json!({
1512 "type": "object",
1513 "properties": {
1514 "issue_url": { "type": "string" }
1515 }
1516 }),
1517 output_schema: None,
1518 annotations: None,
1519 }],
1520 &context_server_store,
1521 cx,
1522 );
1523
1524 // Send a message and have the model call the MCP tool
1525 let events = thread.update(cx, |thread, cx| {
1526 thread
1527 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1528 .unwrap()
1529 });
1530 cx.run_until_parked();
1531
1532 // Verify the MCP tool is available to the model
1533 let completion = fake_model.pending_completions().pop().unwrap();
1534 assert_eq!(
1535 tool_names_for_completion(&completion),
1536 vec!["issue_read"],
1537 "MCP tool should be available"
1538 );
1539
1540 // Simulate the model calling the MCP tool
1541 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1542 LanguageModelToolUse {
1543 id: "tool_1".into(),
1544 name: "issue_read".into(),
1545 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1546 .to_string(),
1547 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1548 is_input_complete: true,
1549 thought_signature: None,
1550 },
1551 ));
1552 fake_model.end_last_completion_stream();
1553 cx.run_until_parked();
1554
1555 // The MCP server receives the tool call and responds with content
1556 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1557 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1558 assert_eq!(tool_call_params.name, "issue_read");
1559 tool_call_response
1560 .send(context_server::types::CallToolResponse {
1561 content: vec![context_server::types::ToolResponseContent::Text {
1562 text: expected_tool_output.into(),
1563 }],
1564 is_error: None,
1565 meta: None,
1566 structured_content: None,
1567 })
1568 .unwrap();
1569 cx.run_until_parked();
1570
1571 // After tool completes, the model continues with a new completion request
1572 // that includes the tool results. We need to respond to this.
1573 let _completion = fake_model.pending_completions().pop().unwrap();
1574 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1575 fake_model
1576 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1577 fake_model.end_last_completion_stream();
1578 events.collect::<Vec<_>>().await;
1579
1580 // Verify the tool result is stored in the thread by checking the markdown output.
1581 // The tool result is in the first assistant message (not the last one, which is
1582 // the model's response after the tool completed).
1583 thread.update(cx, |thread, _cx| {
1584 let markdown = thread.to_markdown();
1585 assert!(
1586 markdown.contains("**Tool Result**: issue_read"),
1587 "Thread should contain tool result header"
1588 );
1589 assert!(
1590 markdown.contains(expected_tool_output),
1591 "Thread should contain tool output: {}",
1592 expected_tool_output
1593 );
1594 });
1595
1596 // Simulate app restart: disconnect the MCP server.
1597 // After restart, the MCP server won't be connected yet when the thread is replayed.
1598 context_server_store.update(cx, |store, cx| {
1599 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1600 });
1601 cx.run_until_parked();
1602
1603 // Replay the thread (this is what happens when loading a saved thread)
1604 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1605
1606 let mut found_tool_call = None;
1607 let mut found_tool_call_update_with_output = None;
1608
1609 while let Some(event) = replay_events.next().await {
1610 let event = event.unwrap();
1611 match &event {
1612 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1613 found_tool_call = Some(tc.clone());
1614 }
1615 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1616 if update.tool_call_id.to_string() == "tool_1" =>
1617 {
1618 if update.fields.raw_output.is_some() {
1619 found_tool_call_update_with_output = Some(update.clone());
1620 }
1621 }
1622 _ => {}
1623 }
1624 }
1625
1626 // The tool call should be found
1627 assert!(
1628 found_tool_call.is_some(),
1629 "Tool call should be emitted during replay"
1630 );
1631
1632 assert!(
1633 found_tool_call_update_with_output.is_some(),
1634 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1635 );
1636
1637 let update = found_tool_call_update_with_output.unwrap();
1638 assert_eq!(
1639 update.fields.raw_output,
1640 Some(expected_tool_output.into()),
1641 "raw_output should contain the saved tool result"
1642 );
1643
1644 // Also verify the status is correct (completed, not failed)
1645 assert_eq!(
1646 update.fields.status,
1647 Some(acp::ToolCallStatus::Completed),
1648 "Tool call status should reflect the original completion status"
1649 );
1650}
1651
1652#[gpui::test]
1653async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1654 let ThreadTest {
1655 model,
1656 thread,
1657 context_server_store,
1658 fs,
1659 ..
1660 } = setup(cx, TestModel::Fake).await;
1661 let fake_model = model.as_fake();
1662
1663 // Set up a profile with all tools enabled
1664 fs.insert_file(
1665 paths::settings_file(),
1666 json!({
1667 "agent": {
1668 "profiles": {
1669 "test": {
1670 "name": "Test Profile",
1671 "enable_all_context_servers": true,
1672 "tools": {
1673 EchoTool::NAME: true,
1674 DelayTool::NAME: true,
1675 WordListTool::NAME: true,
1676 ToolRequiringPermission::NAME: true,
1677 InfiniteTool::NAME: true,
1678 }
1679 },
1680 }
1681 }
1682 })
1683 .to_string()
1684 .into_bytes(),
1685 )
1686 .await;
1687 cx.run_until_parked();
1688
1689 thread.update(cx, |thread, cx| {
1690 thread.set_profile(AgentProfileId("test".into()), cx);
1691 thread.add_tool(EchoTool, None);
1692 thread.add_tool(DelayTool, None);
1693 thread.add_tool(WordListTool, None);
1694 thread.add_tool(ToolRequiringPermission, None);
1695 thread.add_tool(InfiniteTool, None);
1696 });
1697
1698 // Set up multiple context servers with some overlapping tool names
1699 let _server1_calls = setup_context_server(
1700 "xxx",
1701 vec![
1702 context_server::types::Tool {
1703 name: "echo".into(), // Conflicts with native EchoTool
1704 description: None,
1705 input_schema: serde_json::to_value(EchoTool::input_schema(
1706 LanguageModelToolSchemaFormat::JsonSchema,
1707 ))
1708 .unwrap(),
1709 output_schema: None,
1710 annotations: None,
1711 },
1712 context_server::types::Tool {
1713 name: "unique_tool_1".into(),
1714 description: None,
1715 input_schema: json!({"type": "object", "properties": {}}),
1716 output_schema: None,
1717 annotations: None,
1718 },
1719 ],
1720 &context_server_store,
1721 cx,
1722 );
1723
1724 let _server2_calls = setup_context_server(
1725 "yyy",
1726 vec![
1727 context_server::types::Tool {
1728 name: "echo".into(), // Also conflicts with native EchoTool
1729 description: None,
1730 input_schema: serde_json::to_value(EchoTool::input_schema(
1731 LanguageModelToolSchemaFormat::JsonSchema,
1732 ))
1733 .unwrap(),
1734 output_schema: None,
1735 annotations: None,
1736 },
1737 context_server::types::Tool {
1738 name: "unique_tool_2".into(),
1739 description: None,
1740 input_schema: json!({"type": "object", "properties": {}}),
1741 output_schema: None,
1742 annotations: None,
1743 },
1744 context_server::types::Tool {
1745 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1746 description: None,
1747 input_schema: json!({"type": "object", "properties": {}}),
1748 output_schema: None,
1749 annotations: None,
1750 },
1751 context_server::types::Tool {
1752 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1753 description: None,
1754 input_schema: json!({"type": "object", "properties": {}}),
1755 output_schema: None,
1756 annotations: None,
1757 },
1758 ],
1759 &context_server_store,
1760 cx,
1761 );
1762 let _server3_calls = setup_context_server(
1763 "zzz",
1764 vec![
1765 context_server::types::Tool {
1766 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1767 description: None,
1768 input_schema: json!({"type": "object", "properties": {}}),
1769 output_schema: None,
1770 annotations: None,
1771 },
1772 context_server::types::Tool {
1773 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1774 description: None,
1775 input_schema: json!({"type": "object", "properties": {}}),
1776 output_schema: None,
1777 annotations: None,
1778 },
1779 context_server::types::Tool {
1780 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1781 description: None,
1782 input_schema: json!({"type": "object", "properties": {}}),
1783 output_schema: None,
1784 annotations: None,
1785 },
1786 ],
1787 &context_server_store,
1788 cx,
1789 );
1790
1791 // Server with spaces in name - tests snake_case conversion for API compatibility
1792 let _server4_calls = setup_context_server(
1793 "Azure DevOps",
1794 vec![context_server::types::Tool {
1795 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1796 description: None,
1797 input_schema: serde_json::to_value(EchoTool::input_schema(
1798 LanguageModelToolSchemaFormat::JsonSchema,
1799 ))
1800 .unwrap(),
1801 output_schema: None,
1802 annotations: None,
1803 }],
1804 &context_server_store,
1805 cx,
1806 );
1807
1808 thread
1809 .update(cx, |thread, cx| {
1810 thread.send(UserMessageId::new(), ["Go"], cx)
1811 })
1812 .unwrap();
1813 cx.run_until_parked();
1814 let completion = fake_model.pending_completions().pop().unwrap();
1815 assert_eq!(
1816 tool_names_for_completion(&completion),
1817 vec![
1818 "azure_dev_ops_echo",
1819 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1820 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1821 "delay",
1822 "echo",
1823 "infinite",
1824 "tool_requiring_permission",
1825 "unique_tool_1",
1826 "unique_tool_2",
1827 "word_list",
1828 "xxx_echo",
1829 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1830 "yyy_echo",
1831 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1832 ]
1833 );
1834}
1835
1836#[gpui::test]
1837#[cfg_attr(not(feature = "e2e"), ignore)]
1838async fn test_cancellation(cx: &mut TestAppContext) {
1839 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1840
1841 let mut events = thread
1842 .update(cx, |thread, cx| {
1843 thread.add_tool(InfiniteTool, None);
1844 thread.add_tool(EchoTool, None);
1845 thread.send(
1846 UserMessageId::new(),
1847 ["Call the echo tool, then call the infinite tool, then explain their output"],
1848 cx,
1849 )
1850 })
1851 .unwrap();
1852
1853 // Wait until both tools are called.
1854 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1855 let mut echo_id = None;
1856 let mut echo_completed = false;
1857 while let Some(event) = events.next().await {
1858 match event.unwrap() {
1859 ThreadEvent::ToolCall(tool_call) => {
1860 assert_eq!(tool_call.title, expected_tools.remove(0));
1861 if tool_call.title == "Echo" {
1862 echo_id = Some(tool_call.tool_call_id);
1863 }
1864 }
1865 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1866 acp::ToolCallUpdate {
1867 tool_call_id,
1868 fields:
1869 acp::ToolCallUpdateFields {
1870 status: Some(acp::ToolCallStatus::Completed),
1871 ..
1872 },
1873 ..
1874 },
1875 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1876 echo_completed = true;
1877 }
1878 _ => {}
1879 }
1880
1881 if expected_tools.is_empty() && echo_completed {
1882 break;
1883 }
1884 }
1885
1886 // Cancel the current send and ensure that the event stream is closed, even
1887 // if one of the tools is still running.
1888 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1889 let events = events.collect::<Vec<_>>().await;
1890 let last_event = events.last();
1891 assert!(
1892 matches!(
1893 last_event,
1894 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1895 ),
1896 "unexpected event {last_event:?}"
1897 );
1898
1899 // Ensure we can still send a new message after cancellation.
1900 let events = thread
1901 .update(cx, |thread, cx| {
1902 thread.send(
1903 UserMessageId::new(),
1904 ["Testing: reply with 'Hello' then stop."],
1905 cx,
1906 )
1907 })
1908 .unwrap()
1909 .collect::<Vec<_>>()
1910 .await;
1911 thread.update(cx, |thread, _cx| {
1912 let message = thread.last_message().unwrap();
1913 let agent_message = message.as_agent_message().unwrap();
1914 assert_eq!(
1915 agent_message.content,
1916 vec![AgentMessageContent::Text("Hello".to_string())]
1917 );
1918 });
1919 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1920}
1921
1922#[gpui::test]
1923async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1924 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1925 always_allow_tools(cx);
1926 let fake_model = model.as_fake();
1927
1928 let environment = Rc::new(cx.update(|cx| {
1929 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1930 }));
1931 let handle = environment.terminal_handle.clone().unwrap();
1932
1933 let mut events = thread
1934 .update(cx, |thread, cx| {
1935 thread.add_tool(
1936 crate::TerminalTool::new(thread.project().clone(), environment),
1937 None,
1938 );
1939 thread.send(UserMessageId::new(), ["run a command"], cx)
1940 })
1941 .unwrap();
1942
1943 cx.run_until_parked();
1944
1945 // Simulate the model calling the terminal tool
1946 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1947 LanguageModelToolUse {
1948 id: "terminal_tool_1".into(),
1949 name: TerminalTool::NAME.into(),
1950 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1951 input: json!({"command": "sleep 1000", "cd": "."}),
1952 is_input_complete: true,
1953 thought_signature: None,
1954 },
1955 ));
1956 fake_model.end_last_completion_stream();
1957
1958 // Wait for the terminal tool to start running
1959 wait_for_terminal_tool_started(&mut events, cx).await;
1960
1961 // Cancel the thread while the terminal is running
1962 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1963
1964 // Collect remaining events, driving the executor to let cancellation complete
1965 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1966
1967 // Verify the terminal was killed
1968 assert!(
1969 handle.was_killed(),
1970 "expected terminal handle to be killed on cancellation"
1971 );
1972
1973 // Verify we got a cancellation stop event
1974 assert_eq!(
1975 stop_events(remaining_events),
1976 vec![acp::StopReason::Cancelled],
1977 );
1978
1979 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1980 thread.update(cx, |thread, _cx| {
1981 let message = thread.last_message().unwrap();
1982 let agent_message = message.as_agent_message().unwrap();
1983
1984 let tool_use = agent_message
1985 .content
1986 .iter()
1987 .find_map(|content| match content {
1988 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1989 _ => None,
1990 })
1991 .expect("expected tool use in agent message");
1992
1993 let tool_result = agent_message
1994 .tool_results
1995 .get(&tool_use.id)
1996 .expect("expected tool result");
1997
1998 let result_text = match &tool_result.content {
1999 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2000 _ => panic!("expected text content in tool result"),
2001 };
2002
2003 // "partial output" comes from FakeTerminalHandle's output field
2004 assert!(
2005 result_text.contains("partial output"),
2006 "expected tool result to contain terminal output, got: {result_text}"
2007 );
2008 // Match the actual format from process_content in terminal_tool.rs
2009 assert!(
2010 result_text.contains("The user stopped this command"),
2011 "expected tool result to indicate user stopped, got: {result_text}"
2012 );
2013 });
2014
2015 // Verify we can send a new message after cancellation
2016 verify_thread_recovery(&thread, &fake_model, cx).await;
2017}
2018
2019#[gpui::test]
2020async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2021 // This test verifies that tools which properly handle cancellation via
2022 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2023 // to cancellation and report that they were cancelled.
2024 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2025 always_allow_tools(cx);
2026 let fake_model = model.as_fake();
2027
2028 let (tool, was_cancelled) = CancellationAwareTool::new();
2029
2030 let mut events = thread
2031 .update(cx, |thread, cx| {
2032 thread.add_tool(tool, None);
2033 thread.send(
2034 UserMessageId::new(),
2035 ["call the cancellation aware tool"],
2036 cx,
2037 )
2038 })
2039 .unwrap();
2040
2041 cx.run_until_parked();
2042
2043 // Simulate the model calling the cancellation-aware tool
2044 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2045 LanguageModelToolUse {
2046 id: "cancellation_aware_1".into(),
2047 name: "cancellation_aware".into(),
2048 raw_input: r#"{}"#.into(),
2049 input: json!({}),
2050 is_input_complete: true,
2051 thought_signature: None,
2052 },
2053 ));
2054 fake_model.end_last_completion_stream();
2055
2056 cx.run_until_parked();
2057
2058 // Wait for the tool call to be reported
2059 let mut tool_started = false;
2060 let deadline = cx.executor().num_cpus() * 100;
2061 for _ in 0..deadline {
2062 cx.run_until_parked();
2063
2064 while let Some(Some(event)) = events.next().now_or_never() {
2065 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2066 if tool_call.title == "Cancellation Aware Tool" {
2067 tool_started = true;
2068 break;
2069 }
2070 }
2071 }
2072
2073 if tool_started {
2074 break;
2075 }
2076
2077 cx.background_executor
2078 .timer(Duration::from_millis(10))
2079 .await;
2080 }
2081 assert!(tool_started, "expected cancellation aware tool to start");
2082
2083 // Cancel the thread and wait for it to complete
2084 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2085
2086 // The cancel task should complete promptly because the tool handles cancellation
2087 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2088 futures::select! {
2089 _ = cancel_task.fuse() => {}
2090 _ = timeout.fuse() => {
2091 panic!("cancel task timed out - tool did not respond to cancellation");
2092 }
2093 }
2094
2095 // Verify the tool detected cancellation via its flag
2096 assert!(
2097 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2098 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2099 );
2100
2101 // Collect remaining events
2102 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2103
2104 // Verify we got a cancellation stop event
2105 assert_eq!(
2106 stop_events(remaining_events),
2107 vec![acp::StopReason::Cancelled],
2108 );
2109
2110 // Verify we can send a new message after cancellation
2111 verify_thread_recovery(&thread, &fake_model, cx).await;
2112}
2113
2114/// Helper to verify thread can recover after cancellation by sending a simple message.
2115async fn verify_thread_recovery(
2116 thread: &Entity<Thread>,
2117 fake_model: &FakeLanguageModel,
2118 cx: &mut TestAppContext,
2119) {
2120 let events = thread
2121 .update(cx, |thread, cx| {
2122 thread.send(
2123 UserMessageId::new(),
2124 ["Testing: reply with 'Hello' then stop."],
2125 cx,
2126 )
2127 })
2128 .unwrap();
2129 cx.run_until_parked();
2130 fake_model.send_last_completion_stream_text_chunk("Hello");
2131 fake_model
2132 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2133 fake_model.end_last_completion_stream();
2134
2135 let events = events.collect::<Vec<_>>().await;
2136 thread.update(cx, |thread, _cx| {
2137 let message = thread.last_message().unwrap();
2138 let agent_message = message.as_agent_message().unwrap();
2139 assert_eq!(
2140 agent_message.content,
2141 vec![AgentMessageContent::Text("Hello".to_string())]
2142 );
2143 });
2144 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2145}
2146
2147/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2148async fn wait_for_terminal_tool_started(
2149 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2150 cx: &mut TestAppContext,
2151) {
2152 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2153 for _ in 0..deadline {
2154 cx.run_until_parked();
2155
2156 while let Some(Some(event)) = events.next().now_or_never() {
2157 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2158 update,
2159 ))) = &event
2160 {
2161 if update.fields.content.as_ref().is_some_and(|content| {
2162 content
2163 .iter()
2164 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2165 }) {
2166 return;
2167 }
2168 }
2169 }
2170
2171 cx.background_executor
2172 .timer(Duration::from_millis(10))
2173 .await;
2174 }
2175 panic!("terminal tool did not start within the expected time");
2176}
2177
2178/// Collects events until a Stop event is received, driving the executor to completion.
2179async fn collect_events_until_stop(
2180 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2181 cx: &mut TestAppContext,
2182) -> Vec<Result<ThreadEvent>> {
2183 let mut collected = Vec::new();
2184 let deadline = cx.executor().num_cpus() * 200;
2185
2186 for _ in 0..deadline {
2187 cx.executor().advance_clock(Duration::from_millis(10));
2188 cx.run_until_parked();
2189
2190 while let Some(Some(event)) = events.next().now_or_never() {
2191 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2192 collected.push(event);
2193 if is_stop {
2194 return collected;
2195 }
2196 }
2197 }
2198 panic!(
2199 "did not receive Stop event within the expected time; collected {} events",
2200 collected.len()
2201 );
2202}
2203
2204#[gpui::test]
2205async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2206 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2207 always_allow_tools(cx);
2208 let fake_model = model.as_fake();
2209
2210 let environment = Rc::new(cx.update(|cx| {
2211 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2212 }));
2213 let handle = environment.terminal_handle.clone().unwrap();
2214
2215 let message_id = UserMessageId::new();
2216 let mut events = thread
2217 .update(cx, |thread, cx| {
2218 thread.add_tool(
2219 crate::TerminalTool::new(thread.project().clone(), environment),
2220 None,
2221 );
2222 thread.send(message_id.clone(), ["run a command"], cx)
2223 })
2224 .unwrap();
2225
2226 cx.run_until_parked();
2227
2228 // Simulate the model calling the terminal tool
2229 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2230 LanguageModelToolUse {
2231 id: "terminal_tool_1".into(),
2232 name: TerminalTool::NAME.into(),
2233 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2234 input: json!({"command": "sleep 1000", "cd": "."}),
2235 is_input_complete: true,
2236 thought_signature: None,
2237 },
2238 ));
2239 fake_model.end_last_completion_stream();
2240
2241 // Wait for the terminal tool to start running
2242 wait_for_terminal_tool_started(&mut events, cx).await;
2243
2244 // Truncate the thread while the terminal is running
2245 thread
2246 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2247 .unwrap();
2248
2249 // Drive the executor to let cancellation complete
2250 let _ = collect_events_until_stop(&mut events, cx).await;
2251
2252 // Verify the terminal was killed
2253 assert!(
2254 handle.was_killed(),
2255 "expected terminal handle to be killed on truncate"
2256 );
2257
2258 // Verify the thread is empty after truncation
2259 thread.update(cx, |thread, _cx| {
2260 assert_eq!(
2261 thread.to_markdown(),
2262 "",
2263 "expected thread to be empty after truncating the only message"
2264 );
2265 });
2266
2267 // Verify we can send a new message after truncation
2268 verify_thread_recovery(&thread, &fake_model, cx).await;
2269}
2270
2271#[gpui::test]
2272async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2273 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2274 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2275 always_allow_tools(cx);
2276 let fake_model = model.as_fake();
2277
2278 let environment = Rc::new(MultiTerminalEnvironment::new());
2279
2280 let mut events = thread
2281 .update(cx, |thread, cx| {
2282 thread.add_tool(
2283 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2284 None,
2285 );
2286 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2287 })
2288 .unwrap();
2289
2290 cx.run_until_parked();
2291
2292 // Simulate the model calling two terminal tools
2293 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2294 LanguageModelToolUse {
2295 id: "terminal_tool_1".into(),
2296 name: TerminalTool::NAME.into(),
2297 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2298 input: json!({"command": "sleep 1000", "cd": "."}),
2299 is_input_complete: true,
2300 thought_signature: None,
2301 },
2302 ));
2303 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2304 LanguageModelToolUse {
2305 id: "terminal_tool_2".into(),
2306 name: TerminalTool::NAME.into(),
2307 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2308 input: json!({"command": "sleep 2000", "cd": "."}),
2309 is_input_complete: true,
2310 thought_signature: None,
2311 },
2312 ));
2313 fake_model.end_last_completion_stream();
2314
2315 // Wait for both terminal tools to start by counting terminal content updates
2316 let mut terminals_started = 0;
2317 let deadline = cx.executor().num_cpus() * 100;
2318 for _ in 0..deadline {
2319 cx.run_until_parked();
2320
2321 while let Some(Some(event)) = events.next().now_or_never() {
2322 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2323 update,
2324 ))) = &event
2325 {
2326 if update.fields.content.as_ref().is_some_and(|content| {
2327 content
2328 .iter()
2329 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2330 }) {
2331 terminals_started += 1;
2332 if terminals_started >= 2 {
2333 break;
2334 }
2335 }
2336 }
2337 }
2338 if terminals_started >= 2 {
2339 break;
2340 }
2341
2342 cx.background_executor
2343 .timer(Duration::from_millis(10))
2344 .await;
2345 }
2346 assert!(
2347 terminals_started >= 2,
2348 "expected 2 terminal tools to start, got {terminals_started}"
2349 );
2350
2351 // Cancel the thread while both terminals are running
2352 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2353
2354 // Collect remaining events
2355 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2356
2357 // Verify both terminal handles were killed
2358 let handles = environment.handles();
2359 assert_eq!(
2360 handles.len(),
2361 2,
2362 "expected 2 terminal handles to be created"
2363 );
2364 assert!(
2365 handles[0].was_killed(),
2366 "expected first terminal handle to be killed on cancellation"
2367 );
2368 assert!(
2369 handles[1].was_killed(),
2370 "expected second terminal handle to be killed on cancellation"
2371 );
2372
2373 // Verify we got a cancellation stop event
2374 assert_eq!(
2375 stop_events(remaining_events),
2376 vec![acp::StopReason::Cancelled],
2377 );
2378}
2379
2380#[gpui::test]
2381async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2382 // Tests that clicking the stop button on the terminal card (as opposed to the main
2383 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2384 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2385 always_allow_tools(cx);
2386 let fake_model = model.as_fake();
2387
2388 let environment = Rc::new(cx.update(|cx| {
2389 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2390 }));
2391 let handle = environment.terminal_handle.clone().unwrap();
2392
2393 let mut events = thread
2394 .update(cx, |thread, cx| {
2395 thread.add_tool(
2396 crate::TerminalTool::new(thread.project().clone(), environment),
2397 None,
2398 );
2399 thread.send(UserMessageId::new(), ["run a command"], cx)
2400 })
2401 .unwrap();
2402
2403 cx.run_until_parked();
2404
2405 // Simulate the model calling the terminal tool
2406 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2407 LanguageModelToolUse {
2408 id: "terminal_tool_1".into(),
2409 name: TerminalTool::NAME.into(),
2410 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2411 input: json!({"command": "sleep 1000", "cd": "."}),
2412 is_input_complete: true,
2413 thought_signature: None,
2414 },
2415 ));
2416 fake_model.end_last_completion_stream();
2417
2418 // Wait for the terminal tool to start running
2419 wait_for_terminal_tool_started(&mut events, cx).await;
2420
2421 // Simulate user clicking stop on the terminal card itself.
2422 // This sets the flag and signals exit (simulating what the real UI would do).
2423 handle.set_stopped_by_user(true);
2424 handle.killed.store(true, Ordering::SeqCst);
2425 handle.signal_exit();
2426
2427 // Wait for the tool to complete
2428 cx.run_until_parked();
2429
2430 // The thread continues after tool completion - simulate the model ending its turn
2431 fake_model
2432 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2433 fake_model.end_last_completion_stream();
2434
2435 // Collect remaining events
2436 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2437
2438 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2439 assert_eq!(
2440 stop_events(remaining_events),
2441 vec![acp::StopReason::EndTurn],
2442 );
2443
2444 // Verify the tool result indicates user stopped
2445 thread.update(cx, |thread, _cx| {
2446 let message = thread.last_message().unwrap();
2447 let agent_message = message.as_agent_message().unwrap();
2448
2449 let tool_use = agent_message
2450 .content
2451 .iter()
2452 .find_map(|content| match content {
2453 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2454 _ => None,
2455 })
2456 .expect("expected tool use in agent message");
2457
2458 let tool_result = agent_message
2459 .tool_results
2460 .get(&tool_use.id)
2461 .expect("expected tool result");
2462
2463 let result_text = match &tool_result.content {
2464 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2465 _ => panic!("expected text content in tool result"),
2466 };
2467
2468 assert!(
2469 result_text.contains("The user stopped this command"),
2470 "expected tool result to indicate user stopped, got: {result_text}"
2471 );
2472 });
2473}
2474
2475#[gpui::test]
2476async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2477 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2478 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2479 always_allow_tools(cx);
2480 let fake_model = model.as_fake();
2481
2482 let environment = Rc::new(cx.update(|cx| {
2483 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2484 }));
2485 let handle = environment.terminal_handle.clone().unwrap();
2486
2487 let mut events = thread
2488 .update(cx, |thread, cx| {
2489 thread.add_tool(
2490 crate::TerminalTool::new(thread.project().clone(), environment),
2491 None,
2492 );
2493 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2494 })
2495 .unwrap();
2496
2497 cx.run_until_parked();
2498
2499 // Simulate the model calling the terminal tool with a short timeout
2500 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2501 LanguageModelToolUse {
2502 id: "terminal_tool_1".into(),
2503 name: TerminalTool::NAME.into(),
2504 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2505 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2506 is_input_complete: true,
2507 thought_signature: None,
2508 },
2509 ));
2510 fake_model.end_last_completion_stream();
2511
2512 // Wait for the terminal tool to start running
2513 wait_for_terminal_tool_started(&mut events, cx).await;
2514
2515 // Advance clock past the timeout
2516 cx.executor().advance_clock(Duration::from_millis(200));
2517 cx.run_until_parked();
2518
2519 // The thread continues after tool completion - simulate the model ending its turn
2520 fake_model
2521 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2522 fake_model.end_last_completion_stream();
2523
2524 // Collect remaining events
2525 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2526
2527 // Verify the terminal was killed due to timeout
2528 assert!(
2529 handle.was_killed(),
2530 "expected terminal handle to be killed on timeout"
2531 );
2532
2533 // Verify we got an EndTurn (the tool completed, just with timeout)
2534 assert_eq!(
2535 stop_events(remaining_events),
2536 vec![acp::StopReason::EndTurn],
2537 );
2538
2539 // Verify the tool result indicates timeout, not user stopped
2540 thread.update(cx, |thread, _cx| {
2541 let message = thread.last_message().unwrap();
2542 let agent_message = message.as_agent_message().unwrap();
2543
2544 let tool_use = agent_message
2545 .content
2546 .iter()
2547 .find_map(|content| match content {
2548 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2549 _ => None,
2550 })
2551 .expect("expected tool use in agent message");
2552
2553 let tool_result = agent_message
2554 .tool_results
2555 .get(&tool_use.id)
2556 .expect("expected tool result");
2557
2558 let result_text = match &tool_result.content {
2559 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2560 _ => panic!("expected text content in tool result"),
2561 };
2562
2563 assert!(
2564 result_text.contains("timed out"),
2565 "expected tool result to indicate timeout, got: {result_text}"
2566 );
2567 assert!(
2568 !result_text.contains("The user stopped"),
2569 "tool result should not mention user stopped when it timed out, got: {result_text}"
2570 );
2571 });
2572}
2573
2574#[gpui::test]
2575async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2576 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2577 let fake_model = model.as_fake();
2578
2579 let events_1 = thread
2580 .update(cx, |thread, cx| {
2581 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2582 })
2583 .unwrap();
2584 cx.run_until_parked();
2585 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2586 cx.run_until_parked();
2587
2588 let events_2 = thread
2589 .update(cx, |thread, cx| {
2590 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2591 })
2592 .unwrap();
2593 cx.run_until_parked();
2594 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2595 fake_model
2596 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2597 fake_model.end_last_completion_stream();
2598
2599 let events_1 = events_1.collect::<Vec<_>>().await;
2600 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2601 let events_2 = events_2.collect::<Vec<_>>().await;
2602 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2603}
2604
2605#[gpui::test]
2606async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2607 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2608 let fake_model = model.as_fake();
2609
2610 let events_1 = thread
2611 .update(cx, |thread, cx| {
2612 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2613 })
2614 .unwrap();
2615 cx.run_until_parked();
2616 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2617 fake_model
2618 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2619 fake_model.end_last_completion_stream();
2620 let events_1 = events_1.collect::<Vec<_>>().await;
2621
2622 let events_2 = thread
2623 .update(cx, |thread, cx| {
2624 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2625 })
2626 .unwrap();
2627 cx.run_until_parked();
2628 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2629 fake_model
2630 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2631 fake_model.end_last_completion_stream();
2632 let events_2 = events_2.collect::<Vec<_>>().await;
2633
2634 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2635 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2636}
2637
2638#[gpui::test]
2639async fn test_refusal(cx: &mut TestAppContext) {
2640 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2641 let fake_model = model.as_fake();
2642
2643 let events = thread
2644 .update(cx, |thread, cx| {
2645 thread.send(UserMessageId::new(), ["Hello"], cx)
2646 })
2647 .unwrap();
2648 cx.run_until_parked();
2649 thread.read_with(cx, |thread, _| {
2650 assert_eq!(
2651 thread.to_markdown(),
2652 indoc! {"
2653 ## User
2654
2655 Hello
2656 "}
2657 );
2658 });
2659
2660 fake_model.send_last_completion_stream_text_chunk("Hey!");
2661 cx.run_until_parked();
2662 thread.read_with(cx, |thread, _| {
2663 assert_eq!(
2664 thread.to_markdown(),
2665 indoc! {"
2666 ## User
2667
2668 Hello
2669
2670 ## Assistant
2671
2672 Hey!
2673 "}
2674 );
2675 });
2676
2677 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2678 fake_model
2679 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2680 let events = events.collect::<Vec<_>>().await;
2681 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2682 thread.read_with(cx, |thread, _| {
2683 assert_eq!(thread.to_markdown(), "");
2684 });
2685}
2686
2687#[gpui::test]
2688async fn test_truncate_first_message(cx: &mut TestAppContext) {
2689 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2690 let fake_model = model.as_fake();
2691
2692 let message_id = UserMessageId::new();
2693 thread
2694 .update(cx, |thread, cx| {
2695 thread.send(message_id.clone(), ["Hello"], cx)
2696 })
2697 .unwrap();
2698 cx.run_until_parked();
2699 thread.read_with(cx, |thread, _| {
2700 assert_eq!(
2701 thread.to_markdown(),
2702 indoc! {"
2703 ## User
2704
2705 Hello
2706 "}
2707 );
2708 assert_eq!(thread.latest_token_usage(), None);
2709 });
2710
2711 fake_model.send_last_completion_stream_text_chunk("Hey!");
2712 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2713 language_model::TokenUsage {
2714 input_tokens: 32_000,
2715 output_tokens: 16_000,
2716 cache_creation_input_tokens: 0,
2717 cache_read_input_tokens: 0,
2718 },
2719 ));
2720 cx.run_until_parked();
2721 thread.read_with(cx, |thread, _| {
2722 assert_eq!(
2723 thread.to_markdown(),
2724 indoc! {"
2725 ## User
2726
2727 Hello
2728
2729 ## Assistant
2730
2731 Hey!
2732 "}
2733 );
2734 assert_eq!(
2735 thread.latest_token_usage(),
2736 Some(acp_thread::TokenUsage {
2737 used_tokens: 32_000 + 16_000,
2738 max_tokens: 1_000_000,
2739 max_output_tokens: None,
2740 input_tokens: 32_000,
2741 output_tokens: 16_000,
2742 })
2743 );
2744 });
2745
2746 thread
2747 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2748 .unwrap();
2749 cx.run_until_parked();
2750 thread.read_with(cx, |thread, _| {
2751 assert_eq!(thread.to_markdown(), "");
2752 assert_eq!(thread.latest_token_usage(), None);
2753 });
2754
2755 // Ensure we can still send a new message after truncation.
2756 thread
2757 .update(cx, |thread, cx| {
2758 thread.send(UserMessageId::new(), ["Hi"], cx)
2759 })
2760 .unwrap();
2761 thread.update(cx, |thread, _cx| {
2762 assert_eq!(
2763 thread.to_markdown(),
2764 indoc! {"
2765 ## User
2766
2767 Hi
2768 "}
2769 );
2770 });
2771 cx.run_until_parked();
2772 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2773 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2774 language_model::TokenUsage {
2775 input_tokens: 40_000,
2776 output_tokens: 20_000,
2777 cache_creation_input_tokens: 0,
2778 cache_read_input_tokens: 0,
2779 },
2780 ));
2781 cx.run_until_parked();
2782 thread.read_with(cx, |thread, _| {
2783 assert_eq!(
2784 thread.to_markdown(),
2785 indoc! {"
2786 ## User
2787
2788 Hi
2789
2790 ## Assistant
2791
2792 Ahoy!
2793 "}
2794 );
2795
2796 assert_eq!(
2797 thread.latest_token_usage(),
2798 Some(acp_thread::TokenUsage {
2799 used_tokens: 40_000 + 20_000,
2800 max_tokens: 1_000_000,
2801 max_output_tokens: None,
2802 input_tokens: 40_000,
2803 output_tokens: 20_000,
2804 })
2805 );
2806 });
2807}
2808
2809#[gpui::test]
2810async fn test_truncate_second_message(cx: &mut TestAppContext) {
2811 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2812 let fake_model = model.as_fake();
2813
2814 thread
2815 .update(cx, |thread, cx| {
2816 thread.send(UserMessageId::new(), ["Message 1"], cx)
2817 })
2818 .unwrap();
2819 cx.run_until_parked();
2820 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2821 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2822 language_model::TokenUsage {
2823 input_tokens: 32_000,
2824 output_tokens: 16_000,
2825 cache_creation_input_tokens: 0,
2826 cache_read_input_tokens: 0,
2827 },
2828 ));
2829 fake_model.end_last_completion_stream();
2830 cx.run_until_parked();
2831
2832 let assert_first_message_state = |cx: &mut TestAppContext| {
2833 thread.clone().read_with(cx, |thread, _| {
2834 assert_eq!(
2835 thread.to_markdown(),
2836 indoc! {"
2837 ## User
2838
2839 Message 1
2840
2841 ## Assistant
2842
2843 Message 1 response
2844 "}
2845 );
2846
2847 assert_eq!(
2848 thread.latest_token_usage(),
2849 Some(acp_thread::TokenUsage {
2850 used_tokens: 32_000 + 16_000,
2851 max_tokens: 1_000_000,
2852 max_output_tokens: None,
2853 input_tokens: 32_000,
2854 output_tokens: 16_000,
2855 })
2856 );
2857 });
2858 };
2859
2860 assert_first_message_state(cx);
2861
2862 let second_message_id = UserMessageId::new();
2863 thread
2864 .update(cx, |thread, cx| {
2865 thread.send(second_message_id.clone(), ["Message 2"], cx)
2866 })
2867 .unwrap();
2868 cx.run_until_parked();
2869
2870 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2871 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2872 language_model::TokenUsage {
2873 input_tokens: 40_000,
2874 output_tokens: 20_000,
2875 cache_creation_input_tokens: 0,
2876 cache_read_input_tokens: 0,
2877 },
2878 ));
2879 fake_model.end_last_completion_stream();
2880 cx.run_until_parked();
2881
2882 thread.read_with(cx, |thread, _| {
2883 assert_eq!(
2884 thread.to_markdown(),
2885 indoc! {"
2886 ## User
2887
2888 Message 1
2889
2890 ## Assistant
2891
2892 Message 1 response
2893
2894 ## User
2895
2896 Message 2
2897
2898 ## Assistant
2899
2900 Message 2 response
2901 "}
2902 );
2903
2904 assert_eq!(
2905 thread.latest_token_usage(),
2906 Some(acp_thread::TokenUsage {
2907 used_tokens: 40_000 + 20_000,
2908 max_tokens: 1_000_000,
2909 max_output_tokens: None,
2910 input_tokens: 40_000,
2911 output_tokens: 20_000,
2912 })
2913 );
2914 });
2915
2916 thread
2917 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2918 .unwrap();
2919 cx.run_until_parked();
2920
2921 assert_first_message_state(cx);
2922}
2923
2924#[gpui::test]
2925async fn test_title_generation(cx: &mut TestAppContext) {
2926 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2927 let fake_model = model.as_fake();
2928
2929 let summary_model = Arc::new(FakeLanguageModel::default());
2930 thread.update(cx, |thread, cx| {
2931 thread.set_summarization_model(Some(summary_model.clone()), cx)
2932 });
2933
2934 let send = thread
2935 .update(cx, |thread, cx| {
2936 thread.send(UserMessageId::new(), ["Hello"], cx)
2937 })
2938 .unwrap();
2939 cx.run_until_parked();
2940
2941 fake_model.send_last_completion_stream_text_chunk("Hey!");
2942 fake_model.end_last_completion_stream();
2943 cx.run_until_parked();
2944 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2945
2946 // Ensure the summary model has been invoked to generate a title.
2947 summary_model.send_last_completion_stream_text_chunk("Hello ");
2948 summary_model.send_last_completion_stream_text_chunk("world\nG");
2949 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2950 summary_model.end_last_completion_stream();
2951 send.collect::<Vec<_>>().await;
2952 cx.run_until_parked();
2953 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2954
2955 // Send another message, ensuring no title is generated this time.
2956 let send = thread
2957 .update(cx, |thread, cx| {
2958 thread.send(UserMessageId::new(), ["Hello again"], cx)
2959 })
2960 .unwrap();
2961 cx.run_until_parked();
2962 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2963 fake_model.end_last_completion_stream();
2964 cx.run_until_parked();
2965 assert_eq!(summary_model.pending_completions(), Vec::new());
2966 send.collect::<Vec<_>>().await;
2967 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2968}
2969
2970#[gpui::test]
2971async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2972 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2973 let fake_model = model.as_fake();
2974
2975 let _events = thread
2976 .update(cx, |thread, cx| {
2977 thread.add_tool(ToolRequiringPermission, None);
2978 thread.add_tool(EchoTool, None);
2979 thread.send(UserMessageId::new(), ["Hey!"], cx)
2980 })
2981 .unwrap();
2982 cx.run_until_parked();
2983
2984 let permission_tool_use = LanguageModelToolUse {
2985 id: "tool_id_1".into(),
2986 name: ToolRequiringPermission::NAME.into(),
2987 raw_input: "{}".into(),
2988 input: json!({}),
2989 is_input_complete: true,
2990 thought_signature: None,
2991 };
2992 let echo_tool_use = LanguageModelToolUse {
2993 id: "tool_id_2".into(),
2994 name: EchoTool::NAME.into(),
2995 raw_input: json!({"text": "test"}).to_string(),
2996 input: json!({"text": "test"}),
2997 is_input_complete: true,
2998 thought_signature: None,
2999 };
3000 fake_model.send_last_completion_stream_text_chunk("Hi!");
3001 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3002 permission_tool_use,
3003 ));
3004 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3005 echo_tool_use.clone(),
3006 ));
3007 fake_model.end_last_completion_stream();
3008 cx.run_until_parked();
3009
3010 // Ensure pending tools are skipped when building a request.
3011 let request = thread
3012 .read_with(cx, |thread, cx| {
3013 thread.build_completion_request(CompletionIntent::EditFile, cx)
3014 })
3015 .unwrap();
3016 assert_eq!(
3017 request.messages[1..],
3018 vec![
3019 LanguageModelRequestMessage {
3020 role: Role::User,
3021 content: vec!["Hey!".into()],
3022 cache: true,
3023 reasoning_details: None,
3024 },
3025 LanguageModelRequestMessage {
3026 role: Role::Assistant,
3027 content: vec![
3028 MessageContent::Text("Hi!".into()),
3029 MessageContent::ToolUse(echo_tool_use.clone())
3030 ],
3031 cache: false,
3032 reasoning_details: None,
3033 },
3034 LanguageModelRequestMessage {
3035 role: Role::User,
3036 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3037 tool_use_id: echo_tool_use.id.clone(),
3038 tool_name: echo_tool_use.name,
3039 is_error: false,
3040 content: "test".into(),
3041 output: Some("test".into())
3042 })],
3043 cache: false,
3044 reasoning_details: None,
3045 },
3046 ],
3047 );
3048}
3049
3050#[gpui::test]
3051async fn test_agent_connection(cx: &mut TestAppContext) {
3052 cx.update(settings::init);
3053 let templates = Templates::new();
3054
3055 // Initialize language model system with test provider
3056 cx.update(|cx| {
3057 gpui_tokio::init(cx);
3058
3059 let http_client = FakeHttpClient::with_404_response();
3060 let clock = Arc::new(clock::FakeSystemClock::new());
3061 let client = Client::new(clock, http_client, cx);
3062 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3063 language_model::init(client.clone(), cx);
3064 language_models::init(user_store, client.clone(), cx);
3065 LanguageModelRegistry::test(cx);
3066 });
3067 cx.executor().forbid_parking();
3068
3069 // Create a project for new_thread
3070 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3071 fake_fs.insert_tree(path!("/test"), json!({})).await;
3072 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3073 let cwd = Path::new("/test");
3074 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3075
3076 // Create agent and connection
3077 let agent = NativeAgent::new(
3078 project.clone(),
3079 thread_store,
3080 templates.clone(),
3081 None,
3082 fake_fs.clone(),
3083 &mut cx.to_async(),
3084 )
3085 .await
3086 .unwrap();
3087 let connection = NativeAgentConnection(agent.clone());
3088
3089 // Create a thread using new_thread
3090 let connection_rc = Rc::new(connection.clone());
3091 let acp_thread = cx
3092 .update(|cx| connection_rc.new_session(project, cwd, cx))
3093 .await
3094 .expect("new_thread should succeed");
3095
3096 // Get the session_id from the AcpThread
3097 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3098
3099 // Test model_selector returns Some
3100 let selector_opt = connection.model_selector(&session_id);
3101 assert!(
3102 selector_opt.is_some(),
3103 "agent should always support ModelSelector"
3104 );
3105 let selector = selector_opt.unwrap();
3106
3107 // Test list_models
3108 let listed_models = cx
3109 .update(|cx| selector.list_models(cx))
3110 .await
3111 .expect("list_models should succeed");
3112 let AgentModelList::Grouped(listed_models) = listed_models else {
3113 panic!("Unexpected model list type");
3114 };
3115 assert!(!listed_models.is_empty(), "should have at least one model");
3116 assert_eq!(
3117 listed_models[&AgentModelGroupName("Fake".into())][0]
3118 .id
3119 .0
3120 .as_ref(),
3121 "fake/fake"
3122 );
3123
3124 // Test selected_model returns the default
3125 let model = cx
3126 .update(|cx| selector.selected_model(cx))
3127 .await
3128 .expect("selected_model should succeed");
3129 let model = cx
3130 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3131 .unwrap();
3132 let model = model.as_fake();
3133 assert_eq!(model.id().0, "fake", "should return default model");
3134
3135 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3136 cx.run_until_parked();
3137 model.send_last_completion_stream_text_chunk("def");
3138 cx.run_until_parked();
3139 acp_thread.read_with(cx, |thread, cx| {
3140 assert_eq!(
3141 thread.to_markdown(cx),
3142 indoc! {"
3143 ## User
3144
3145 abc
3146
3147 ## Assistant
3148
3149 def
3150
3151 "}
3152 )
3153 });
3154
3155 // Test cancel
3156 cx.update(|cx| connection.cancel(&session_id, cx));
3157 request.await.expect("prompt should fail gracefully");
3158
3159 // Explicitly close the session and drop the ACP thread.
3160 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3161 .await
3162 .unwrap();
3163 drop(acp_thread);
3164 let result = cx
3165 .update(|cx| {
3166 connection.prompt(
3167 Some(acp_thread::UserMessageId::new()),
3168 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3169 cx,
3170 )
3171 })
3172 .await;
3173 assert_eq!(
3174 result.as_ref().unwrap_err().to_string(),
3175 "Session not found",
3176 "unexpected result: {:?}",
3177 result
3178 );
3179}
3180
3181#[gpui::test]
3182async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3183 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3184 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
3185 let fake_model = model.as_fake();
3186
3187 let mut events = thread
3188 .update(cx, |thread, cx| {
3189 thread.send(UserMessageId::new(), ["Echo something"], cx)
3190 })
3191 .unwrap();
3192 cx.run_until_parked();
3193
3194 // Simulate streaming partial input.
3195 let input = json!({});
3196 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3197 LanguageModelToolUse {
3198 id: "1".into(),
3199 name: EchoTool::NAME.into(),
3200 raw_input: input.to_string(),
3201 input,
3202 is_input_complete: false,
3203 thought_signature: None,
3204 },
3205 ));
3206
3207 // Input streaming completed
3208 let input = json!({ "text": "Hello!" });
3209 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3210 LanguageModelToolUse {
3211 id: "1".into(),
3212 name: "echo".into(),
3213 raw_input: input.to_string(),
3214 input,
3215 is_input_complete: true,
3216 thought_signature: None,
3217 },
3218 ));
3219 fake_model.end_last_completion_stream();
3220 cx.run_until_parked();
3221
3222 let tool_call = expect_tool_call(&mut events).await;
3223 assert_eq!(
3224 tool_call,
3225 acp::ToolCall::new("1", "Echo")
3226 .raw_input(json!({}))
3227 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3228 );
3229 let update = expect_tool_call_update_fields(&mut events).await;
3230 assert_eq!(
3231 update,
3232 acp::ToolCallUpdate::new(
3233 "1",
3234 acp::ToolCallUpdateFields::new()
3235 .title("Echo")
3236 .kind(acp::ToolKind::Other)
3237 .raw_input(json!({ "text": "Hello!"}))
3238 )
3239 );
3240 let update = expect_tool_call_update_fields(&mut events).await;
3241 assert_eq!(
3242 update,
3243 acp::ToolCallUpdate::new(
3244 "1",
3245 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3246 )
3247 );
3248 let update = expect_tool_call_update_fields(&mut events).await;
3249 assert_eq!(
3250 update,
3251 acp::ToolCallUpdate::new(
3252 "1",
3253 acp::ToolCallUpdateFields::new()
3254 .status(acp::ToolCallStatus::Completed)
3255 .raw_output("Hello!")
3256 )
3257 );
3258}
3259
3260#[gpui::test]
3261async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3262 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3263 let fake_model = model.as_fake();
3264
3265 let mut events = thread
3266 .update(cx, |thread, cx| {
3267 thread.send(UserMessageId::new(), ["Hello!"], cx)
3268 })
3269 .unwrap();
3270 cx.run_until_parked();
3271
3272 fake_model.send_last_completion_stream_text_chunk("Hey!");
3273 fake_model.end_last_completion_stream();
3274
3275 let mut retry_events = Vec::new();
3276 while let Some(Ok(event)) = events.next().await {
3277 match event {
3278 ThreadEvent::Retry(retry_status) => {
3279 retry_events.push(retry_status);
3280 }
3281 ThreadEvent::Stop(..) => break,
3282 _ => {}
3283 }
3284 }
3285
3286 assert_eq!(retry_events.len(), 0);
3287 thread.read_with(cx, |thread, _cx| {
3288 assert_eq!(
3289 thread.to_markdown(),
3290 indoc! {"
3291 ## User
3292
3293 Hello!
3294
3295 ## Assistant
3296
3297 Hey!
3298 "}
3299 )
3300 });
3301}
3302
3303#[gpui::test]
3304async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3305 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3306 let fake_model = model.as_fake();
3307
3308 let mut events = thread
3309 .update(cx, |thread, cx| {
3310 thread.send(UserMessageId::new(), ["Hello!"], cx)
3311 })
3312 .unwrap();
3313 cx.run_until_parked();
3314
3315 fake_model.send_last_completion_stream_text_chunk("Hey,");
3316 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3317 provider: LanguageModelProviderName::new("Anthropic"),
3318 retry_after: Some(Duration::from_secs(3)),
3319 });
3320 fake_model.end_last_completion_stream();
3321
3322 cx.executor().advance_clock(Duration::from_secs(3));
3323 cx.run_until_parked();
3324
3325 fake_model.send_last_completion_stream_text_chunk("there!");
3326 fake_model.end_last_completion_stream();
3327 cx.run_until_parked();
3328
3329 let mut retry_events = Vec::new();
3330 while let Some(Ok(event)) = events.next().await {
3331 match event {
3332 ThreadEvent::Retry(retry_status) => {
3333 retry_events.push(retry_status);
3334 }
3335 ThreadEvent::Stop(..) => break,
3336 _ => {}
3337 }
3338 }
3339
3340 assert_eq!(retry_events.len(), 1);
3341 assert!(matches!(
3342 retry_events[0],
3343 acp_thread::RetryStatus { attempt: 1, .. }
3344 ));
3345 thread.read_with(cx, |thread, _cx| {
3346 assert_eq!(
3347 thread.to_markdown(),
3348 indoc! {"
3349 ## User
3350
3351 Hello!
3352
3353 ## Assistant
3354
3355 Hey,
3356
3357 [resume]
3358
3359 ## Assistant
3360
3361 there!
3362 "}
3363 )
3364 });
3365}
3366
3367#[gpui::test]
3368async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3369 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3370 let fake_model = model.as_fake();
3371
3372 let events = thread
3373 .update(cx, |thread, cx| {
3374 thread.add_tool(EchoTool, None);
3375 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3376 })
3377 .unwrap();
3378 cx.run_until_parked();
3379
3380 let tool_use_1 = LanguageModelToolUse {
3381 id: "tool_1".into(),
3382 name: EchoTool::NAME.into(),
3383 raw_input: json!({"text": "test"}).to_string(),
3384 input: json!({"text": "test"}),
3385 is_input_complete: true,
3386 thought_signature: None,
3387 };
3388 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3389 tool_use_1.clone(),
3390 ));
3391 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3392 provider: LanguageModelProviderName::new("Anthropic"),
3393 retry_after: Some(Duration::from_secs(3)),
3394 });
3395 fake_model.end_last_completion_stream();
3396
3397 cx.executor().advance_clock(Duration::from_secs(3));
3398 let completion = fake_model.pending_completions().pop().unwrap();
3399 assert_eq!(
3400 completion.messages[1..],
3401 vec![
3402 LanguageModelRequestMessage {
3403 role: Role::User,
3404 content: vec!["Call the echo tool!".into()],
3405 cache: false,
3406 reasoning_details: None,
3407 },
3408 LanguageModelRequestMessage {
3409 role: Role::Assistant,
3410 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3411 cache: false,
3412 reasoning_details: None,
3413 },
3414 LanguageModelRequestMessage {
3415 role: Role::User,
3416 content: vec![language_model::MessageContent::ToolResult(
3417 LanguageModelToolResult {
3418 tool_use_id: tool_use_1.id.clone(),
3419 tool_name: tool_use_1.name.clone(),
3420 is_error: false,
3421 content: "test".into(),
3422 output: Some("test".into())
3423 }
3424 )],
3425 cache: true,
3426 reasoning_details: None,
3427 },
3428 ]
3429 );
3430
3431 fake_model.send_last_completion_stream_text_chunk("Done");
3432 fake_model.end_last_completion_stream();
3433 cx.run_until_parked();
3434 events.collect::<Vec<_>>().await;
3435 thread.read_with(cx, |thread, _cx| {
3436 assert_eq!(
3437 thread.last_message(),
3438 Some(Message::Agent(AgentMessage {
3439 content: vec![AgentMessageContent::Text("Done".into())],
3440 tool_results: IndexMap::default(),
3441 reasoning_details: None,
3442 }))
3443 );
3444 })
3445}
3446
3447#[gpui::test]
3448async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3449 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3450 let fake_model = model.as_fake();
3451
3452 let mut events = thread
3453 .update(cx, |thread, cx| {
3454 thread.send(UserMessageId::new(), ["Hello!"], cx)
3455 })
3456 .unwrap();
3457 cx.run_until_parked();
3458
3459 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3460 fake_model.send_last_completion_stream_error(
3461 LanguageModelCompletionError::ServerOverloaded {
3462 provider: LanguageModelProviderName::new("Anthropic"),
3463 retry_after: Some(Duration::from_secs(3)),
3464 },
3465 );
3466 fake_model.end_last_completion_stream();
3467 cx.executor().advance_clock(Duration::from_secs(3));
3468 cx.run_until_parked();
3469 }
3470
3471 let mut errors = Vec::new();
3472 let mut retry_events = Vec::new();
3473 while let Some(event) = events.next().await {
3474 match event {
3475 Ok(ThreadEvent::Retry(retry_status)) => {
3476 retry_events.push(retry_status);
3477 }
3478 Ok(ThreadEvent::Stop(..)) => break,
3479 Err(error) => errors.push(error),
3480 _ => {}
3481 }
3482 }
3483
3484 assert_eq!(
3485 retry_events.len(),
3486 crate::thread::MAX_RETRY_ATTEMPTS as usize
3487 );
3488 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3489 assert_eq!(retry_events[i].attempt, i + 1);
3490 }
3491 assert_eq!(errors.len(), 1);
3492 let error = errors[0]
3493 .downcast_ref::<LanguageModelCompletionError>()
3494 .unwrap();
3495 assert!(matches!(
3496 error,
3497 LanguageModelCompletionError::ServerOverloaded { .. }
3498 ));
3499}
3500
3501/// Filters out the stop events for asserting against in tests
3502fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3503 result_events
3504 .into_iter()
3505 .filter_map(|event| match event.unwrap() {
3506 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3507 _ => None,
3508 })
3509 .collect()
3510}
3511
3512struct ThreadTest {
3513 model: Arc<dyn LanguageModel>,
3514 thread: Entity<Thread>,
3515 project_context: Entity<ProjectContext>,
3516 context_server_store: Entity<ContextServerStore>,
3517 fs: Arc<FakeFs>,
3518}
3519
3520enum TestModel {
3521 Sonnet4,
3522 Fake,
3523}
3524
3525impl TestModel {
3526 fn id(&self) -> LanguageModelId {
3527 match self {
3528 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3529 TestModel::Fake => unreachable!(),
3530 }
3531 }
3532}
3533
3534async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3535 cx.executor().allow_parking();
3536
3537 let fs = FakeFs::new(cx.background_executor.clone());
3538 fs.create_dir(paths::settings_file().parent().unwrap())
3539 .await
3540 .unwrap();
3541 fs.insert_file(
3542 paths::settings_file(),
3543 json!({
3544 "agent": {
3545 "default_profile": "test-profile",
3546 "profiles": {
3547 "test-profile": {
3548 "name": "Test Profile",
3549 "tools": {
3550 EchoTool::NAME: true,
3551 DelayTool::NAME: true,
3552 WordListTool::NAME: true,
3553 ToolRequiringPermission::NAME: true,
3554 InfiniteTool::NAME: true,
3555 CancellationAwareTool::NAME: true,
3556 (TerminalTool::NAME): true,
3557 }
3558 }
3559 }
3560 }
3561 })
3562 .to_string()
3563 .into_bytes(),
3564 )
3565 .await;
3566
3567 cx.update(|cx| {
3568 settings::init(cx);
3569
3570 match model {
3571 TestModel::Fake => {}
3572 TestModel::Sonnet4 => {
3573 gpui_tokio::init(cx);
3574 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3575 cx.set_http_client(Arc::new(http_client));
3576 let client = Client::production(cx);
3577 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3578 language_model::init(client.clone(), cx);
3579 language_models::init(user_store, client.clone(), cx);
3580 }
3581 };
3582
3583 watch_settings(fs.clone(), cx);
3584 });
3585
3586 let templates = Templates::new();
3587
3588 fs.insert_tree(path!("/test"), json!({})).await;
3589 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3590
3591 let model = cx
3592 .update(|cx| {
3593 if let TestModel::Fake = model {
3594 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3595 } else {
3596 let model_id = model.id();
3597 let models = LanguageModelRegistry::read_global(cx);
3598 let model = models
3599 .available_models(cx)
3600 .find(|model| model.id() == model_id)
3601 .unwrap();
3602
3603 let provider = models.provider(&model.provider_id()).unwrap();
3604 let authenticated = provider.authenticate(cx);
3605
3606 cx.spawn(async move |_cx| {
3607 authenticated.await.unwrap();
3608 model
3609 })
3610 }
3611 })
3612 .await;
3613
3614 let project_context = cx.new(|_cx| ProjectContext::default());
3615 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3616 let context_server_registry =
3617 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3618 let thread = cx.new(|cx| {
3619 Thread::new(
3620 project,
3621 project_context.clone(),
3622 context_server_registry,
3623 templates,
3624 Some(model.clone()),
3625 cx,
3626 )
3627 });
3628 ThreadTest {
3629 model,
3630 thread,
3631 project_context,
3632 context_server_store,
3633 fs,
3634 }
3635}
3636
3637#[cfg(test)]
3638#[ctor::ctor]
3639fn init_logger() {
3640 if std::env::var("RUST_LOG").is_ok() {
3641 env_logger::init();
3642 }
3643}
3644
3645fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3646 let fs = fs.clone();
3647 cx.spawn({
3648 async move |cx| {
3649 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3650 cx.background_executor(),
3651 fs,
3652 paths::settings_file().clone(),
3653 );
3654 let _watcher_task = watcher_task;
3655
3656 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3657 cx.update(|cx| {
3658 SettingsStore::update_global(cx, |settings, cx| {
3659 settings.set_user_settings(&new_settings_content, cx)
3660 })
3661 })
3662 .ok();
3663 }
3664 }
3665 })
3666 .detach();
3667}
3668
3669fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3670 completion
3671 .tools
3672 .iter()
3673 .map(|tool| tool.name.clone())
3674 .collect()
3675}
3676
3677fn setup_context_server(
3678 name: &'static str,
3679 tools: Vec<context_server::types::Tool>,
3680 context_server_store: &Entity<ContextServerStore>,
3681 cx: &mut TestAppContext,
3682) -> mpsc::UnboundedReceiver<(
3683 context_server::types::CallToolParams,
3684 oneshot::Sender<context_server::types::CallToolResponse>,
3685)> {
3686 cx.update(|cx| {
3687 let mut settings = ProjectSettings::get_global(cx).clone();
3688 settings.context_servers.insert(
3689 name.into(),
3690 project::project_settings::ContextServerSettings::Stdio {
3691 enabled: true,
3692 remote: false,
3693 command: ContextServerCommand {
3694 path: "somebinary".into(),
3695 args: Vec::new(),
3696 env: None,
3697 timeout: None,
3698 },
3699 },
3700 );
3701 ProjectSettings::override_global(settings, cx);
3702 });
3703
3704 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3705 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3706 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3707 context_server::types::InitializeResponse {
3708 protocol_version: context_server::types::ProtocolVersion(
3709 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3710 ),
3711 server_info: context_server::types::Implementation {
3712 name: name.into(),
3713 version: "1.0.0".to_string(),
3714 },
3715 capabilities: context_server::types::ServerCapabilities {
3716 tools: Some(context_server::types::ToolsCapabilities {
3717 list_changed: Some(true),
3718 }),
3719 ..Default::default()
3720 },
3721 meta: None,
3722 }
3723 })
3724 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3725 let tools = tools.clone();
3726 async move {
3727 context_server::types::ListToolsResponse {
3728 tools,
3729 next_cursor: None,
3730 meta: None,
3731 }
3732 }
3733 })
3734 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3735 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3736 async move {
3737 let (response_tx, response_rx) = oneshot::channel();
3738 mcp_tool_calls_tx
3739 .unbounded_send((params, response_tx))
3740 .unwrap();
3741 response_rx.await.unwrap()
3742 }
3743 });
3744 context_server_store.update(cx, |store, cx| {
3745 store.start_server(
3746 Arc::new(ContextServer::new(
3747 ContextServerId(name.into()),
3748 Arc::new(fake_transport),
3749 )),
3750 cx,
3751 );
3752 });
3753 cx.run_until_parked();
3754 mcp_tool_calls_rx
3755}
3756
3757#[gpui::test]
3758async fn test_tokens_before_message(cx: &mut TestAppContext) {
3759 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3760 let fake_model = model.as_fake();
3761
3762 // First message
3763 let message_1_id = UserMessageId::new();
3764 thread
3765 .update(cx, |thread, cx| {
3766 thread.send(message_1_id.clone(), ["First message"], cx)
3767 })
3768 .unwrap();
3769 cx.run_until_parked();
3770
3771 // Before any response, tokens_before_message should return None for first message
3772 thread.read_with(cx, |thread, _| {
3773 assert_eq!(
3774 thread.tokens_before_message(&message_1_id),
3775 None,
3776 "First message should have no tokens before it"
3777 );
3778 });
3779
3780 // Complete first message with usage
3781 fake_model.send_last_completion_stream_text_chunk("Response 1");
3782 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3783 language_model::TokenUsage {
3784 input_tokens: 100,
3785 output_tokens: 50,
3786 cache_creation_input_tokens: 0,
3787 cache_read_input_tokens: 0,
3788 },
3789 ));
3790 fake_model.end_last_completion_stream();
3791 cx.run_until_parked();
3792
3793 // First message still has no tokens before it
3794 thread.read_with(cx, |thread, _| {
3795 assert_eq!(
3796 thread.tokens_before_message(&message_1_id),
3797 None,
3798 "First message should still have no tokens before it after response"
3799 );
3800 });
3801
3802 // Second message
3803 let message_2_id = UserMessageId::new();
3804 thread
3805 .update(cx, |thread, cx| {
3806 thread.send(message_2_id.clone(), ["Second message"], cx)
3807 })
3808 .unwrap();
3809 cx.run_until_parked();
3810
3811 // Second message should have first message's input tokens before it
3812 thread.read_with(cx, |thread, _| {
3813 assert_eq!(
3814 thread.tokens_before_message(&message_2_id),
3815 Some(100),
3816 "Second message should have 100 tokens before it (from first request)"
3817 );
3818 });
3819
3820 // Complete second message
3821 fake_model.send_last_completion_stream_text_chunk("Response 2");
3822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3823 language_model::TokenUsage {
3824 input_tokens: 250, // Total for this request (includes previous context)
3825 output_tokens: 75,
3826 cache_creation_input_tokens: 0,
3827 cache_read_input_tokens: 0,
3828 },
3829 ));
3830 fake_model.end_last_completion_stream();
3831 cx.run_until_parked();
3832
3833 // Third message
3834 let message_3_id = UserMessageId::new();
3835 thread
3836 .update(cx, |thread, cx| {
3837 thread.send(message_3_id.clone(), ["Third message"], cx)
3838 })
3839 .unwrap();
3840 cx.run_until_parked();
3841
3842 // Third message should have second message's input tokens (250) before it
3843 thread.read_with(cx, |thread, _| {
3844 assert_eq!(
3845 thread.tokens_before_message(&message_3_id),
3846 Some(250),
3847 "Third message should have 250 tokens before it (from second request)"
3848 );
3849 // Second message should still have 100
3850 assert_eq!(
3851 thread.tokens_before_message(&message_2_id),
3852 Some(100),
3853 "Second message should still have 100 tokens before it"
3854 );
3855 // First message still has none
3856 assert_eq!(
3857 thread.tokens_before_message(&message_1_id),
3858 None,
3859 "First message should still have no tokens before it"
3860 );
3861 });
3862}
3863
3864#[gpui::test]
3865async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3866 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3867 let fake_model = model.as_fake();
3868
3869 // Set up three messages with responses
3870 let message_1_id = UserMessageId::new();
3871 thread
3872 .update(cx, |thread, cx| {
3873 thread.send(message_1_id.clone(), ["Message 1"], cx)
3874 })
3875 .unwrap();
3876 cx.run_until_parked();
3877 fake_model.send_last_completion_stream_text_chunk("Response 1");
3878 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3879 language_model::TokenUsage {
3880 input_tokens: 100,
3881 output_tokens: 50,
3882 cache_creation_input_tokens: 0,
3883 cache_read_input_tokens: 0,
3884 },
3885 ));
3886 fake_model.end_last_completion_stream();
3887 cx.run_until_parked();
3888
3889 let message_2_id = UserMessageId::new();
3890 thread
3891 .update(cx, |thread, cx| {
3892 thread.send(message_2_id.clone(), ["Message 2"], cx)
3893 })
3894 .unwrap();
3895 cx.run_until_parked();
3896 fake_model.send_last_completion_stream_text_chunk("Response 2");
3897 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3898 language_model::TokenUsage {
3899 input_tokens: 250,
3900 output_tokens: 75,
3901 cache_creation_input_tokens: 0,
3902 cache_read_input_tokens: 0,
3903 },
3904 ));
3905 fake_model.end_last_completion_stream();
3906 cx.run_until_parked();
3907
3908 // Verify initial state
3909 thread.read_with(cx, |thread, _| {
3910 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3911 });
3912
3913 // Truncate at message 2 (removes message 2 and everything after)
3914 thread
3915 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3916 .unwrap();
3917 cx.run_until_parked();
3918
3919 // After truncation, message_2_id no longer exists, so lookup should return None
3920 thread.read_with(cx, |thread, _| {
3921 assert_eq!(
3922 thread.tokens_before_message(&message_2_id),
3923 None,
3924 "After truncation, message 2 no longer exists"
3925 );
3926 // Message 1 still exists but has no tokens before it
3927 assert_eq!(
3928 thread.tokens_before_message(&message_1_id),
3929 None,
3930 "First message still has no tokens before it"
3931 );
3932 });
3933}
3934
3935#[gpui::test]
3936async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3937 init_test(cx);
3938
3939 let fs = FakeFs::new(cx.executor());
3940 fs.insert_tree("/root", json!({})).await;
3941 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3942
3943 // Test 1: Deny rule blocks command
3944 {
3945 let environment = Rc::new(cx.update(|cx| {
3946 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3947 }));
3948
3949 cx.update(|cx| {
3950 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3951 settings.tool_permissions.tools.insert(
3952 TerminalTool::NAME.into(),
3953 agent_settings::ToolRules {
3954 default: Some(settings::ToolPermissionMode::Confirm),
3955 always_allow: vec![],
3956 always_deny: vec![
3957 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3958 ],
3959 always_confirm: vec![],
3960 invalid_patterns: vec![],
3961 },
3962 );
3963 agent_settings::AgentSettings::override_global(settings, cx);
3964 });
3965
3966 #[allow(clippy::arc_with_non_send_sync)]
3967 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3968 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3969
3970 let task = cx.update(|cx| {
3971 tool.run(
3972 crate::TerminalToolInput {
3973 command: "rm -rf /".to_string(),
3974 cd: ".".to_string(),
3975 timeout_ms: None,
3976 },
3977 event_stream,
3978 cx,
3979 )
3980 });
3981
3982 let result = task.await;
3983 assert!(
3984 result.is_err(),
3985 "expected command to be blocked by deny rule"
3986 );
3987 let err_msg = result.unwrap_err().to_string().to_lowercase();
3988 assert!(
3989 err_msg.contains("blocked"),
3990 "error should mention the command was blocked"
3991 );
3992 }
3993
3994 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
3995 {
3996 let environment = Rc::new(cx.update(|cx| {
3997 FakeThreadEnvironment::default()
3998 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3999 }));
4000
4001 cx.update(|cx| {
4002 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4003 settings.tool_permissions.tools.insert(
4004 TerminalTool::NAME.into(),
4005 agent_settings::ToolRules {
4006 default: Some(settings::ToolPermissionMode::Deny),
4007 always_allow: vec![
4008 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4009 ],
4010 always_deny: vec![],
4011 always_confirm: vec![],
4012 invalid_patterns: vec![],
4013 },
4014 );
4015 agent_settings::AgentSettings::override_global(settings, cx);
4016 });
4017
4018 #[allow(clippy::arc_with_non_send_sync)]
4019 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4020 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4021
4022 let task = cx.update(|cx| {
4023 tool.run(
4024 crate::TerminalToolInput {
4025 command: "echo hello".to_string(),
4026 cd: ".".to_string(),
4027 timeout_ms: None,
4028 },
4029 event_stream,
4030 cx,
4031 )
4032 });
4033
4034 let update = rx.expect_update_fields().await;
4035 assert!(
4036 update.content.iter().any(|blocks| {
4037 blocks
4038 .iter()
4039 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4040 }),
4041 "expected terminal content (allow rule should skip confirmation and override default deny)"
4042 );
4043
4044 let result = task.await;
4045 assert!(
4046 result.is_ok(),
4047 "expected command to succeed without confirmation"
4048 );
4049 }
4050
4051 // Test 3: global default: allow does NOT override always_confirm patterns
4052 {
4053 let environment = Rc::new(cx.update(|cx| {
4054 FakeThreadEnvironment::default()
4055 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4056 }));
4057
4058 cx.update(|cx| {
4059 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4060 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4061 settings.tool_permissions.tools.insert(
4062 TerminalTool::NAME.into(),
4063 agent_settings::ToolRules {
4064 default: Some(settings::ToolPermissionMode::Allow),
4065 always_allow: vec![],
4066 always_deny: vec![],
4067 always_confirm: vec![
4068 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4069 ],
4070 invalid_patterns: vec![],
4071 },
4072 );
4073 agent_settings::AgentSettings::override_global(settings, cx);
4074 });
4075
4076 #[allow(clippy::arc_with_non_send_sync)]
4077 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4078 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4079
4080 let _task = cx.update(|cx| {
4081 tool.run(
4082 crate::TerminalToolInput {
4083 command: "sudo rm file".to_string(),
4084 cd: ".".to_string(),
4085 timeout_ms: None,
4086 },
4087 event_stream,
4088 cx,
4089 )
4090 });
4091
4092 // With global default: allow, confirm patterns are still respected
4093 // The expect_authorization() call will panic if no authorization is requested,
4094 // which validates that the confirm pattern still triggers confirmation
4095 let _auth = rx.expect_authorization().await;
4096
4097 drop(_task);
4098 }
4099
4100 // Test 4: tool-specific default: deny is respected even with global default: allow
4101 {
4102 let environment = Rc::new(cx.update(|cx| {
4103 FakeThreadEnvironment::default()
4104 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4105 }));
4106
4107 cx.update(|cx| {
4108 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4109 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4110 settings.tool_permissions.tools.insert(
4111 TerminalTool::NAME.into(),
4112 agent_settings::ToolRules {
4113 default: Some(settings::ToolPermissionMode::Deny),
4114 always_allow: vec![],
4115 always_deny: vec![],
4116 always_confirm: vec![],
4117 invalid_patterns: vec![],
4118 },
4119 );
4120 agent_settings::AgentSettings::override_global(settings, cx);
4121 });
4122
4123 #[allow(clippy::arc_with_non_send_sync)]
4124 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4125 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4126
4127 let task = cx.update(|cx| {
4128 tool.run(
4129 crate::TerminalToolInput {
4130 command: "echo hello".to_string(),
4131 cd: ".".to_string(),
4132 timeout_ms: None,
4133 },
4134 event_stream,
4135 cx,
4136 )
4137 });
4138
4139 // tool-specific default: deny is respected even with global default: allow
4140 let result = task.await;
4141 assert!(
4142 result.is_err(),
4143 "expected command to be blocked by tool-specific deny default"
4144 );
4145 let err_msg = result.unwrap_err().to_string().to_lowercase();
4146 assert!(
4147 err_msg.contains("disabled"),
4148 "error should mention the tool is disabled, got: {err_msg}"
4149 );
4150 }
4151}
4152
4153#[gpui::test]
4154async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4155 init_test(cx);
4156
4157 cx.update(|cx| {
4158 cx.update_flags(true, vec!["subagents".to_string()]);
4159 });
4160
4161 let fs = FakeFs::new(cx.executor());
4162 fs.insert_tree(path!("/test"), json!({})).await;
4163 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4164 let project_context = cx.new(|_cx| ProjectContext::default());
4165 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4166 let context_server_registry =
4167 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4168 let model = Arc::new(FakeLanguageModel::default());
4169
4170 let environment = Rc::new(cx.update(|cx| {
4171 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4172 }));
4173
4174 let thread = cx.new(|cx| {
4175 let mut thread = Thread::new(
4176 project.clone(),
4177 project_context,
4178 context_server_registry,
4179 Templates::new(),
4180 Some(model),
4181 cx,
4182 );
4183 thread.add_default_tools(None, environment, cx);
4184 thread
4185 });
4186
4187 thread.read_with(cx, |thread, _| {
4188 assert!(
4189 thread.has_registered_tool(SubagentTool::NAME),
4190 "subagent tool should be present when feature flag is enabled"
4191 );
4192 });
4193}
4194
4195#[gpui::test]
4196async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4197 init_test(cx);
4198
4199 cx.update(|cx| {
4200 cx.update_flags(true, vec!["subagents".to_string()]);
4201 });
4202
4203 let fs = FakeFs::new(cx.executor());
4204 fs.insert_tree(path!("/test"), json!({})).await;
4205 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4206 let project_context = cx.new(|_cx| ProjectContext::default());
4207 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4208 let context_server_registry =
4209 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4210 let model = Arc::new(FakeLanguageModel::default());
4211
4212 let parent_thread = cx.new(|cx| {
4213 Thread::new(
4214 project.clone(),
4215 project_context,
4216 context_server_registry,
4217 Templates::new(),
4218 Some(model.clone()),
4219 cx,
4220 )
4221 });
4222
4223 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4224 subagent_thread.read_with(cx, |subagent_thread, cx| {
4225 assert!(subagent_thread.is_subagent());
4226 assert_eq!(subagent_thread.depth(), 1);
4227 assert_eq!(
4228 subagent_thread.model().map(|model| model.id()),
4229 Some(model.id())
4230 );
4231 assert_eq!(
4232 subagent_thread.parent_thread_id(),
4233 Some(parent_thread.read(cx).id().clone())
4234 );
4235 });
4236}
4237
4238#[gpui::test]
4239async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4240 init_test(cx);
4241
4242 cx.update(|cx| {
4243 cx.update_flags(true, vec!["subagents".to_string()]);
4244 });
4245
4246 let fs = FakeFs::new(cx.executor());
4247 fs.insert_tree(path!("/test"), json!({})).await;
4248 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4249 let project_context = cx.new(|_cx| ProjectContext::default());
4250 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4251 let context_server_registry =
4252 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4253 let model = Arc::new(FakeLanguageModel::default());
4254 let environment = Rc::new(cx.update(|cx| {
4255 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4256 }));
4257
4258 let deep_parent_thread = cx.new(|cx| {
4259 let mut thread = Thread::new(
4260 project.clone(),
4261 project_context,
4262 context_server_registry,
4263 Templates::new(),
4264 Some(model.clone()),
4265 cx,
4266 );
4267 thread.set_subagent_context(SubagentContext {
4268 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4269 depth: MAX_SUBAGENT_DEPTH - 1,
4270 });
4271 thread
4272 });
4273 let deep_subagent_thread = cx.new(|cx| {
4274 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4275 thread.add_default_tools(None, environment, cx);
4276 thread
4277 });
4278
4279 deep_subagent_thread.read_with(cx, |thread, _| {
4280 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4281 assert!(
4282 !thread.has_registered_tool(SubagentTool::NAME),
4283 "subagent tool should not be present at max depth"
4284 );
4285 });
4286}
4287
4288#[gpui::test]
4289async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4290 init_test(cx);
4291
4292 cx.update(|cx| {
4293 cx.update_flags(true, vec!["subagents".to_string()]);
4294 });
4295
4296 let fs = FakeFs::new(cx.executor());
4297 fs.insert_tree(path!("/test"), json!({})).await;
4298 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4299 let project_context = cx.new(|_cx| ProjectContext::default());
4300 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4301 let context_server_registry =
4302 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4303 let model = Arc::new(FakeLanguageModel::default());
4304
4305 let parent = cx.new(|cx| {
4306 Thread::new(
4307 project.clone(),
4308 project_context.clone(),
4309 context_server_registry.clone(),
4310 Templates::new(),
4311 Some(model.clone()),
4312 cx,
4313 )
4314 });
4315
4316 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4317
4318 parent.update(cx, |thread, _cx| {
4319 thread.register_running_subagent(subagent.downgrade());
4320 });
4321
4322 subagent
4323 .update(cx, |thread, cx| {
4324 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4325 })
4326 .unwrap();
4327 cx.run_until_parked();
4328
4329 subagent.read_with(cx, |thread, _| {
4330 assert!(!thread.is_turn_complete(), "subagent should be running");
4331 });
4332
4333 parent.update(cx, |thread, cx| {
4334 thread.cancel(cx).detach();
4335 });
4336
4337 subagent.read_with(cx, |thread, _| {
4338 assert!(
4339 thread.is_turn_complete(),
4340 "subagent should be cancelled when parent cancels"
4341 );
4342 });
4343}
4344
4345#[gpui::test]
4346async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4347 // This test verifies that the subagent tool properly handles user cancellation
4348 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4349 init_test(cx);
4350 always_allow_tools(cx);
4351
4352 cx.update(|cx| {
4353 cx.update_flags(true, vec!["subagents".to_string()]);
4354 });
4355
4356 let fs = FakeFs::new(cx.executor());
4357 fs.insert_tree(path!("/test"), json!({})).await;
4358 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4359 let project_context = cx.new(|_cx| ProjectContext::default());
4360 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4361 let context_server_registry =
4362 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4363 let model = Arc::new(FakeLanguageModel::default());
4364 let environment = Rc::new(cx.update(|cx| {
4365 FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
4366 }));
4367
4368 let parent = cx.new(|cx| {
4369 Thread::new(
4370 project.clone(),
4371 project_context.clone(),
4372 context_server_registry.clone(),
4373 Templates::new(),
4374 Some(model.clone()),
4375 cx,
4376 )
4377 });
4378
4379 #[allow(clippy::arc_with_non_send_sync)]
4380 let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
4381
4382 let (event_stream, _rx, mut cancellation_tx) =
4383 crate::ToolCallEventStream::test_with_cancellation();
4384
4385 // Start the subagent tool
4386 let task = cx.update(|cx| {
4387 tool.run(
4388 SubagentToolInput {
4389 label: "Long running task".to_string(),
4390 task_prompt: "Do a very long task that takes forever".to_string(),
4391 summary_prompt: "Summarize".to_string(),
4392 timeout_ms: None,
4393 allowed_tools: None,
4394 },
4395 event_stream.clone(),
4396 cx,
4397 )
4398 });
4399
4400 cx.run_until_parked();
4401
4402 // Signal cancellation via the event stream
4403 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4404
4405 // The task should complete promptly with a cancellation error
4406 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4407 let result = futures::select! {
4408 result = task.fuse() => result,
4409 _ = timeout.fuse() => {
4410 panic!("subagent tool did not respond to cancellation within timeout");
4411 }
4412 };
4413
4414 // Verify we got a cancellation error
4415 let err = result.unwrap_err();
4416 assert!(
4417 err.to_string().contains("cancelled by user"),
4418 "expected cancellation error, got: {}",
4419 err
4420 );
4421}
4422
4423#[gpui::test]
4424async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4425 init_test(cx);
4426 always_allow_tools(cx);
4427
4428 cx.update(|cx| {
4429 cx.update_flags(true, vec!["subagents".to_string()]);
4430 });
4431
4432 let fs = FakeFs::new(cx.executor());
4433 fs.insert_tree(path!("/test"), json!({})).await;
4434 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4435 let project_context = cx.new(|_cx| ProjectContext::default());
4436 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4437 let context_server_registry =
4438 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4439 cx.update(LanguageModelRegistry::test);
4440 let model = Arc::new(FakeLanguageModel::default());
4441 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4442 let native_agent = NativeAgent::new(
4443 project.clone(),
4444 thread_store,
4445 Templates::new(),
4446 None,
4447 fs,
4448 &mut cx.to_async(),
4449 )
4450 .await
4451 .unwrap();
4452 let parent_thread = cx.new(|cx| {
4453 Thread::new(
4454 project.clone(),
4455 project_context,
4456 context_server_registry,
4457 Templates::new(),
4458 Some(model.clone()),
4459 cx,
4460 )
4461 });
4462
4463 let mut handles = Vec::new();
4464 for _ in 0..MAX_PARALLEL_SUBAGENTS {
4465 let handle = cx
4466 .update(|cx| {
4467 NativeThreadEnvironment::create_subagent_thread(
4468 native_agent.downgrade(),
4469 parent_thread.clone(),
4470 "some title".to_string(),
4471 "some task".to_string(),
4472 None,
4473 None,
4474 cx,
4475 )
4476 })
4477 .expect("Expected to be able to create subagent thread");
4478 handles.push(handle);
4479 }
4480
4481 let result = cx.update(|cx| {
4482 NativeThreadEnvironment::create_subagent_thread(
4483 native_agent.downgrade(),
4484 parent_thread.clone(),
4485 "some title".to_string(),
4486 "some task".to_string(),
4487 None,
4488 None,
4489 cx,
4490 )
4491 });
4492 assert!(result.is_err());
4493 assert_eq!(
4494 result.err().unwrap().to_string(),
4495 format!(
4496 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
4497 MAX_PARALLEL_SUBAGENTS
4498 )
4499 );
4500}
4501
4502#[gpui::test]
4503async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4504 init_test(cx);
4505
4506 always_allow_tools(cx);
4507
4508 cx.update(|cx| {
4509 cx.update_flags(true, vec!["subagents".to_string()]);
4510 });
4511
4512 let fs = FakeFs::new(cx.executor());
4513 fs.insert_tree(path!("/test"), json!({})).await;
4514 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4515 let project_context = cx.new(|_cx| ProjectContext::default());
4516 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4517 let context_server_registry =
4518 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4519 cx.update(LanguageModelRegistry::test);
4520 let model = Arc::new(FakeLanguageModel::default());
4521 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4522 let native_agent = NativeAgent::new(
4523 project.clone(),
4524 thread_store,
4525 Templates::new(),
4526 None,
4527 fs,
4528 &mut cx.to_async(),
4529 )
4530 .await
4531 .unwrap();
4532 let parent_thread = cx.new(|cx| {
4533 Thread::new(
4534 project.clone(),
4535 project_context,
4536 context_server_registry,
4537 Templates::new(),
4538 Some(model.clone()),
4539 cx,
4540 )
4541 });
4542
4543 let subagent_handle = cx
4544 .update(|cx| {
4545 NativeThreadEnvironment::create_subagent_thread(
4546 native_agent.downgrade(),
4547 parent_thread.clone(),
4548 "some title".to_string(),
4549 "task prompt".to_string(),
4550 Some(Duration::from_millis(10)),
4551 None,
4552 cx,
4553 )
4554 })
4555 .expect("Failed to create subagent");
4556
4557 let summary_task =
4558 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4559
4560 cx.run_until_parked();
4561
4562 {
4563 let messages = model.pending_completions().last().unwrap().messages.clone();
4564 // Ensure that model received a system prompt
4565 assert_eq!(messages[0].role, Role::System);
4566 // Ensure that model received a task prompt
4567 assert_eq!(messages[1].role, Role::User);
4568 assert_eq!(
4569 messages[1].content,
4570 vec![MessageContent::Text("task prompt".to_string())]
4571 );
4572 }
4573
4574 model.send_last_completion_stream_text_chunk("Some task response...");
4575 model.end_last_completion_stream();
4576
4577 cx.run_until_parked();
4578
4579 {
4580 let messages = model.pending_completions().last().unwrap().messages.clone();
4581 assert_eq!(messages[2].role, Role::Assistant);
4582 assert_eq!(
4583 messages[2].content,
4584 vec![MessageContent::Text("Some task response...".to_string())]
4585 );
4586 // Ensure that model received a summary prompt
4587 assert_eq!(messages[3].role, Role::User);
4588 assert_eq!(
4589 messages[3].content,
4590 vec![MessageContent::Text("summary prompt".to_string())]
4591 );
4592 }
4593
4594 model.send_last_completion_stream_text_chunk("Some summary...");
4595 model.end_last_completion_stream();
4596
4597 let result = summary_task.await;
4598 assert_eq!(result.unwrap(), "Some summary...\n");
4599}
4600
4601#[gpui::test]
4602async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4603 cx: &mut TestAppContext,
4604) {
4605 init_test(cx);
4606
4607 always_allow_tools(cx);
4608
4609 cx.update(|cx| {
4610 cx.update_flags(true, vec!["subagents".to_string()]);
4611 });
4612
4613 let fs = FakeFs::new(cx.executor());
4614 fs.insert_tree(path!("/test"), json!({})).await;
4615 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4616 let project_context = cx.new(|_cx| ProjectContext::default());
4617 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4618 let context_server_registry =
4619 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4620 cx.update(LanguageModelRegistry::test);
4621 let model = Arc::new(FakeLanguageModel::default());
4622 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4623 let native_agent = NativeAgent::new(
4624 project.clone(),
4625 thread_store,
4626 Templates::new(),
4627 None,
4628 fs,
4629 &mut cx.to_async(),
4630 )
4631 .await
4632 .unwrap();
4633 let parent_thread = cx.new(|cx| {
4634 Thread::new(
4635 project.clone(),
4636 project_context,
4637 context_server_registry,
4638 Templates::new(),
4639 Some(model.clone()),
4640 cx,
4641 )
4642 });
4643
4644 let subagent_handle = cx
4645 .update(|cx| {
4646 NativeThreadEnvironment::create_subagent_thread(
4647 native_agent.downgrade(),
4648 parent_thread.clone(),
4649 "some title".to_string(),
4650 "task prompt".to_string(),
4651 Some(Duration::from_millis(100)),
4652 None,
4653 cx,
4654 )
4655 })
4656 .expect("Failed to create subagent");
4657
4658 let summary_task =
4659 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4660
4661 cx.run_until_parked();
4662
4663 {
4664 let messages = model.pending_completions().last().unwrap().messages.clone();
4665 // Ensure that model received a system prompt
4666 assert_eq!(messages[0].role, Role::System);
4667 // Ensure that model received a task prompt
4668 assert_eq!(
4669 messages[1].content,
4670 vec![MessageContent::Text("task prompt".to_string())]
4671 );
4672 }
4673
4674 // Don't complete the initial model stream — let the timeout expire instead.
4675 cx.executor().advance_clock(Duration::from_millis(200));
4676 cx.run_until_parked();
4677
4678 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
4679 // instead of the summary_prompt.
4680 {
4681 let messages = model.pending_completions().last().unwrap().messages.clone();
4682 let last_user_message = messages
4683 .iter()
4684 .rev()
4685 .find(|m| m.role == Role::User)
4686 .unwrap();
4687 assert_eq!(
4688 last_user_message.content,
4689 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
4690 );
4691 }
4692
4693 model.send_last_completion_stream_text_chunk("Some context low response...");
4694 model.end_last_completion_stream();
4695
4696 let result = summary_task.await;
4697 assert_eq!(result.unwrap(), "Some context low response...\n");
4698}
4699
4700#[gpui::test]
4701async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
4702 init_test(cx);
4703
4704 always_allow_tools(cx);
4705
4706 cx.update(|cx| {
4707 cx.update_flags(true, vec!["subagents".to_string()]);
4708 });
4709
4710 let fs = FakeFs::new(cx.executor());
4711 fs.insert_tree(path!("/test"), json!({})).await;
4712 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4713 let project_context = cx.new(|_cx| ProjectContext::default());
4714 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4715 let context_server_registry =
4716 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4717 cx.update(LanguageModelRegistry::test);
4718 let model = Arc::new(FakeLanguageModel::default());
4719 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4720 let native_agent = NativeAgent::new(
4721 project.clone(),
4722 thread_store,
4723 Templates::new(),
4724 None,
4725 fs,
4726 &mut cx.to_async(),
4727 )
4728 .await
4729 .unwrap();
4730 let parent_thread = cx.new(|cx| {
4731 let mut thread = Thread::new(
4732 project.clone(),
4733 project_context,
4734 context_server_registry,
4735 Templates::new(),
4736 Some(model.clone()),
4737 cx,
4738 );
4739 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4740 thread.add_tool(GrepTool::new(project.clone()), None);
4741 thread
4742 });
4743
4744 let _subagent_handle = cx
4745 .update(|cx| {
4746 NativeThreadEnvironment::create_subagent_thread(
4747 native_agent.downgrade(),
4748 parent_thread.clone(),
4749 "some title".to_string(),
4750 "task prompt".to_string(),
4751 Some(Duration::from_millis(10)),
4752 None,
4753 cx,
4754 )
4755 })
4756 .expect("Failed to create subagent");
4757
4758 cx.run_until_parked();
4759
4760 let tools = model
4761 .pending_completions()
4762 .last()
4763 .unwrap()
4764 .tools
4765 .iter()
4766 .map(|tool| tool.name.clone())
4767 .collect::<Vec<_>>();
4768 assert_eq!(tools.len(), 2);
4769 assert!(tools.contains(&"grep".to_string()));
4770 assert!(tools.contains(&"list_directory".to_string()));
4771}
4772
4773#[gpui::test]
4774async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
4775 init_test(cx);
4776
4777 always_allow_tools(cx);
4778
4779 cx.update(|cx| {
4780 cx.update_flags(true, vec!["subagents".to_string()]);
4781 });
4782
4783 let fs = FakeFs::new(cx.executor());
4784 fs.insert_tree(path!("/test"), json!({})).await;
4785 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4786 let project_context = cx.new(|_cx| ProjectContext::default());
4787 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4788 let context_server_registry =
4789 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4790 cx.update(LanguageModelRegistry::test);
4791 let model = Arc::new(FakeLanguageModel::default());
4792 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4793 let native_agent = NativeAgent::new(
4794 project.clone(),
4795 thread_store,
4796 Templates::new(),
4797 None,
4798 fs,
4799 &mut cx.to_async(),
4800 )
4801 .await
4802 .unwrap();
4803 let parent_thread = cx.new(|cx| {
4804 let mut thread = Thread::new(
4805 project.clone(),
4806 project_context,
4807 context_server_registry,
4808 Templates::new(),
4809 Some(model.clone()),
4810 cx,
4811 );
4812 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4813 thread.add_tool(GrepTool::new(project.clone()), None);
4814 thread
4815 });
4816
4817 let _subagent_handle = cx
4818 .update(|cx| {
4819 NativeThreadEnvironment::create_subagent_thread(
4820 native_agent.downgrade(),
4821 parent_thread.clone(),
4822 "some title".to_string(),
4823 "task prompt".to_string(),
4824 Some(Duration::from_millis(10)),
4825 Some(vec!["grep".to_string()]),
4826 cx,
4827 )
4828 })
4829 .expect("Failed to create subagent");
4830
4831 cx.run_until_parked();
4832
4833 let tools = model
4834 .pending_completions()
4835 .last()
4836 .unwrap()
4837 .tools
4838 .iter()
4839 .map(|tool| tool.name.clone())
4840 .collect::<Vec<_>>();
4841 assert_eq!(tools, vec!["grep"]);
4842}
4843
4844#[gpui::test]
4845async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4846 init_test(cx);
4847
4848 let fs = FakeFs::new(cx.executor());
4849 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4850 .await;
4851 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4852
4853 cx.update(|cx| {
4854 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4855 settings.tool_permissions.tools.insert(
4856 EditFileTool::NAME.into(),
4857 agent_settings::ToolRules {
4858 default: Some(settings::ToolPermissionMode::Allow),
4859 always_allow: vec![],
4860 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
4861 always_confirm: vec![],
4862 invalid_patterns: vec![],
4863 },
4864 );
4865 agent_settings::AgentSettings::override_global(settings, cx);
4866 });
4867
4868 let context_server_registry =
4869 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4870 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4871 let templates = crate::Templates::new();
4872 let thread = cx.new(|cx| {
4873 crate::Thread::new(
4874 project.clone(),
4875 cx.new(|_cx| prompt_store::ProjectContext::default()),
4876 context_server_registry,
4877 templates.clone(),
4878 None,
4879 cx,
4880 )
4881 });
4882
4883 #[allow(clippy::arc_with_non_send_sync)]
4884 let tool = Arc::new(crate::EditFileTool::new(
4885 project.clone(),
4886 thread.downgrade(),
4887 language_registry,
4888 templates,
4889 ));
4890 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4891
4892 let task = cx.update(|cx| {
4893 tool.run(
4894 crate::EditFileToolInput {
4895 display_description: "Edit sensitive file".to_string(),
4896 path: "root/sensitive_config.txt".into(),
4897 mode: crate::EditFileMode::Edit,
4898 },
4899 event_stream,
4900 cx,
4901 )
4902 });
4903
4904 let result = task.await;
4905 assert!(result.is_err(), "expected edit to be blocked");
4906 assert!(
4907 result.unwrap_err().to_string().contains("blocked"),
4908 "error should mention the edit was blocked"
4909 );
4910}
4911
4912#[gpui::test]
4913async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4914 init_test(cx);
4915
4916 let fs = FakeFs::new(cx.executor());
4917 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4918 .await;
4919 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4920
4921 cx.update(|cx| {
4922 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4923 settings.tool_permissions.tools.insert(
4924 DeletePathTool::NAME.into(),
4925 agent_settings::ToolRules {
4926 default: Some(settings::ToolPermissionMode::Allow),
4927 always_allow: vec![],
4928 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4929 always_confirm: vec![],
4930 invalid_patterns: vec![],
4931 },
4932 );
4933 agent_settings::AgentSettings::override_global(settings, cx);
4934 });
4935
4936 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4937
4938 #[allow(clippy::arc_with_non_send_sync)]
4939 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4940 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4941
4942 let task = cx.update(|cx| {
4943 tool.run(
4944 crate::DeletePathToolInput {
4945 path: "root/important_data.txt".to_string(),
4946 },
4947 event_stream,
4948 cx,
4949 )
4950 });
4951
4952 let result = task.await;
4953 assert!(result.is_err(), "expected deletion to be blocked");
4954 assert!(
4955 result.unwrap_err().to_string().contains("blocked"),
4956 "error should mention the deletion was blocked"
4957 );
4958}
4959
4960#[gpui::test]
4961async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
4962 init_test(cx);
4963
4964 let fs = FakeFs::new(cx.executor());
4965 fs.insert_tree(
4966 "/root",
4967 json!({
4968 "safe.txt": "content",
4969 "protected": {}
4970 }),
4971 )
4972 .await;
4973 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4974
4975 cx.update(|cx| {
4976 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4977 settings.tool_permissions.tools.insert(
4978 MovePathTool::NAME.into(),
4979 agent_settings::ToolRules {
4980 default: Some(settings::ToolPermissionMode::Allow),
4981 always_allow: vec![],
4982 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
4983 always_confirm: vec![],
4984 invalid_patterns: vec![],
4985 },
4986 );
4987 agent_settings::AgentSettings::override_global(settings, cx);
4988 });
4989
4990 #[allow(clippy::arc_with_non_send_sync)]
4991 let tool = Arc::new(crate::MovePathTool::new(project));
4992 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4993
4994 let task = cx.update(|cx| {
4995 tool.run(
4996 crate::MovePathToolInput {
4997 source_path: "root/safe.txt".to_string(),
4998 destination_path: "root/protected/safe.txt".to_string(),
4999 },
5000 event_stream,
5001 cx,
5002 )
5003 });
5004
5005 let result = task.await;
5006 assert!(
5007 result.is_err(),
5008 "expected move to be blocked due to destination path"
5009 );
5010 assert!(
5011 result.unwrap_err().to_string().contains("blocked"),
5012 "error should mention the move was blocked"
5013 );
5014}
5015
5016#[gpui::test]
5017async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5018 init_test(cx);
5019
5020 let fs = FakeFs::new(cx.executor());
5021 fs.insert_tree(
5022 "/root",
5023 json!({
5024 "secret.txt": "secret content",
5025 "public": {}
5026 }),
5027 )
5028 .await;
5029 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5030
5031 cx.update(|cx| {
5032 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5033 settings.tool_permissions.tools.insert(
5034 MovePathTool::NAME.into(),
5035 agent_settings::ToolRules {
5036 default: Some(settings::ToolPermissionMode::Allow),
5037 always_allow: vec![],
5038 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5039 always_confirm: vec![],
5040 invalid_patterns: vec![],
5041 },
5042 );
5043 agent_settings::AgentSettings::override_global(settings, cx);
5044 });
5045
5046 #[allow(clippy::arc_with_non_send_sync)]
5047 let tool = Arc::new(crate::MovePathTool::new(project));
5048 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5049
5050 let task = cx.update(|cx| {
5051 tool.run(
5052 crate::MovePathToolInput {
5053 source_path: "root/secret.txt".to_string(),
5054 destination_path: "root/public/not_secret.txt".to_string(),
5055 },
5056 event_stream,
5057 cx,
5058 )
5059 });
5060
5061 let result = task.await;
5062 assert!(
5063 result.is_err(),
5064 "expected move to be blocked due to source path"
5065 );
5066 assert!(
5067 result.unwrap_err().to_string().contains("blocked"),
5068 "error should mention the move was blocked"
5069 );
5070}
5071
5072#[gpui::test]
5073async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5074 init_test(cx);
5075
5076 let fs = FakeFs::new(cx.executor());
5077 fs.insert_tree(
5078 "/root",
5079 json!({
5080 "confidential.txt": "confidential data",
5081 "dest": {}
5082 }),
5083 )
5084 .await;
5085 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5086
5087 cx.update(|cx| {
5088 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5089 settings.tool_permissions.tools.insert(
5090 CopyPathTool::NAME.into(),
5091 agent_settings::ToolRules {
5092 default: Some(settings::ToolPermissionMode::Allow),
5093 always_allow: vec![],
5094 always_deny: vec![
5095 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5096 ],
5097 always_confirm: vec![],
5098 invalid_patterns: vec![],
5099 },
5100 );
5101 agent_settings::AgentSettings::override_global(settings, cx);
5102 });
5103
5104 #[allow(clippy::arc_with_non_send_sync)]
5105 let tool = Arc::new(crate::CopyPathTool::new(project));
5106 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5107
5108 let task = cx.update(|cx| {
5109 tool.run(
5110 crate::CopyPathToolInput {
5111 source_path: "root/confidential.txt".to_string(),
5112 destination_path: "root/dest/copy.txt".to_string(),
5113 },
5114 event_stream,
5115 cx,
5116 )
5117 });
5118
5119 let result = task.await;
5120 assert!(result.is_err(), "expected copy to be blocked");
5121 assert!(
5122 result.unwrap_err().to_string().contains("blocked"),
5123 "error should mention the copy was blocked"
5124 );
5125}
5126
5127#[gpui::test]
5128async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5129 init_test(cx);
5130
5131 let fs = FakeFs::new(cx.executor());
5132 fs.insert_tree(
5133 "/root",
5134 json!({
5135 "normal.txt": "normal content",
5136 "readonly": {
5137 "config.txt": "readonly content"
5138 }
5139 }),
5140 )
5141 .await;
5142 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5143
5144 cx.update(|cx| {
5145 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5146 settings.tool_permissions.tools.insert(
5147 SaveFileTool::NAME.into(),
5148 agent_settings::ToolRules {
5149 default: Some(settings::ToolPermissionMode::Allow),
5150 always_allow: vec![],
5151 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5152 always_confirm: vec![],
5153 invalid_patterns: vec![],
5154 },
5155 );
5156 agent_settings::AgentSettings::override_global(settings, cx);
5157 });
5158
5159 #[allow(clippy::arc_with_non_send_sync)]
5160 let tool = Arc::new(crate::SaveFileTool::new(project));
5161 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5162
5163 let task = cx.update(|cx| {
5164 tool.run(
5165 crate::SaveFileToolInput {
5166 paths: vec![
5167 std::path::PathBuf::from("root/normal.txt"),
5168 std::path::PathBuf::from("root/readonly/config.txt"),
5169 ],
5170 },
5171 event_stream,
5172 cx,
5173 )
5174 });
5175
5176 let result = task.await;
5177 assert!(
5178 result.is_err(),
5179 "expected save to be blocked due to denied path"
5180 );
5181 assert!(
5182 result.unwrap_err().to_string().contains("blocked"),
5183 "error should mention the save was blocked"
5184 );
5185}
5186
5187#[gpui::test]
5188async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5189 init_test(cx);
5190
5191 let fs = FakeFs::new(cx.executor());
5192 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5193 .await;
5194 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5195
5196 cx.update(|cx| {
5197 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5198 settings.tool_permissions.tools.insert(
5199 SaveFileTool::NAME.into(),
5200 agent_settings::ToolRules {
5201 default: Some(settings::ToolPermissionMode::Allow),
5202 always_allow: vec![],
5203 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5204 always_confirm: vec![],
5205 invalid_patterns: vec![],
5206 },
5207 );
5208 agent_settings::AgentSettings::override_global(settings, cx);
5209 });
5210
5211 #[allow(clippy::arc_with_non_send_sync)]
5212 let tool = Arc::new(crate::SaveFileTool::new(project));
5213 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5214
5215 let task = cx.update(|cx| {
5216 tool.run(
5217 crate::SaveFileToolInput {
5218 paths: vec![std::path::PathBuf::from("root/config.secret")],
5219 },
5220 event_stream,
5221 cx,
5222 )
5223 });
5224
5225 let result = task.await;
5226 assert!(result.is_err(), "expected save to be blocked");
5227 assert!(
5228 result.unwrap_err().to_string().contains("blocked"),
5229 "error should mention the save was blocked"
5230 );
5231}
5232
5233#[gpui::test]
5234async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5235 init_test(cx);
5236
5237 cx.update(|cx| {
5238 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5239 settings.tool_permissions.tools.insert(
5240 WebSearchTool::NAME.into(),
5241 agent_settings::ToolRules {
5242 default: Some(settings::ToolPermissionMode::Allow),
5243 always_allow: vec![],
5244 always_deny: vec![
5245 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5246 ],
5247 always_confirm: vec![],
5248 invalid_patterns: vec![],
5249 },
5250 );
5251 agent_settings::AgentSettings::override_global(settings, cx);
5252 });
5253
5254 #[allow(clippy::arc_with_non_send_sync)]
5255 let tool = Arc::new(crate::WebSearchTool);
5256 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5257
5258 let input: crate::WebSearchToolInput =
5259 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5260
5261 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5262
5263 let result = task.await;
5264 assert!(result.is_err(), "expected search to be blocked");
5265 assert!(
5266 result.unwrap_err().to_string().contains("blocked"),
5267 "error should mention the search was blocked"
5268 );
5269}
5270
5271#[gpui::test]
5272async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5273 init_test(cx);
5274
5275 let fs = FakeFs::new(cx.executor());
5276 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5277 .await;
5278 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5279
5280 cx.update(|cx| {
5281 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5282 settings.tool_permissions.tools.insert(
5283 EditFileTool::NAME.into(),
5284 agent_settings::ToolRules {
5285 default: Some(settings::ToolPermissionMode::Confirm),
5286 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5287 always_deny: vec![],
5288 always_confirm: vec![],
5289 invalid_patterns: vec![],
5290 },
5291 );
5292 agent_settings::AgentSettings::override_global(settings, cx);
5293 });
5294
5295 let context_server_registry =
5296 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5297 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5298 let templates = crate::Templates::new();
5299 let thread = cx.new(|cx| {
5300 crate::Thread::new(
5301 project.clone(),
5302 cx.new(|_cx| prompt_store::ProjectContext::default()),
5303 context_server_registry,
5304 templates.clone(),
5305 None,
5306 cx,
5307 )
5308 });
5309
5310 #[allow(clippy::arc_with_non_send_sync)]
5311 let tool = Arc::new(crate::EditFileTool::new(
5312 project,
5313 thread.downgrade(),
5314 language_registry,
5315 templates,
5316 ));
5317 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5318
5319 let _task = cx.update(|cx| {
5320 tool.run(
5321 crate::EditFileToolInput {
5322 display_description: "Edit README".to_string(),
5323 path: "root/README.md".into(),
5324 mode: crate::EditFileMode::Edit,
5325 },
5326 event_stream,
5327 cx,
5328 )
5329 });
5330
5331 cx.run_until_parked();
5332
5333 let event = rx.try_next();
5334 assert!(
5335 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5336 "expected no authorization request for allowed .md file"
5337 );
5338}
5339
5340#[gpui::test]
5341async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5342 init_test(cx);
5343
5344 let fs = FakeFs::new(cx.executor());
5345 fs.insert_tree(
5346 "/root",
5347 json!({
5348 ".zed": {
5349 "settings.json": "{}"
5350 },
5351 "README.md": "# Hello"
5352 }),
5353 )
5354 .await;
5355 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5356
5357 cx.update(|cx| {
5358 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5359 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5360 agent_settings::AgentSettings::override_global(settings, cx);
5361 });
5362
5363 let context_server_registry =
5364 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5365 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5366 let templates = crate::Templates::new();
5367 let thread = cx.new(|cx| {
5368 crate::Thread::new(
5369 project.clone(),
5370 cx.new(|_cx| prompt_store::ProjectContext::default()),
5371 context_server_registry,
5372 templates.clone(),
5373 None,
5374 cx,
5375 )
5376 });
5377
5378 #[allow(clippy::arc_with_non_send_sync)]
5379 let tool = Arc::new(crate::EditFileTool::new(
5380 project,
5381 thread.downgrade(),
5382 language_registry,
5383 templates,
5384 ));
5385
5386 // Editing a file inside .zed/ should still prompt even with global default: allow,
5387 // because local settings paths are sensitive and require confirmation regardless.
5388 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5389 let _task = cx.update(|cx| {
5390 tool.run(
5391 crate::EditFileToolInput {
5392 display_description: "Edit local settings".to_string(),
5393 path: "root/.zed/settings.json".into(),
5394 mode: crate::EditFileMode::Edit,
5395 },
5396 event_stream,
5397 cx,
5398 )
5399 });
5400
5401 let _update = rx.expect_update_fields().await;
5402 let _auth = rx.expect_authorization().await;
5403}
5404
5405#[gpui::test]
5406async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5407 init_test(cx);
5408
5409 cx.update(|cx| {
5410 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5411 settings.tool_permissions.tools.insert(
5412 FetchTool::NAME.into(),
5413 agent_settings::ToolRules {
5414 default: Some(settings::ToolPermissionMode::Allow),
5415 always_allow: vec![],
5416 always_deny: vec![
5417 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5418 ],
5419 always_confirm: vec![],
5420 invalid_patterns: vec![],
5421 },
5422 );
5423 agent_settings::AgentSettings::override_global(settings, cx);
5424 });
5425
5426 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5427
5428 #[allow(clippy::arc_with_non_send_sync)]
5429 let tool = Arc::new(crate::FetchTool::new(http_client));
5430 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5431
5432 let input: crate::FetchToolInput =
5433 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5434
5435 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5436
5437 let result = task.await;
5438 assert!(result.is_err(), "expected fetch to be blocked");
5439 assert!(
5440 result.unwrap_err().to_string().contains("blocked"),
5441 "error should mention the fetch was blocked"
5442 );
5443}
5444
5445#[gpui::test]
5446async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5447 init_test(cx);
5448
5449 cx.update(|cx| {
5450 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5451 settings.tool_permissions.tools.insert(
5452 FetchTool::NAME.into(),
5453 agent_settings::ToolRules {
5454 default: Some(settings::ToolPermissionMode::Confirm),
5455 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5456 always_deny: vec![],
5457 always_confirm: vec![],
5458 invalid_patterns: vec![],
5459 },
5460 );
5461 agent_settings::AgentSettings::override_global(settings, cx);
5462 });
5463
5464 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5465
5466 #[allow(clippy::arc_with_non_send_sync)]
5467 let tool = Arc::new(crate::FetchTool::new(http_client));
5468 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5469
5470 let input: crate::FetchToolInput =
5471 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5472
5473 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5474
5475 cx.run_until_parked();
5476
5477 let event = rx.try_next();
5478 assert!(
5479 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5480 "expected no authorization request for allowed docs.rs URL"
5481 );
5482}
5483
5484#[gpui::test]
5485async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5486 init_test(cx);
5487 always_allow_tools(cx);
5488
5489 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5490 let fake_model = model.as_fake();
5491
5492 // Add a tool so we can simulate tool calls
5493 thread.update(cx, |thread, _cx| {
5494 thread.add_tool(EchoTool, None);
5495 });
5496
5497 // Start a turn by sending a message
5498 let mut events = thread
5499 .update(cx, |thread, cx| {
5500 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5501 })
5502 .unwrap();
5503 cx.run_until_parked();
5504
5505 // Simulate the model making a tool call
5506 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5507 LanguageModelToolUse {
5508 id: "tool_1".into(),
5509 name: "echo".into(),
5510 raw_input: r#"{"text": "hello"}"#.into(),
5511 input: json!({"text": "hello"}),
5512 is_input_complete: true,
5513 thought_signature: None,
5514 },
5515 ));
5516 fake_model
5517 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5518
5519 // Signal that a message is queued before ending the stream
5520 thread.update(cx, |thread, _cx| {
5521 thread.set_has_queued_message(true);
5522 });
5523
5524 // Now end the stream - tool will run, and the boundary check should see the queue
5525 fake_model.end_last_completion_stream();
5526
5527 // Collect all events until the turn stops
5528 let all_events = collect_events_until_stop(&mut events, cx).await;
5529
5530 // Verify we received the tool call event
5531 let tool_call_ids: Vec<_> = all_events
5532 .iter()
5533 .filter_map(|e| match e {
5534 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5535 _ => None,
5536 })
5537 .collect();
5538 assert_eq!(
5539 tool_call_ids,
5540 vec!["tool_1"],
5541 "Should have received a tool call event for our echo tool"
5542 );
5543
5544 // The turn should have stopped with EndTurn
5545 let stop_reasons = stop_events(all_events);
5546 assert_eq!(
5547 stop_reasons,
5548 vec![acp::StopReason::EndTurn],
5549 "Turn should have ended after tool completion due to queued message"
5550 );
5551
5552 // Verify the queued message flag is still set
5553 thread.update(cx, |thread, _cx| {
5554 assert!(
5555 thread.has_queued_message(),
5556 "Should still have queued message flag set"
5557 );
5558 });
5559
5560 // Thread should be idle now
5561 thread.update(cx, |thread, _cx| {
5562 assert!(
5563 thread.is_turn_complete(),
5564 "Thread should not be running after turn ends"
5565 );
5566 });
5567}