1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeSubagentHandle {
159 session_id: acp::SessionId,
160 wait_for_summary_task: Shared<Task<String>>,
161}
162
163impl FakeSubagentHandle {
164 fn new_never_completes(cx: &App) -> Self {
165 Self {
166 session_id: acp::SessionId::new("subagent-id"),
167 wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
168 }
169 }
170}
171
172impl SubagentHandle for FakeSubagentHandle {
173 fn id(&self) -> acp::SessionId {
174 self.session_id.clone()
175 }
176
177 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
178 let task = self.wait_for_summary_task.clone();
179 cx.background_spawn(async move { Ok(task.await) })
180 }
181}
182
183#[derive(Default)]
184struct FakeThreadEnvironment {
185 terminal_handle: Option<Rc<FakeTerminalHandle>>,
186 subagent_handle: Option<Rc<FakeSubagentHandle>>,
187}
188
189impl FakeThreadEnvironment {
190 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
191 Self {
192 terminal_handle: Some(terminal_handle.into()),
193 ..self
194 }
195 }
196
197 pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
198 Self {
199 subagent_handle: Some(subagent_handle.into()),
200 ..self
201 }
202 }
203}
204
205impl crate::ThreadEnvironment for FakeThreadEnvironment {
206 fn create_terminal(
207 &self,
208 _command: String,
209 _cwd: Option<std::path::PathBuf>,
210 _output_byte_limit: Option<u64>,
211 _cx: &mut AsyncApp,
212 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
213 let handle = self
214 .terminal_handle
215 .clone()
216 .expect("Terminal handle not available on FakeThreadEnvironment");
217 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
218 }
219
220 fn create_subagent(
221 &self,
222 _parent_thread: Entity<Thread>,
223 _label: String,
224 _initial_prompt: String,
225 _timeout_ms: Option<Duration>,
226 _allowed_tools: Option<Vec<String>>,
227 _cx: &mut App,
228 ) -> Result<Rc<dyn SubagentHandle>> {
229 Ok(self
230 .subagent_handle
231 .clone()
232 .expect("Subagent handle not available on FakeThreadEnvironment")
233 as Rc<dyn SubagentHandle>)
234 }
235}
236
237/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
238struct MultiTerminalEnvironment {
239 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
240}
241
242impl MultiTerminalEnvironment {
243 fn new() -> Self {
244 Self {
245 handles: std::cell::RefCell::new(Vec::new()),
246 }
247 }
248
249 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
250 self.handles.borrow().clone()
251 }
252}
253
254impl crate::ThreadEnvironment for MultiTerminalEnvironment {
255 fn create_terminal(
256 &self,
257 _command: String,
258 _cwd: Option<std::path::PathBuf>,
259 _output_byte_limit: Option<u64>,
260 cx: &mut AsyncApp,
261 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
263 self.handles.borrow_mut().push(handle.clone());
264 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
265 }
266
267 fn create_subagent(
268 &self,
269 _parent_thread: Entity<Thread>,
270 _label: String,
271 _initial_prompt: String,
272 _timeout: Option<Duration>,
273 _allowed_tools: Option<Vec<String>>,
274 _cx: &mut App,
275 ) -> Result<Rc<dyn SubagentHandle>> {
276 unimplemented!()
277 }
278}
279
280fn always_allow_tools(cx: &mut TestAppContext) {
281 cx.update(|cx| {
282 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
283 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
284 agent_settings::AgentSettings::override_global(settings, cx);
285 });
286}
287
288#[gpui::test]
289async fn test_echo(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
296 })
297 .unwrap();
298 cx.run_until_parked();
299 fake_model.send_last_completion_stream_text_chunk("Hello");
300 fake_model
301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
302 fake_model.end_last_completion_stream();
303
304 let events = events.collect().await;
305 thread.update(cx, |thread, _cx| {
306 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
307 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
308 });
309 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
310}
311
312#[gpui::test]
313async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
314 init_test(cx);
315 always_allow_tools(cx);
316
317 let fs = FakeFs::new(cx.executor());
318 let project = Project::test(fs, [], cx).await;
319
320 let environment = Rc::new(cx.update(|cx| {
321 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
322 }));
323 let handle = environment.terminal_handle.clone().unwrap();
324
325 #[allow(clippy::arc_with_non_send_sync)]
326 let tool = Arc::new(crate::TerminalTool::new(project, environment));
327 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
328
329 let task = cx.update(|cx| {
330 tool.run(
331 crate::TerminalToolInput {
332 command: "sleep 1000".to_string(),
333 cd: ".".to_string(),
334 timeout_ms: Some(5),
335 },
336 event_stream,
337 cx,
338 )
339 });
340
341 let update = rx.expect_update_fields().await;
342 assert!(
343 update.content.iter().any(|blocks| {
344 blocks
345 .iter()
346 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
347 }),
348 "expected tool call update to include terminal content"
349 );
350
351 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
352
353 let deadline = std::time::Instant::now() + Duration::from_millis(500);
354 loop {
355 if let Some(result) = task_future.as_mut().now_or_never() {
356 let result = result.expect("terminal tool task should complete");
357
358 assert!(
359 handle.was_killed(),
360 "expected terminal handle to be killed on timeout"
361 );
362 assert!(
363 result.contains("partial output"),
364 "expected result to include terminal output, got: {result}"
365 );
366 return;
367 }
368
369 if std::time::Instant::now() >= deadline {
370 panic!("timed out waiting for terminal tool task to complete");
371 }
372
373 cx.run_until_parked();
374 cx.background_executor.timer(Duration::from_millis(1)).await;
375 }
376}
377
378#[gpui::test]
379#[ignore]
380async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
381 init_test(cx);
382 always_allow_tools(cx);
383
384 let fs = FakeFs::new(cx.executor());
385 let project = Project::test(fs, [], cx).await;
386
387 let environment = Rc::new(cx.update(|cx| {
388 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
389 }));
390 let handle = environment.terminal_handle.clone().unwrap();
391
392 #[allow(clippy::arc_with_non_send_sync)]
393 let tool = Arc::new(crate::TerminalTool::new(project, environment));
394 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
395
396 let _task = cx.update(|cx| {
397 tool.run(
398 crate::TerminalToolInput {
399 command: "sleep 1000".to_string(),
400 cd: ".".to_string(),
401 timeout_ms: None,
402 },
403 event_stream,
404 cx,
405 )
406 });
407
408 let update = rx.expect_update_fields().await;
409 assert!(
410 update.content.iter().any(|blocks| {
411 blocks
412 .iter()
413 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
414 }),
415 "expected tool call update to include terminal content"
416 );
417
418 cx.background_executor
419 .timer(Duration::from_millis(25))
420 .await;
421
422 assert!(
423 !handle.was_killed(),
424 "did not expect terminal handle to be killed without a timeout"
425 );
426}
427
428#[gpui::test]
429async fn test_thinking(cx: &mut TestAppContext) {
430 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
431 let fake_model = model.as_fake();
432
433 let events = thread
434 .update(cx, |thread, cx| {
435 thread.send(
436 UserMessageId::new(),
437 [indoc! {"
438 Testing:
439
440 Generate a thinking step where you just think the word 'Think',
441 and have your final answer be 'Hello'
442 "}],
443 cx,
444 )
445 })
446 .unwrap();
447 cx.run_until_parked();
448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
449 text: "Think".to_string(),
450 signature: None,
451 });
452 fake_model.send_last_completion_stream_text_chunk("Hello");
453 fake_model
454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
455 fake_model.end_last_completion_stream();
456
457 let events = events.collect().await;
458 thread.update(cx, |thread, _cx| {
459 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
460 assert_eq!(
461 thread.last_message().unwrap().to_markdown(),
462 indoc! {"
463 <think>Think</think>
464 Hello
465 "}
466 )
467 });
468 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
469}
470
471#[gpui::test]
472async fn test_system_prompt(cx: &mut TestAppContext) {
473 let ThreadTest {
474 model,
475 thread,
476 project_context,
477 ..
478 } = setup(cx, TestModel::Fake).await;
479 let fake_model = model.as_fake();
480
481 project_context.update(cx, |project_context, _cx| {
482 project_context.shell = "test-shell".into()
483 });
484 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["abc"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491 let mut pending_completions = fake_model.pending_completions();
492 assert_eq!(
493 pending_completions.len(),
494 1,
495 "unexpected pending completions: {:?}",
496 pending_completions
497 );
498
499 let pending_completion = pending_completions.pop().unwrap();
500 assert_eq!(pending_completion.messages[0].role, Role::System);
501
502 let system_message = &pending_completion.messages[0];
503 let system_prompt = system_message.content[0].to_str().unwrap();
504 assert!(
505 system_prompt.contains("test-shell"),
506 "unexpected system message: {:?}",
507 system_message
508 );
509 assert!(
510 system_prompt.contains("## Fixing Diagnostics"),
511 "unexpected system message: {:?}",
512 system_message
513 );
514}
515
516#[gpui::test]
517async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
518 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
519 let fake_model = model.as_fake();
520
521 thread
522 .update(cx, |thread, cx| {
523 thread.send(UserMessageId::new(), ["abc"], cx)
524 })
525 .unwrap();
526 cx.run_until_parked();
527 let mut pending_completions = fake_model.pending_completions();
528 assert_eq!(
529 pending_completions.len(),
530 1,
531 "unexpected pending completions: {:?}",
532 pending_completions
533 );
534
535 let pending_completion = pending_completions.pop().unwrap();
536 assert_eq!(pending_completion.messages[0].role, Role::System);
537
538 let system_message = &pending_completion.messages[0];
539 let system_prompt = system_message.content[0].to_str().unwrap();
540 assert!(
541 !system_prompt.contains("## Tool Use"),
542 "unexpected system message: {:?}",
543 system_message
544 );
545 assert!(
546 !system_prompt.contains("## Fixing Diagnostics"),
547 "unexpected system message: {:?}",
548 system_message
549 );
550}
551
552#[gpui::test]
553async fn test_prompt_caching(cx: &mut TestAppContext) {
554 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
555 let fake_model = model.as_fake();
556
557 // Send initial user message and verify it's cached
558 thread
559 .update(cx, |thread, cx| {
560 thread.send(UserMessageId::new(), ["Message 1"], cx)
561 })
562 .unwrap();
563 cx.run_until_parked();
564
565 let completion = fake_model.pending_completions().pop().unwrap();
566 assert_eq!(
567 completion.messages[1..],
568 vec![LanguageModelRequestMessage {
569 role: Role::User,
570 content: vec!["Message 1".into()],
571 cache: true,
572 reasoning_details: None,
573 }]
574 );
575 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
576 "Response to Message 1".into(),
577 ));
578 fake_model.end_last_completion_stream();
579 cx.run_until_parked();
580
581 // Send another user message and verify only the latest is cached
582 thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["Message 2"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588
589 let completion = fake_model.pending_completions().pop().unwrap();
590 assert_eq!(
591 completion.messages[1..],
592 vec![
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 1".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 1".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Message 2".into()],
608 cache: true,
609 reasoning_details: None,
610 }
611 ]
612 );
613 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
614 "Response to Message 2".into(),
615 ));
616 fake_model.end_last_completion_stream();
617 cx.run_until_parked();
618
619 // Simulate a tool call and verify that the latest tool result is cached
620 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
621 thread
622 .update(cx, |thread, cx| {
623 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
624 })
625 .unwrap();
626 cx.run_until_parked();
627
628 let tool_use = LanguageModelToolUse {
629 id: "tool_1".into(),
630 name: EchoTool::NAME.into(),
631 raw_input: json!({"text": "test"}).to_string(),
632 input: json!({"text": "test"}),
633 is_input_complete: true,
634 thought_signature: None,
635 };
636 fake_model
637 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
638 fake_model.end_last_completion_stream();
639 cx.run_until_parked();
640
641 let completion = fake_model.pending_completions().pop().unwrap();
642 let tool_result = LanguageModelToolResult {
643 tool_use_id: "tool_1".into(),
644 tool_name: EchoTool::NAME.into(),
645 is_error: false,
646 content: "test".into(),
647 output: Some("test".into()),
648 };
649 assert_eq!(
650 completion.messages[1..],
651 vec![
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec!["Message 1".into()],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::Assistant,
660 content: vec!["Response to Message 1".into()],
661 cache: false,
662 reasoning_details: None,
663 },
664 LanguageModelRequestMessage {
665 role: Role::User,
666 content: vec!["Message 2".into()],
667 cache: false,
668 reasoning_details: None,
669 },
670 LanguageModelRequestMessage {
671 role: Role::Assistant,
672 content: vec!["Response to Message 2".into()],
673 cache: false,
674 reasoning_details: None,
675 },
676 LanguageModelRequestMessage {
677 role: Role::User,
678 content: vec!["Use the echo tool".into()],
679 cache: false,
680 reasoning_details: None,
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false,
686 reasoning_details: None,
687 },
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec![MessageContent::ToolResult(tool_result)],
691 cache: true,
692 reasoning_details: None,
693 }
694 ]
695 );
696}
697
698#[gpui::test]
699#[cfg_attr(not(feature = "e2e"), ignore)]
700async fn test_basic_tool_calls(cx: &mut TestAppContext) {
701 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
702
703 // Test a tool call that's likely to complete *before* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.add_tool(EchoTool, None);
707 thread.send(
708 UserMessageId::new(),
709 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
710 cx,
711 )
712 })
713 .unwrap()
714 .collect()
715 .await;
716 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
717
718 // Test a tool calls that's likely to complete *after* streaming stops.
719 let events = thread
720 .update(cx, |thread, cx| {
721 thread.remove_tool(&EchoTool::NAME);
722 thread.add_tool(DelayTool, None);
723 thread.send(
724 UserMessageId::new(),
725 [
726 "Now call the delay tool with 200ms.",
727 "When the timer goes off, then you echo the output of the tool.",
728 ],
729 cx,
730 )
731 })
732 .unwrap()
733 .collect()
734 .await;
735 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
736 thread.update(cx, |thread, _cx| {
737 assert!(
738 thread
739 .last_message()
740 .unwrap()
741 .as_agent_message()
742 .unwrap()
743 .content
744 .iter()
745 .any(|content| {
746 if let AgentMessageContent::Text(text) = content {
747 text.contains("Ding")
748 } else {
749 false
750 }
751 }),
752 "{}",
753 thread.to_markdown()
754 );
755 });
756}
757
758#[gpui::test]
759#[cfg_attr(not(feature = "e2e"), ignore)]
760async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
761 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
762
763 // Test a tool call that's likely to complete *before* streaming stops.
764 let mut events = thread
765 .update(cx, |thread, cx| {
766 thread.add_tool(WordListTool, None);
767 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
768 })
769 .unwrap();
770
771 let mut saw_partial_tool_use = false;
772 while let Some(event) = events.next().await {
773 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
774 thread.update(cx, |thread, _cx| {
775 // Look for a tool use in the thread's last message
776 let message = thread.last_message().unwrap();
777 let agent_message = message.as_agent_message().unwrap();
778 let last_content = agent_message.content.last().unwrap();
779 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
780 assert_eq!(last_tool_use.name.as_ref(), "word_list");
781 if tool_call.status == acp::ToolCallStatus::Pending {
782 if !last_tool_use.is_input_complete
783 && last_tool_use.input.get("g").is_none()
784 {
785 saw_partial_tool_use = true;
786 }
787 } else {
788 last_tool_use
789 .input
790 .get("a")
791 .expect("'a' has streamed because input is now complete");
792 last_tool_use
793 .input
794 .get("g")
795 .expect("'g' has streamed because input is now complete");
796 }
797 } else {
798 panic!("last content should be a tool use");
799 }
800 });
801 }
802 }
803
804 assert!(
805 saw_partial_tool_use,
806 "should see at least one partially streamed tool use in the history"
807 );
808}
809
810#[gpui::test]
811async fn test_tool_authorization(cx: &mut TestAppContext) {
812 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
813 let fake_model = model.as_fake();
814
815 let mut events = thread
816 .update(cx, |thread, cx| {
817 thread.add_tool(ToolRequiringPermission, None);
818 thread.send(UserMessageId::new(), ["abc"], cx)
819 })
820 .unwrap();
821 cx.run_until_parked();
822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
823 LanguageModelToolUse {
824 id: "tool_id_1".into(),
825 name: ToolRequiringPermission::NAME.into(),
826 raw_input: "{}".into(),
827 input: json!({}),
828 is_input_complete: true,
829 thought_signature: None,
830 },
831 ));
832 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
833 LanguageModelToolUse {
834 id: "tool_id_2".into(),
835 name: ToolRequiringPermission::NAME.into(),
836 raw_input: "{}".into(),
837 input: json!({}),
838 is_input_complete: true,
839 thought_signature: None,
840 },
841 ));
842 fake_model.end_last_completion_stream();
843 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
844 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
845
846 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
847 tool_call_auth_1
848 .response
849 .send(acp::PermissionOptionId::new("allow"))
850 .unwrap();
851 cx.run_until_parked();
852
853 // Reject the second - send "deny" option_id directly since Deny is now a button
854 tool_call_auth_2
855 .response
856 .send(acp::PermissionOptionId::new("deny"))
857 .unwrap();
858 cx.run_until_parked();
859
860 let completion = fake_model.pending_completions().pop().unwrap();
861 let message = completion.messages.last().unwrap();
862 assert_eq!(
863 message.content,
864 vec![
865 language_model::MessageContent::ToolResult(LanguageModelToolResult {
866 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
867 tool_name: ToolRequiringPermission::NAME.into(),
868 is_error: false,
869 content: "Allowed".into(),
870 output: Some("Allowed".into())
871 }),
872 language_model::MessageContent::ToolResult(LanguageModelToolResult {
873 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
874 tool_name: ToolRequiringPermission::NAME.into(),
875 is_error: true,
876 content: "Permission to run tool denied by user".into(),
877 output: Some("Permission to run tool denied by user".into())
878 })
879 ]
880 );
881
882 // Simulate yet another tool call.
883 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
884 LanguageModelToolUse {
885 id: "tool_id_3".into(),
886 name: ToolRequiringPermission::NAME.into(),
887 raw_input: "{}".into(),
888 input: json!({}),
889 is_input_complete: true,
890 thought_signature: None,
891 },
892 ));
893 fake_model.end_last_completion_stream();
894
895 // Respond by always allowing tools - send transformed option_id
896 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
897 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
898 tool_call_auth_3
899 .response
900 .send(acp::PermissionOptionId::new(
901 "always_allow:tool_requiring_permission",
902 ))
903 .unwrap();
904 cx.run_until_parked();
905 let completion = fake_model.pending_completions().pop().unwrap();
906 let message = completion.messages.last().unwrap();
907 assert_eq!(
908 message.content,
909 vec![language_model::MessageContent::ToolResult(
910 LanguageModelToolResult {
911 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
912 tool_name: ToolRequiringPermission::NAME.into(),
913 is_error: false,
914 content: "Allowed".into(),
915 output: Some("Allowed".into())
916 }
917 )]
918 );
919
920 // Simulate a final tool call, ensuring we don't trigger authorization.
921 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
922 LanguageModelToolUse {
923 id: "tool_id_4".into(),
924 name: ToolRequiringPermission::NAME.into(),
925 raw_input: "{}".into(),
926 input: json!({}),
927 is_input_complete: true,
928 thought_signature: None,
929 },
930 ));
931 fake_model.end_last_completion_stream();
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let message = completion.messages.last().unwrap();
935 assert_eq!(
936 message.content,
937 vec![language_model::MessageContent::ToolResult(
938 LanguageModelToolResult {
939 tool_use_id: "tool_id_4".into(),
940 tool_name: ToolRequiringPermission::NAME.into(),
941 is_error: false,
942 content: "Allowed".into(),
943 output: Some("Allowed".into())
944 }
945 )]
946 );
947}
948
949#[gpui::test]
950async fn test_tool_hallucination(cx: &mut TestAppContext) {
951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
952 let fake_model = model.as_fake();
953
954 let mut events = thread
955 .update(cx, |thread, cx| {
956 thread.send(UserMessageId::new(), ["abc"], cx)
957 })
958 .unwrap();
959 cx.run_until_parked();
960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
961 LanguageModelToolUse {
962 id: "tool_id_1".into(),
963 name: "nonexistent_tool".into(),
964 raw_input: "{}".into(),
965 input: json!({}),
966 is_input_complete: true,
967 thought_signature: None,
968 },
969 ));
970 fake_model.end_last_completion_stream();
971
972 let tool_call = expect_tool_call(&mut events).await;
973 assert_eq!(tool_call.title, "nonexistent_tool");
974 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
975 let update = expect_tool_call_update_fields(&mut events).await;
976 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
977}
978
979async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
980 let event = events
981 .next()
982 .await
983 .expect("no tool call authorization event received")
984 .unwrap();
985 match event {
986 ThreadEvent::ToolCall(tool_call) => tool_call,
987 event => {
988 panic!("Unexpected event {event:?}");
989 }
990 }
991}
992
993async fn expect_tool_call_update_fields(
994 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
995) -> acp::ToolCallUpdate {
996 let event = events
997 .next()
998 .await
999 .expect("no tool call authorization event received")
1000 .unwrap();
1001 match event {
1002 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1003 event => {
1004 panic!("Unexpected event {event:?}");
1005 }
1006 }
1007}
1008
1009async fn next_tool_call_authorization(
1010 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1011) -> ToolCallAuthorization {
1012 loop {
1013 let event = events
1014 .next()
1015 .await
1016 .expect("no tool call authorization event received")
1017 .unwrap();
1018 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1019 let permission_kinds = tool_call_authorization
1020 .options
1021 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1022 .map(|option| option.kind);
1023 let allow_once = tool_call_authorization
1024 .options
1025 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1026 .map(|option| option.kind);
1027
1028 assert_eq!(
1029 permission_kinds,
1030 Some(acp::PermissionOptionKind::AllowAlways)
1031 );
1032 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1033 return tool_call_authorization;
1034 }
1035 }
1036}
1037
1038#[test]
1039fn test_permission_options_terminal_with_pattern() {
1040 let permission_options = ToolPermissionContext::new(
1041 TerminalTool::NAME,
1042 vec!["cargo build --release".to_string()],
1043 )
1044 .build_permission_options();
1045
1046 let PermissionOptions::Dropdown(choices) = permission_options else {
1047 panic!("Expected dropdown permission options");
1048 };
1049
1050 assert_eq!(choices.len(), 3);
1051 let labels: Vec<&str> = choices
1052 .iter()
1053 .map(|choice| choice.allow.name.as_ref())
1054 .collect();
1055 assert!(labels.contains(&"Always for terminal"));
1056 assert!(labels.contains(&"Always for `cargo` commands"));
1057 assert!(labels.contains(&"Only this time"));
1058}
1059
1060#[test]
1061fn test_permission_options_edit_file_with_path_pattern() {
1062 let permission_options =
1063 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1064 .build_permission_options();
1065
1066 let PermissionOptions::Dropdown(choices) = permission_options else {
1067 panic!("Expected dropdown permission options");
1068 };
1069
1070 let labels: Vec<&str> = choices
1071 .iter()
1072 .map(|choice| choice.allow.name.as_ref())
1073 .collect();
1074 assert!(labels.contains(&"Always for edit file"));
1075 assert!(labels.contains(&"Always for `src/`"));
1076}
1077
1078#[test]
1079fn test_permission_options_fetch_with_domain_pattern() {
1080 let permission_options =
1081 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1082 .build_permission_options();
1083
1084 let PermissionOptions::Dropdown(choices) = permission_options else {
1085 panic!("Expected dropdown permission options");
1086 };
1087
1088 let labels: Vec<&str> = choices
1089 .iter()
1090 .map(|choice| choice.allow.name.as_ref())
1091 .collect();
1092 assert!(labels.contains(&"Always for fetch"));
1093 assert!(labels.contains(&"Always for `docs.rs`"));
1094}
1095
1096#[test]
1097fn test_permission_options_without_pattern() {
1098 let permission_options = ToolPermissionContext::new(
1099 TerminalTool::NAME,
1100 vec!["./deploy.sh --production".to_string()],
1101 )
1102 .build_permission_options();
1103
1104 let PermissionOptions::Dropdown(choices) = permission_options else {
1105 panic!("Expected dropdown permission options");
1106 };
1107
1108 assert_eq!(choices.len(), 2);
1109 let labels: Vec<&str> = choices
1110 .iter()
1111 .map(|choice| choice.allow.name.as_ref())
1112 .collect();
1113 assert!(labels.contains(&"Always for terminal"));
1114 assert!(labels.contains(&"Only this time"));
1115 assert!(!labels.iter().any(|label| label.contains("commands")));
1116}
1117
1118#[test]
1119fn test_permission_options_symlink_target_are_flat_once_only() {
1120 let permission_options =
1121 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1122 .build_permission_options();
1123
1124 let PermissionOptions::Flat(options) = permission_options else {
1125 panic!("Expected flat permission options for symlink target authorization");
1126 };
1127
1128 assert_eq!(options.len(), 2);
1129 assert!(options.iter().any(|option| {
1130 option.option_id.0.as_ref() == "allow"
1131 && option.kind == acp::PermissionOptionKind::AllowOnce
1132 }));
1133 assert!(options.iter().any(|option| {
1134 option.option_id.0.as_ref() == "deny"
1135 && option.kind == acp::PermissionOptionKind::RejectOnce
1136 }));
1137}
1138
1139#[test]
1140fn test_permission_option_ids_for_terminal() {
1141 let permission_options = ToolPermissionContext::new(
1142 TerminalTool::NAME,
1143 vec!["cargo build --release".to_string()],
1144 )
1145 .build_permission_options();
1146
1147 let PermissionOptions::Dropdown(choices) = permission_options else {
1148 panic!("Expected dropdown permission options");
1149 };
1150
1151 let allow_ids: Vec<String> = choices
1152 .iter()
1153 .map(|choice| choice.allow.option_id.0.to_string())
1154 .collect();
1155 let deny_ids: Vec<String> = choices
1156 .iter()
1157 .map(|choice| choice.deny.option_id.0.to_string())
1158 .collect();
1159
1160 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1161 assert!(allow_ids.contains(&"allow".to_string()));
1162 assert!(
1163 allow_ids
1164 .iter()
1165 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1166 "Missing allow pattern option"
1167 );
1168
1169 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1170 assert!(deny_ids.contains(&"deny".to_string()));
1171 assert!(
1172 deny_ids
1173 .iter()
1174 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1175 "Missing deny pattern option"
1176 );
1177}
1178
1179#[gpui::test]
1180#[cfg_attr(not(feature = "e2e"), ignore)]
1181async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1182 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1183
1184 // Test concurrent tool calls with different delay times
1185 let events = thread
1186 .update(cx, |thread, cx| {
1187 thread.add_tool(DelayTool, None);
1188 thread.send(
1189 UserMessageId::new(),
1190 [
1191 "Call the delay tool twice in the same message.",
1192 "Once with 100ms. Once with 300ms.",
1193 "When both timers are complete, describe the outputs.",
1194 ],
1195 cx,
1196 )
1197 })
1198 .unwrap()
1199 .collect()
1200 .await;
1201
1202 let stop_reasons = stop_events(events);
1203 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1204
1205 thread.update(cx, |thread, _cx| {
1206 let last_message = thread.last_message().unwrap();
1207 let agent_message = last_message.as_agent_message().unwrap();
1208 let text = agent_message
1209 .content
1210 .iter()
1211 .filter_map(|content| {
1212 if let AgentMessageContent::Text(text) = content {
1213 Some(text.as_str())
1214 } else {
1215 None
1216 }
1217 })
1218 .collect::<String>();
1219
1220 assert!(text.contains("Ding"));
1221 });
1222}
1223
1224#[gpui::test]
1225async fn test_profiles(cx: &mut TestAppContext) {
1226 let ThreadTest {
1227 model, thread, fs, ..
1228 } = setup(cx, TestModel::Fake).await;
1229 let fake_model = model.as_fake();
1230
1231 thread.update(cx, |thread, _cx| {
1232 thread.add_tool(DelayTool, None);
1233 thread.add_tool(EchoTool, None);
1234 thread.add_tool(InfiniteTool, None);
1235 });
1236
1237 // Override profiles and wait for settings to be loaded.
1238 fs.insert_file(
1239 paths::settings_file(),
1240 json!({
1241 "agent": {
1242 "profiles": {
1243 "test-1": {
1244 "name": "Test Profile 1",
1245 "tools": {
1246 EchoTool::NAME: true,
1247 DelayTool::NAME: true,
1248 }
1249 },
1250 "test-2": {
1251 "name": "Test Profile 2",
1252 "tools": {
1253 InfiniteTool::NAME: true,
1254 }
1255 }
1256 }
1257 }
1258 })
1259 .to_string()
1260 .into_bytes(),
1261 )
1262 .await;
1263 cx.run_until_parked();
1264
1265 // Test that test-1 profile (default) has echo and delay tools
1266 thread
1267 .update(cx, |thread, cx| {
1268 thread.set_profile(AgentProfileId("test-1".into()), cx);
1269 thread.send(UserMessageId::new(), ["test"], cx)
1270 })
1271 .unwrap();
1272 cx.run_until_parked();
1273
1274 let mut pending_completions = fake_model.pending_completions();
1275 assert_eq!(pending_completions.len(), 1);
1276 let completion = pending_completions.pop().unwrap();
1277 let tool_names: Vec<String> = completion
1278 .tools
1279 .iter()
1280 .map(|tool| tool.name.clone())
1281 .collect();
1282 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1283 fake_model.end_last_completion_stream();
1284
1285 // Switch to test-2 profile, and verify that it has only the infinite tool.
1286 thread
1287 .update(cx, |thread, cx| {
1288 thread.set_profile(AgentProfileId("test-2".into()), cx);
1289 thread.send(UserMessageId::new(), ["test2"], cx)
1290 })
1291 .unwrap();
1292 cx.run_until_parked();
1293 let mut pending_completions = fake_model.pending_completions();
1294 assert_eq!(pending_completions.len(), 1);
1295 let completion = pending_completions.pop().unwrap();
1296 let tool_names: Vec<String> = completion
1297 .tools
1298 .iter()
1299 .map(|tool| tool.name.clone())
1300 .collect();
1301 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1302}
1303
1304#[gpui::test]
1305async fn test_mcp_tools(cx: &mut TestAppContext) {
1306 let ThreadTest {
1307 model,
1308 thread,
1309 context_server_store,
1310 fs,
1311 ..
1312 } = setup(cx, TestModel::Fake).await;
1313 let fake_model = model.as_fake();
1314
1315 // Override profiles and wait for settings to be loaded.
1316 fs.insert_file(
1317 paths::settings_file(),
1318 json!({
1319 "agent": {
1320 "tool_permissions": { "default": "allow" },
1321 "profiles": {
1322 "test": {
1323 "name": "Test Profile",
1324 "enable_all_context_servers": true,
1325 "tools": {
1326 EchoTool::NAME: true,
1327 }
1328 },
1329 }
1330 }
1331 })
1332 .to_string()
1333 .into_bytes(),
1334 )
1335 .await;
1336 cx.run_until_parked();
1337 thread.update(cx, |thread, cx| {
1338 thread.set_profile(AgentProfileId("test".into()), cx)
1339 });
1340
1341 let mut mcp_tool_calls = setup_context_server(
1342 "test_server",
1343 vec![context_server::types::Tool {
1344 name: "echo".into(),
1345 description: None,
1346 input_schema: serde_json::to_value(EchoTool::input_schema(
1347 LanguageModelToolSchemaFormat::JsonSchema,
1348 ))
1349 .unwrap(),
1350 output_schema: None,
1351 annotations: None,
1352 }],
1353 &context_server_store,
1354 cx,
1355 );
1356
1357 let events = thread.update(cx, |thread, cx| {
1358 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1359 });
1360 cx.run_until_parked();
1361
1362 // Simulate the model calling the MCP tool.
1363 let completion = fake_model.pending_completions().pop().unwrap();
1364 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1365 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1366 LanguageModelToolUse {
1367 id: "tool_1".into(),
1368 name: "echo".into(),
1369 raw_input: json!({"text": "test"}).to_string(),
1370 input: json!({"text": "test"}),
1371 is_input_complete: true,
1372 thought_signature: None,
1373 },
1374 ));
1375 fake_model.end_last_completion_stream();
1376 cx.run_until_parked();
1377
1378 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1379 assert_eq!(tool_call_params.name, "echo");
1380 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1381 tool_call_response
1382 .send(context_server::types::CallToolResponse {
1383 content: vec![context_server::types::ToolResponseContent::Text {
1384 text: "test".into(),
1385 }],
1386 is_error: None,
1387 meta: None,
1388 structured_content: None,
1389 })
1390 .unwrap();
1391 cx.run_until_parked();
1392
1393 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1394 fake_model.send_last_completion_stream_text_chunk("Done!");
1395 fake_model.end_last_completion_stream();
1396 events.collect::<Vec<_>>().await;
1397
1398 // Send again after adding the echo tool, ensuring the name collision is resolved.
1399 let events = thread.update(cx, |thread, cx| {
1400 thread.add_tool(EchoTool, None);
1401 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1402 });
1403 cx.run_until_parked();
1404 let completion = fake_model.pending_completions().pop().unwrap();
1405 assert_eq!(
1406 tool_names_for_completion(&completion),
1407 vec!["echo", "test_server_echo"]
1408 );
1409 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1410 LanguageModelToolUse {
1411 id: "tool_2".into(),
1412 name: "test_server_echo".into(),
1413 raw_input: json!({"text": "mcp"}).to_string(),
1414 input: json!({"text": "mcp"}),
1415 is_input_complete: true,
1416 thought_signature: None,
1417 },
1418 ));
1419 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1420 LanguageModelToolUse {
1421 id: "tool_3".into(),
1422 name: "echo".into(),
1423 raw_input: json!({"text": "native"}).to_string(),
1424 input: json!({"text": "native"}),
1425 is_input_complete: true,
1426 thought_signature: None,
1427 },
1428 ));
1429 fake_model.end_last_completion_stream();
1430 cx.run_until_parked();
1431
1432 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1433 assert_eq!(tool_call_params.name, "echo");
1434 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1435 tool_call_response
1436 .send(context_server::types::CallToolResponse {
1437 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1438 is_error: None,
1439 meta: None,
1440 structured_content: None,
1441 })
1442 .unwrap();
1443 cx.run_until_parked();
1444
1445 // Ensure the tool results were inserted with the correct names.
1446 let completion = fake_model.pending_completions().pop().unwrap();
1447 assert_eq!(
1448 completion.messages.last().unwrap().content,
1449 vec![
1450 MessageContent::ToolResult(LanguageModelToolResult {
1451 tool_use_id: "tool_3".into(),
1452 tool_name: "echo".into(),
1453 is_error: false,
1454 content: "native".into(),
1455 output: Some("native".into()),
1456 },),
1457 MessageContent::ToolResult(LanguageModelToolResult {
1458 tool_use_id: "tool_2".into(),
1459 tool_name: "test_server_echo".into(),
1460 is_error: false,
1461 content: "mcp".into(),
1462 output: Some("mcp".into()),
1463 },),
1464 ]
1465 );
1466 fake_model.end_last_completion_stream();
1467 events.collect::<Vec<_>>().await;
1468}
1469
1470#[gpui::test]
1471async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1472 let ThreadTest {
1473 model,
1474 thread,
1475 context_server_store,
1476 fs,
1477 ..
1478 } = setup(cx, TestModel::Fake).await;
1479 let fake_model = model.as_fake();
1480
1481 // Setup settings to allow MCP tools
1482 fs.insert_file(
1483 paths::settings_file(),
1484 json!({
1485 "agent": {
1486 "always_allow_tool_actions": true,
1487 "profiles": {
1488 "test": {
1489 "name": "Test Profile",
1490 "enable_all_context_servers": true,
1491 "tools": {}
1492 },
1493 }
1494 }
1495 })
1496 .to_string()
1497 .into_bytes(),
1498 )
1499 .await;
1500 cx.run_until_parked();
1501 thread.update(cx, |thread, cx| {
1502 thread.set_profile(AgentProfileId("test".into()), cx)
1503 });
1504
1505 // Setup a context server with a tool
1506 let mut mcp_tool_calls = setup_context_server(
1507 "github_server",
1508 vec![context_server::types::Tool {
1509 name: "issue_read".into(),
1510 description: Some("Read a GitHub issue".into()),
1511 input_schema: json!({
1512 "type": "object",
1513 "properties": {
1514 "issue_url": { "type": "string" }
1515 }
1516 }),
1517 output_schema: None,
1518 annotations: None,
1519 }],
1520 &context_server_store,
1521 cx,
1522 );
1523
1524 // Send a message and have the model call the MCP tool
1525 let events = thread.update(cx, |thread, cx| {
1526 thread
1527 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1528 .unwrap()
1529 });
1530 cx.run_until_parked();
1531
1532 // Verify the MCP tool is available to the model
1533 let completion = fake_model.pending_completions().pop().unwrap();
1534 assert_eq!(
1535 tool_names_for_completion(&completion),
1536 vec!["issue_read"],
1537 "MCP tool should be available"
1538 );
1539
1540 // Simulate the model calling the MCP tool
1541 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1542 LanguageModelToolUse {
1543 id: "tool_1".into(),
1544 name: "issue_read".into(),
1545 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1546 .to_string(),
1547 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1548 is_input_complete: true,
1549 thought_signature: None,
1550 },
1551 ));
1552 fake_model.end_last_completion_stream();
1553 cx.run_until_parked();
1554
1555 // The MCP server receives the tool call and responds with content
1556 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1557 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1558 assert_eq!(tool_call_params.name, "issue_read");
1559 tool_call_response
1560 .send(context_server::types::CallToolResponse {
1561 content: vec![context_server::types::ToolResponseContent::Text {
1562 text: expected_tool_output.into(),
1563 }],
1564 is_error: None,
1565 meta: None,
1566 structured_content: None,
1567 })
1568 .unwrap();
1569 cx.run_until_parked();
1570
1571 // After tool completes, the model continues with a new completion request
1572 // that includes the tool results. We need to respond to this.
1573 let _completion = fake_model.pending_completions().pop().unwrap();
1574 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1575 fake_model
1576 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1577 fake_model.end_last_completion_stream();
1578 events.collect::<Vec<_>>().await;
1579
1580 // Verify the tool result is stored in the thread by checking the markdown output.
1581 // The tool result is in the first assistant message (not the last one, which is
1582 // the model's response after the tool completed).
1583 thread.update(cx, |thread, _cx| {
1584 let markdown = thread.to_markdown();
1585 assert!(
1586 markdown.contains("**Tool Result**: issue_read"),
1587 "Thread should contain tool result header"
1588 );
1589 assert!(
1590 markdown.contains(expected_tool_output),
1591 "Thread should contain tool output: {}",
1592 expected_tool_output
1593 );
1594 });
1595
1596 // Simulate app restart: disconnect the MCP server.
1597 // After restart, the MCP server won't be connected yet when the thread is replayed.
1598 context_server_store.update(cx, |store, cx| {
1599 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1600 });
1601 cx.run_until_parked();
1602
1603 // Replay the thread (this is what happens when loading a saved thread)
1604 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1605
1606 let mut found_tool_call = None;
1607 let mut found_tool_call_update_with_output = None;
1608
1609 while let Some(event) = replay_events.next().await {
1610 let event = event.unwrap();
1611 match &event {
1612 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1613 found_tool_call = Some(tc.clone());
1614 }
1615 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1616 if update.tool_call_id.to_string() == "tool_1" =>
1617 {
1618 if update.fields.raw_output.is_some() {
1619 found_tool_call_update_with_output = Some(update.clone());
1620 }
1621 }
1622 _ => {}
1623 }
1624 }
1625
1626 // The tool call should be found
1627 assert!(
1628 found_tool_call.is_some(),
1629 "Tool call should be emitted during replay"
1630 );
1631
1632 assert!(
1633 found_tool_call_update_with_output.is_some(),
1634 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1635 );
1636
1637 let update = found_tool_call_update_with_output.unwrap();
1638 assert_eq!(
1639 update.fields.raw_output,
1640 Some(expected_tool_output.into()),
1641 "raw_output should contain the saved tool result"
1642 );
1643
1644 // Also verify the status is correct (completed, not failed)
1645 assert_eq!(
1646 update.fields.status,
1647 Some(acp::ToolCallStatus::Completed),
1648 "Tool call status should reflect the original completion status"
1649 );
1650}
1651
1652#[gpui::test]
1653async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1654 let ThreadTest {
1655 model,
1656 thread,
1657 context_server_store,
1658 fs,
1659 ..
1660 } = setup(cx, TestModel::Fake).await;
1661 let fake_model = model.as_fake();
1662
1663 // Set up a profile with all tools enabled
1664 fs.insert_file(
1665 paths::settings_file(),
1666 json!({
1667 "agent": {
1668 "profiles": {
1669 "test": {
1670 "name": "Test Profile",
1671 "enable_all_context_servers": true,
1672 "tools": {
1673 EchoTool::NAME: true,
1674 DelayTool::NAME: true,
1675 WordListTool::NAME: true,
1676 ToolRequiringPermission::NAME: true,
1677 InfiniteTool::NAME: true,
1678 }
1679 },
1680 }
1681 }
1682 })
1683 .to_string()
1684 .into_bytes(),
1685 )
1686 .await;
1687 cx.run_until_parked();
1688
1689 thread.update(cx, |thread, cx| {
1690 thread.set_profile(AgentProfileId("test".into()), cx);
1691 thread.add_tool(EchoTool, None);
1692 thread.add_tool(DelayTool, None);
1693 thread.add_tool(WordListTool, None);
1694 thread.add_tool(ToolRequiringPermission, None);
1695 thread.add_tool(InfiniteTool, None);
1696 });
1697
1698 // Set up multiple context servers with some overlapping tool names
1699 let _server1_calls = setup_context_server(
1700 "xxx",
1701 vec![
1702 context_server::types::Tool {
1703 name: "echo".into(), // Conflicts with native EchoTool
1704 description: None,
1705 input_schema: serde_json::to_value(EchoTool::input_schema(
1706 LanguageModelToolSchemaFormat::JsonSchema,
1707 ))
1708 .unwrap(),
1709 output_schema: None,
1710 annotations: None,
1711 },
1712 context_server::types::Tool {
1713 name: "unique_tool_1".into(),
1714 description: None,
1715 input_schema: json!({"type": "object", "properties": {}}),
1716 output_schema: None,
1717 annotations: None,
1718 },
1719 ],
1720 &context_server_store,
1721 cx,
1722 );
1723
1724 let _server2_calls = setup_context_server(
1725 "yyy",
1726 vec![
1727 context_server::types::Tool {
1728 name: "echo".into(), // Also conflicts with native EchoTool
1729 description: None,
1730 input_schema: serde_json::to_value(EchoTool::input_schema(
1731 LanguageModelToolSchemaFormat::JsonSchema,
1732 ))
1733 .unwrap(),
1734 output_schema: None,
1735 annotations: None,
1736 },
1737 context_server::types::Tool {
1738 name: "unique_tool_2".into(),
1739 description: None,
1740 input_schema: json!({"type": "object", "properties": {}}),
1741 output_schema: None,
1742 annotations: None,
1743 },
1744 context_server::types::Tool {
1745 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1746 description: None,
1747 input_schema: json!({"type": "object", "properties": {}}),
1748 output_schema: None,
1749 annotations: None,
1750 },
1751 context_server::types::Tool {
1752 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1753 description: None,
1754 input_schema: json!({"type": "object", "properties": {}}),
1755 output_schema: None,
1756 annotations: None,
1757 },
1758 ],
1759 &context_server_store,
1760 cx,
1761 );
1762 let _server3_calls = setup_context_server(
1763 "zzz",
1764 vec![
1765 context_server::types::Tool {
1766 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1767 description: None,
1768 input_schema: json!({"type": "object", "properties": {}}),
1769 output_schema: None,
1770 annotations: None,
1771 },
1772 context_server::types::Tool {
1773 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1774 description: None,
1775 input_schema: json!({"type": "object", "properties": {}}),
1776 output_schema: None,
1777 annotations: None,
1778 },
1779 context_server::types::Tool {
1780 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1781 description: None,
1782 input_schema: json!({"type": "object", "properties": {}}),
1783 output_schema: None,
1784 annotations: None,
1785 },
1786 ],
1787 &context_server_store,
1788 cx,
1789 );
1790
1791 // Server with spaces in name - tests snake_case conversion for API compatibility
1792 let _server4_calls = setup_context_server(
1793 "Azure DevOps",
1794 vec![context_server::types::Tool {
1795 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1796 description: None,
1797 input_schema: serde_json::to_value(EchoTool::input_schema(
1798 LanguageModelToolSchemaFormat::JsonSchema,
1799 ))
1800 .unwrap(),
1801 output_schema: None,
1802 annotations: None,
1803 }],
1804 &context_server_store,
1805 cx,
1806 );
1807
1808 thread
1809 .update(cx, |thread, cx| {
1810 thread.send(UserMessageId::new(), ["Go"], cx)
1811 })
1812 .unwrap();
1813 cx.run_until_parked();
1814 let completion = fake_model.pending_completions().pop().unwrap();
1815 assert_eq!(
1816 tool_names_for_completion(&completion),
1817 vec![
1818 "azure_dev_ops_echo",
1819 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1820 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1821 "delay",
1822 "echo",
1823 "infinite",
1824 "tool_requiring_permission",
1825 "unique_tool_1",
1826 "unique_tool_2",
1827 "word_list",
1828 "xxx_echo",
1829 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1830 "yyy_echo",
1831 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1832 ]
1833 );
1834}
1835
1836#[gpui::test]
1837#[cfg_attr(not(feature = "e2e"), ignore)]
1838async fn test_cancellation(cx: &mut TestAppContext) {
1839 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1840
1841 let mut events = thread
1842 .update(cx, |thread, cx| {
1843 thread.add_tool(InfiniteTool, None);
1844 thread.add_tool(EchoTool, None);
1845 thread.send(
1846 UserMessageId::new(),
1847 ["Call the echo tool, then call the infinite tool, then explain their output"],
1848 cx,
1849 )
1850 })
1851 .unwrap();
1852
1853 // Wait until both tools are called.
1854 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1855 let mut echo_id = None;
1856 let mut echo_completed = false;
1857 while let Some(event) = events.next().await {
1858 match event.unwrap() {
1859 ThreadEvent::ToolCall(tool_call) => {
1860 assert_eq!(tool_call.title, expected_tools.remove(0));
1861 if tool_call.title == "Echo" {
1862 echo_id = Some(tool_call.tool_call_id);
1863 }
1864 }
1865 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1866 acp::ToolCallUpdate {
1867 tool_call_id,
1868 fields:
1869 acp::ToolCallUpdateFields {
1870 status: Some(acp::ToolCallStatus::Completed),
1871 ..
1872 },
1873 ..
1874 },
1875 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1876 echo_completed = true;
1877 }
1878 _ => {}
1879 }
1880
1881 if expected_tools.is_empty() && echo_completed {
1882 break;
1883 }
1884 }
1885
1886 // Cancel the current send and ensure that the event stream is closed, even
1887 // if one of the tools is still running.
1888 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1889 let events = events.collect::<Vec<_>>().await;
1890 let last_event = events.last();
1891 assert!(
1892 matches!(
1893 last_event,
1894 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1895 ),
1896 "unexpected event {last_event:?}"
1897 );
1898
1899 // Ensure we can still send a new message after cancellation.
1900 let events = thread
1901 .update(cx, |thread, cx| {
1902 thread.send(
1903 UserMessageId::new(),
1904 ["Testing: reply with 'Hello' then stop."],
1905 cx,
1906 )
1907 })
1908 .unwrap()
1909 .collect::<Vec<_>>()
1910 .await;
1911 thread.update(cx, |thread, _cx| {
1912 let message = thread.last_message().unwrap();
1913 let agent_message = message.as_agent_message().unwrap();
1914 assert_eq!(
1915 agent_message.content,
1916 vec![AgentMessageContent::Text("Hello".to_string())]
1917 );
1918 });
1919 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1920}
1921
1922#[gpui::test]
1923async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1924 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1925 always_allow_tools(cx);
1926 let fake_model = model.as_fake();
1927
1928 let environment = Rc::new(cx.update(|cx| {
1929 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1930 }));
1931 let handle = environment.terminal_handle.clone().unwrap();
1932
1933 let mut events = thread
1934 .update(cx, |thread, cx| {
1935 thread.add_tool(
1936 crate::TerminalTool::new(thread.project().clone(), environment),
1937 None,
1938 );
1939 thread.send(UserMessageId::new(), ["run a command"], cx)
1940 })
1941 .unwrap();
1942
1943 cx.run_until_parked();
1944
1945 // Simulate the model calling the terminal tool
1946 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1947 LanguageModelToolUse {
1948 id: "terminal_tool_1".into(),
1949 name: TerminalTool::NAME.into(),
1950 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1951 input: json!({"command": "sleep 1000", "cd": "."}),
1952 is_input_complete: true,
1953 thought_signature: None,
1954 },
1955 ));
1956 fake_model.end_last_completion_stream();
1957
1958 // Wait for the terminal tool to start running
1959 wait_for_terminal_tool_started(&mut events, cx).await;
1960
1961 // Cancel the thread while the terminal is running
1962 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1963
1964 // Collect remaining events, driving the executor to let cancellation complete
1965 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1966
1967 // Verify the terminal was killed
1968 assert!(
1969 handle.was_killed(),
1970 "expected terminal handle to be killed on cancellation"
1971 );
1972
1973 // Verify we got a cancellation stop event
1974 assert_eq!(
1975 stop_events(remaining_events),
1976 vec![acp::StopReason::Cancelled],
1977 );
1978
1979 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1980 thread.update(cx, |thread, _cx| {
1981 let message = thread.last_message().unwrap();
1982 let agent_message = message.as_agent_message().unwrap();
1983
1984 let tool_use = agent_message
1985 .content
1986 .iter()
1987 .find_map(|content| match content {
1988 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1989 _ => None,
1990 })
1991 .expect("expected tool use in agent message");
1992
1993 let tool_result = agent_message
1994 .tool_results
1995 .get(&tool_use.id)
1996 .expect("expected tool result");
1997
1998 let result_text = match &tool_result.content {
1999 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2000 _ => panic!("expected text content in tool result"),
2001 };
2002
2003 // "partial output" comes from FakeTerminalHandle's output field
2004 assert!(
2005 result_text.contains("partial output"),
2006 "expected tool result to contain terminal output, got: {result_text}"
2007 );
2008 // Match the actual format from process_content in terminal_tool.rs
2009 assert!(
2010 result_text.contains("The user stopped this command"),
2011 "expected tool result to indicate user stopped, got: {result_text}"
2012 );
2013 });
2014
2015 // Verify we can send a new message after cancellation
2016 verify_thread_recovery(&thread, &fake_model, cx).await;
2017}
2018
2019#[gpui::test]
2020async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2021 // This test verifies that tools which properly handle cancellation via
2022 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2023 // to cancellation and report that they were cancelled.
2024 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2025 always_allow_tools(cx);
2026 let fake_model = model.as_fake();
2027
2028 let (tool, was_cancelled) = CancellationAwareTool::new();
2029
2030 let mut events = thread
2031 .update(cx, |thread, cx| {
2032 thread.add_tool(tool, None);
2033 thread.send(
2034 UserMessageId::new(),
2035 ["call the cancellation aware tool"],
2036 cx,
2037 )
2038 })
2039 .unwrap();
2040
2041 cx.run_until_parked();
2042
2043 // Simulate the model calling the cancellation-aware tool
2044 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2045 LanguageModelToolUse {
2046 id: "cancellation_aware_1".into(),
2047 name: "cancellation_aware".into(),
2048 raw_input: r#"{}"#.into(),
2049 input: json!({}),
2050 is_input_complete: true,
2051 thought_signature: None,
2052 },
2053 ));
2054 fake_model.end_last_completion_stream();
2055
2056 cx.run_until_parked();
2057
2058 // Wait for the tool call to be reported
2059 let mut tool_started = false;
2060 let deadline = cx.executor().num_cpus() * 100;
2061 for _ in 0..deadline {
2062 cx.run_until_parked();
2063
2064 while let Some(Some(event)) = events.next().now_or_never() {
2065 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2066 if tool_call.title == "Cancellation Aware Tool" {
2067 tool_started = true;
2068 break;
2069 }
2070 }
2071 }
2072
2073 if tool_started {
2074 break;
2075 }
2076
2077 cx.background_executor
2078 .timer(Duration::from_millis(10))
2079 .await;
2080 }
2081 assert!(tool_started, "expected cancellation aware tool to start");
2082
2083 // Cancel the thread and wait for it to complete
2084 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2085
2086 // The cancel task should complete promptly because the tool handles cancellation
2087 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2088 futures::select! {
2089 _ = cancel_task.fuse() => {}
2090 _ = timeout.fuse() => {
2091 panic!("cancel task timed out - tool did not respond to cancellation");
2092 }
2093 }
2094
2095 // Verify the tool detected cancellation via its flag
2096 assert!(
2097 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2098 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2099 );
2100
2101 // Collect remaining events
2102 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2103
2104 // Verify we got a cancellation stop event
2105 assert_eq!(
2106 stop_events(remaining_events),
2107 vec![acp::StopReason::Cancelled],
2108 );
2109
2110 // Verify we can send a new message after cancellation
2111 verify_thread_recovery(&thread, &fake_model, cx).await;
2112}
2113
2114/// Helper to verify thread can recover after cancellation by sending a simple message.
2115async fn verify_thread_recovery(
2116 thread: &Entity<Thread>,
2117 fake_model: &FakeLanguageModel,
2118 cx: &mut TestAppContext,
2119) {
2120 let events = thread
2121 .update(cx, |thread, cx| {
2122 thread.send(
2123 UserMessageId::new(),
2124 ["Testing: reply with 'Hello' then stop."],
2125 cx,
2126 )
2127 })
2128 .unwrap();
2129 cx.run_until_parked();
2130 fake_model.send_last_completion_stream_text_chunk("Hello");
2131 fake_model
2132 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2133 fake_model.end_last_completion_stream();
2134
2135 let events = events.collect::<Vec<_>>().await;
2136 thread.update(cx, |thread, _cx| {
2137 let message = thread.last_message().unwrap();
2138 let agent_message = message.as_agent_message().unwrap();
2139 assert_eq!(
2140 agent_message.content,
2141 vec![AgentMessageContent::Text("Hello".to_string())]
2142 );
2143 });
2144 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2145}
2146
2147/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2148async fn wait_for_terminal_tool_started(
2149 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2150 cx: &mut TestAppContext,
2151) {
2152 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2153 for _ in 0..deadline {
2154 cx.run_until_parked();
2155
2156 while let Some(Some(event)) = events.next().now_or_never() {
2157 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2158 update,
2159 ))) = &event
2160 {
2161 if update.fields.content.as_ref().is_some_and(|content| {
2162 content
2163 .iter()
2164 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2165 }) {
2166 return;
2167 }
2168 }
2169 }
2170
2171 cx.background_executor
2172 .timer(Duration::from_millis(10))
2173 .await;
2174 }
2175 panic!("terminal tool did not start within the expected time");
2176}
2177
2178/// Collects events until a Stop event is received, driving the executor to completion.
2179async fn collect_events_until_stop(
2180 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2181 cx: &mut TestAppContext,
2182) -> Vec<Result<ThreadEvent>> {
2183 let mut collected = Vec::new();
2184 let deadline = cx.executor().num_cpus() * 200;
2185
2186 for _ in 0..deadline {
2187 cx.executor().advance_clock(Duration::from_millis(10));
2188 cx.run_until_parked();
2189
2190 while let Some(Some(event)) = events.next().now_or_never() {
2191 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2192 collected.push(event);
2193 if is_stop {
2194 return collected;
2195 }
2196 }
2197 }
2198 panic!(
2199 "did not receive Stop event within the expected time; collected {} events",
2200 collected.len()
2201 );
2202}
2203
2204#[gpui::test]
2205async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2206 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2207 always_allow_tools(cx);
2208 let fake_model = model.as_fake();
2209
2210 let environment = Rc::new(cx.update(|cx| {
2211 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2212 }));
2213 let handle = environment.terminal_handle.clone().unwrap();
2214
2215 let message_id = UserMessageId::new();
2216 let mut events = thread
2217 .update(cx, |thread, cx| {
2218 thread.add_tool(
2219 crate::TerminalTool::new(thread.project().clone(), environment),
2220 None,
2221 );
2222 thread.send(message_id.clone(), ["run a command"], cx)
2223 })
2224 .unwrap();
2225
2226 cx.run_until_parked();
2227
2228 // Simulate the model calling the terminal tool
2229 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2230 LanguageModelToolUse {
2231 id: "terminal_tool_1".into(),
2232 name: TerminalTool::NAME.into(),
2233 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2234 input: json!({"command": "sleep 1000", "cd": "."}),
2235 is_input_complete: true,
2236 thought_signature: None,
2237 },
2238 ));
2239 fake_model.end_last_completion_stream();
2240
2241 // Wait for the terminal tool to start running
2242 wait_for_terminal_tool_started(&mut events, cx).await;
2243
2244 // Truncate the thread while the terminal is running
2245 thread
2246 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2247 .unwrap();
2248
2249 // Drive the executor to let cancellation complete
2250 let _ = collect_events_until_stop(&mut events, cx).await;
2251
2252 // Verify the terminal was killed
2253 assert!(
2254 handle.was_killed(),
2255 "expected terminal handle to be killed on truncate"
2256 );
2257
2258 // Verify the thread is empty after truncation
2259 thread.update(cx, |thread, _cx| {
2260 assert_eq!(
2261 thread.to_markdown(),
2262 "",
2263 "expected thread to be empty after truncating the only message"
2264 );
2265 });
2266
2267 // Verify we can send a new message after truncation
2268 verify_thread_recovery(&thread, &fake_model, cx).await;
2269}
2270
2271#[gpui::test]
2272async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2273 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2274 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2275 always_allow_tools(cx);
2276 let fake_model = model.as_fake();
2277
2278 let environment = Rc::new(MultiTerminalEnvironment::new());
2279
2280 let mut events = thread
2281 .update(cx, |thread, cx| {
2282 thread.add_tool(
2283 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2284 None,
2285 );
2286 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2287 })
2288 .unwrap();
2289
2290 cx.run_until_parked();
2291
2292 // Simulate the model calling two terminal tools
2293 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2294 LanguageModelToolUse {
2295 id: "terminal_tool_1".into(),
2296 name: TerminalTool::NAME.into(),
2297 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2298 input: json!({"command": "sleep 1000", "cd": "."}),
2299 is_input_complete: true,
2300 thought_signature: None,
2301 },
2302 ));
2303 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2304 LanguageModelToolUse {
2305 id: "terminal_tool_2".into(),
2306 name: TerminalTool::NAME.into(),
2307 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2308 input: json!({"command": "sleep 2000", "cd": "."}),
2309 is_input_complete: true,
2310 thought_signature: None,
2311 },
2312 ));
2313 fake_model.end_last_completion_stream();
2314
2315 // Wait for both terminal tools to start by counting terminal content updates
2316 let mut terminals_started = 0;
2317 let deadline = cx.executor().num_cpus() * 100;
2318 for _ in 0..deadline {
2319 cx.run_until_parked();
2320
2321 while let Some(Some(event)) = events.next().now_or_never() {
2322 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2323 update,
2324 ))) = &event
2325 {
2326 if update.fields.content.as_ref().is_some_and(|content| {
2327 content
2328 .iter()
2329 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2330 }) {
2331 terminals_started += 1;
2332 if terminals_started >= 2 {
2333 break;
2334 }
2335 }
2336 }
2337 }
2338 if terminals_started >= 2 {
2339 break;
2340 }
2341
2342 cx.background_executor
2343 .timer(Duration::from_millis(10))
2344 .await;
2345 }
2346 assert!(
2347 terminals_started >= 2,
2348 "expected 2 terminal tools to start, got {terminals_started}"
2349 );
2350
2351 // Cancel the thread while both terminals are running
2352 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2353
2354 // Collect remaining events
2355 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2356
2357 // Verify both terminal handles were killed
2358 let handles = environment.handles();
2359 assert_eq!(
2360 handles.len(),
2361 2,
2362 "expected 2 terminal handles to be created"
2363 );
2364 assert!(
2365 handles[0].was_killed(),
2366 "expected first terminal handle to be killed on cancellation"
2367 );
2368 assert!(
2369 handles[1].was_killed(),
2370 "expected second terminal handle to be killed on cancellation"
2371 );
2372
2373 // Verify we got a cancellation stop event
2374 assert_eq!(
2375 stop_events(remaining_events),
2376 vec![acp::StopReason::Cancelled],
2377 );
2378}
2379
2380#[gpui::test]
2381async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2382 // Tests that clicking the stop button on the terminal card (as opposed to the main
2383 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2384 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2385 always_allow_tools(cx);
2386 let fake_model = model.as_fake();
2387
2388 let environment = Rc::new(cx.update(|cx| {
2389 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2390 }));
2391 let handle = environment.terminal_handle.clone().unwrap();
2392
2393 let mut events = thread
2394 .update(cx, |thread, cx| {
2395 thread.add_tool(
2396 crate::TerminalTool::new(thread.project().clone(), environment),
2397 None,
2398 );
2399 thread.send(UserMessageId::new(), ["run a command"], cx)
2400 })
2401 .unwrap();
2402
2403 cx.run_until_parked();
2404
2405 // Simulate the model calling the terminal tool
2406 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2407 LanguageModelToolUse {
2408 id: "terminal_tool_1".into(),
2409 name: TerminalTool::NAME.into(),
2410 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2411 input: json!({"command": "sleep 1000", "cd": "."}),
2412 is_input_complete: true,
2413 thought_signature: None,
2414 },
2415 ));
2416 fake_model.end_last_completion_stream();
2417
2418 // Wait for the terminal tool to start running
2419 wait_for_terminal_tool_started(&mut events, cx).await;
2420
2421 // Simulate user clicking stop on the terminal card itself.
2422 // This sets the flag and signals exit (simulating what the real UI would do).
2423 handle.set_stopped_by_user(true);
2424 handle.killed.store(true, Ordering::SeqCst);
2425 handle.signal_exit();
2426
2427 // Wait for the tool to complete
2428 cx.run_until_parked();
2429
2430 // The thread continues after tool completion - simulate the model ending its turn
2431 fake_model
2432 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2433 fake_model.end_last_completion_stream();
2434
2435 // Collect remaining events
2436 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2437
2438 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2439 assert_eq!(
2440 stop_events(remaining_events),
2441 vec![acp::StopReason::EndTurn],
2442 );
2443
2444 // Verify the tool result indicates user stopped
2445 thread.update(cx, |thread, _cx| {
2446 let message = thread.last_message().unwrap();
2447 let agent_message = message.as_agent_message().unwrap();
2448
2449 let tool_use = agent_message
2450 .content
2451 .iter()
2452 .find_map(|content| match content {
2453 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2454 _ => None,
2455 })
2456 .expect("expected tool use in agent message");
2457
2458 let tool_result = agent_message
2459 .tool_results
2460 .get(&tool_use.id)
2461 .expect("expected tool result");
2462
2463 let result_text = match &tool_result.content {
2464 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2465 _ => panic!("expected text content in tool result"),
2466 };
2467
2468 assert!(
2469 result_text.contains("The user stopped this command"),
2470 "expected tool result to indicate user stopped, got: {result_text}"
2471 );
2472 });
2473}
2474
2475#[gpui::test]
2476async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2477 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2478 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2479 always_allow_tools(cx);
2480 let fake_model = model.as_fake();
2481
2482 let environment = Rc::new(cx.update(|cx| {
2483 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2484 }));
2485 let handle = environment.terminal_handle.clone().unwrap();
2486
2487 let mut events = thread
2488 .update(cx, |thread, cx| {
2489 thread.add_tool(
2490 crate::TerminalTool::new(thread.project().clone(), environment),
2491 None,
2492 );
2493 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2494 })
2495 .unwrap();
2496
2497 cx.run_until_parked();
2498
2499 // Simulate the model calling the terminal tool with a short timeout
2500 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2501 LanguageModelToolUse {
2502 id: "terminal_tool_1".into(),
2503 name: TerminalTool::NAME.into(),
2504 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2505 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2506 is_input_complete: true,
2507 thought_signature: None,
2508 },
2509 ));
2510 fake_model.end_last_completion_stream();
2511
2512 // Wait for the terminal tool to start running
2513 wait_for_terminal_tool_started(&mut events, cx).await;
2514
2515 // Advance clock past the timeout
2516 cx.executor().advance_clock(Duration::from_millis(200));
2517 cx.run_until_parked();
2518
2519 // The thread continues after tool completion - simulate the model ending its turn
2520 fake_model
2521 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2522 fake_model.end_last_completion_stream();
2523
2524 // Collect remaining events
2525 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2526
2527 // Verify the terminal was killed due to timeout
2528 assert!(
2529 handle.was_killed(),
2530 "expected terminal handle to be killed on timeout"
2531 );
2532
2533 // Verify we got an EndTurn (the tool completed, just with timeout)
2534 assert_eq!(
2535 stop_events(remaining_events),
2536 vec![acp::StopReason::EndTurn],
2537 );
2538
2539 // Verify the tool result indicates timeout, not user stopped
2540 thread.update(cx, |thread, _cx| {
2541 let message = thread.last_message().unwrap();
2542 let agent_message = message.as_agent_message().unwrap();
2543
2544 let tool_use = agent_message
2545 .content
2546 .iter()
2547 .find_map(|content| match content {
2548 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2549 _ => None,
2550 })
2551 .expect("expected tool use in agent message");
2552
2553 let tool_result = agent_message
2554 .tool_results
2555 .get(&tool_use.id)
2556 .expect("expected tool result");
2557
2558 let result_text = match &tool_result.content {
2559 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2560 _ => panic!("expected text content in tool result"),
2561 };
2562
2563 assert!(
2564 result_text.contains("timed out"),
2565 "expected tool result to indicate timeout, got: {result_text}"
2566 );
2567 assert!(
2568 !result_text.contains("The user stopped"),
2569 "tool result should not mention user stopped when it timed out, got: {result_text}"
2570 );
2571 });
2572}
2573
2574#[gpui::test]
2575async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2576 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2577 let fake_model = model.as_fake();
2578
2579 let events_1 = thread
2580 .update(cx, |thread, cx| {
2581 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2582 })
2583 .unwrap();
2584 cx.run_until_parked();
2585 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2586 cx.run_until_parked();
2587
2588 let events_2 = thread
2589 .update(cx, |thread, cx| {
2590 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2591 })
2592 .unwrap();
2593 cx.run_until_parked();
2594 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2595 fake_model
2596 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2597 fake_model.end_last_completion_stream();
2598
2599 let events_1 = events_1.collect::<Vec<_>>().await;
2600 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2601 let events_2 = events_2.collect::<Vec<_>>().await;
2602 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2603}
2604
2605#[gpui::test]
2606async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2607 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2608 let fake_model = model.as_fake();
2609
2610 let events_1 = thread
2611 .update(cx, |thread, cx| {
2612 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2613 })
2614 .unwrap();
2615 cx.run_until_parked();
2616 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2617 fake_model
2618 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2619 fake_model.end_last_completion_stream();
2620 let events_1 = events_1.collect::<Vec<_>>().await;
2621
2622 let events_2 = thread
2623 .update(cx, |thread, cx| {
2624 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2625 })
2626 .unwrap();
2627 cx.run_until_parked();
2628 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2629 fake_model
2630 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2631 fake_model.end_last_completion_stream();
2632 let events_2 = events_2.collect::<Vec<_>>().await;
2633
2634 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2635 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2636}
2637
2638#[gpui::test]
2639async fn test_refusal(cx: &mut TestAppContext) {
2640 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2641 let fake_model = model.as_fake();
2642
2643 let events = thread
2644 .update(cx, |thread, cx| {
2645 thread.send(UserMessageId::new(), ["Hello"], cx)
2646 })
2647 .unwrap();
2648 cx.run_until_parked();
2649 thread.read_with(cx, |thread, _| {
2650 assert_eq!(
2651 thread.to_markdown(),
2652 indoc! {"
2653 ## User
2654
2655 Hello
2656 "}
2657 );
2658 });
2659
2660 fake_model.send_last_completion_stream_text_chunk("Hey!");
2661 cx.run_until_parked();
2662 thread.read_with(cx, |thread, _| {
2663 assert_eq!(
2664 thread.to_markdown(),
2665 indoc! {"
2666 ## User
2667
2668 Hello
2669
2670 ## Assistant
2671
2672 Hey!
2673 "}
2674 );
2675 });
2676
2677 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2678 fake_model
2679 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2680 let events = events.collect::<Vec<_>>().await;
2681 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2682 thread.read_with(cx, |thread, _| {
2683 assert_eq!(thread.to_markdown(), "");
2684 });
2685}
2686
2687#[gpui::test]
2688async fn test_truncate_first_message(cx: &mut TestAppContext) {
2689 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2690 let fake_model = model.as_fake();
2691
2692 let message_id = UserMessageId::new();
2693 thread
2694 .update(cx, |thread, cx| {
2695 thread.send(message_id.clone(), ["Hello"], cx)
2696 })
2697 .unwrap();
2698 cx.run_until_parked();
2699 thread.read_with(cx, |thread, _| {
2700 assert_eq!(
2701 thread.to_markdown(),
2702 indoc! {"
2703 ## User
2704
2705 Hello
2706 "}
2707 );
2708 assert_eq!(thread.latest_token_usage(), None);
2709 });
2710
2711 fake_model.send_last_completion_stream_text_chunk("Hey!");
2712 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2713 language_model::TokenUsage {
2714 input_tokens: 32_000,
2715 output_tokens: 16_000,
2716 cache_creation_input_tokens: 0,
2717 cache_read_input_tokens: 0,
2718 },
2719 ));
2720 cx.run_until_parked();
2721 thread.read_with(cx, |thread, _| {
2722 assert_eq!(
2723 thread.to_markdown(),
2724 indoc! {"
2725 ## User
2726
2727 Hello
2728
2729 ## Assistant
2730
2731 Hey!
2732 "}
2733 );
2734 assert_eq!(
2735 thread.latest_token_usage(),
2736 Some(acp_thread::TokenUsage {
2737 used_tokens: 32_000 + 16_000,
2738 max_tokens: 1_000_000,
2739 input_tokens: 32_000,
2740 output_tokens: 16_000,
2741 })
2742 );
2743 });
2744
2745 thread
2746 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2747 .unwrap();
2748 cx.run_until_parked();
2749 thread.read_with(cx, |thread, _| {
2750 assert_eq!(thread.to_markdown(), "");
2751 assert_eq!(thread.latest_token_usage(), None);
2752 });
2753
2754 // Ensure we can still send a new message after truncation.
2755 thread
2756 .update(cx, |thread, cx| {
2757 thread.send(UserMessageId::new(), ["Hi"], cx)
2758 })
2759 .unwrap();
2760 thread.update(cx, |thread, _cx| {
2761 assert_eq!(
2762 thread.to_markdown(),
2763 indoc! {"
2764 ## User
2765
2766 Hi
2767 "}
2768 );
2769 });
2770 cx.run_until_parked();
2771 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2772 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2773 language_model::TokenUsage {
2774 input_tokens: 40_000,
2775 output_tokens: 20_000,
2776 cache_creation_input_tokens: 0,
2777 cache_read_input_tokens: 0,
2778 },
2779 ));
2780 cx.run_until_parked();
2781 thread.read_with(cx, |thread, _| {
2782 assert_eq!(
2783 thread.to_markdown(),
2784 indoc! {"
2785 ## User
2786
2787 Hi
2788
2789 ## Assistant
2790
2791 Ahoy!
2792 "}
2793 );
2794
2795 assert_eq!(
2796 thread.latest_token_usage(),
2797 Some(acp_thread::TokenUsage {
2798 used_tokens: 40_000 + 20_000,
2799 max_tokens: 1_000_000,
2800 input_tokens: 40_000,
2801 output_tokens: 20_000,
2802 })
2803 );
2804 });
2805}
2806
2807#[gpui::test]
2808async fn test_truncate_second_message(cx: &mut TestAppContext) {
2809 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2810 let fake_model = model.as_fake();
2811
2812 thread
2813 .update(cx, |thread, cx| {
2814 thread.send(UserMessageId::new(), ["Message 1"], cx)
2815 })
2816 .unwrap();
2817 cx.run_until_parked();
2818 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2819 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2820 language_model::TokenUsage {
2821 input_tokens: 32_000,
2822 output_tokens: 16_000,
2823 cache_creation_input_tokens: 0,
2824 cache_read_input_tokens: 0,
2825 },
2826 ));
2827 fake_model.end_last_completion_stream();
2828 cx.run_until_parked();
2829
2830 let assert_first_message_state = |cx: &mut TestAppContext| {
2831 thread.clone().read_with(cx, |thread, _| {
2832 assert_eq!(
2833 thread.to_markdown(),
2834 indoc! {"
2835 ## User
2836
2837 Message 1
2838
2839 ## Assistant
2840
2841 Message 1 response
2842 "}
2843 );
2844
2845 assert_eq!(
2846 thread.latest_token_usage(),
2847 Some(acp_thread::TokenUsage {
2848 used_tokens: 32_000 + 16_000,
2849 max_tokens: 1_000_000,
2850 input_tokens: 32_000,
2851 output_tokens: 16_000,
2852 })
2853 );
2854 });
2855 };
2856
2857 assert_first_message_state(cx);
2858
2859 let second_message_id = UserMessageId::new();
2860 thread
2861 .update(cx, |thread, cx| {
2862 thread.send(second_message_id.clone(), ["Message 2"], cx)
2863 })
2864 .unwrap();
2865 cx.run_until_parked();
2866
2867 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2868 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2869 language_model::TokenUsage {
2870 input_tokens: 40_000,
2871 output_tokens: 20_000,
2872 cache_creation_input_tokens: 0,
2873 cache_read_input_tokens: 0,
2874 },
2875 ));
2876 fake_model.end_last_completion_stream();
2877 cx.run_until_parked();
2878
2879 thread.read_with(cx, |thread, _| {
2880 assert_eq!(
2881 thread.to_markdown(),
2882 indoc! {"
2883 ## User
2884
2885 Message 1
2886
2887 ## Assistant
2888
2889 Message 1 response
2890
2891 ## User
2892
2893 Message 2
2894
2895 ## Assistant
2896
2897 Message 2 response
2898 "}
2899 );
2900
2901 assert_eq!(
2902 thread.latest_token_usage(),
2903 Some(acp_thread::TokenUsage {
2904 used_tokens: 40_000 + 20_000,
2905 max_tokens: 1_000_000,
2906 input_tokens: 40_000,
2907 output_tokens: 20_000,
2908 })
2909 );
2910 });
2911
2912 thread
2913 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2914 .unwrap();
2915 cx.run_until_parked();
2916
2917 assert_first_message_state(cx);
2918}
2919
2920#[gpui::test]
2921async fn test_title_generation(cx: &mut TestAppContext) {
2922 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2923 let fake_model = model.as_fake();
2924
2925 let summary_model = Arc::new(FakeLanguageModel::default());
2926 thread.update(cx, |thread, cx| {
2927 thread.set_summarization_model(Some(summary_model.clone()), cx)
2928 });
2929
2930 let send = thread
2931 .update(cx, |thread, cx| {
2932 thread.send(UserMessageId::new(), ["Hello"], cx)
2933 })
2934 .unwrap();
2935 cx.run_until_parked();
2936
2937 fake_model.send_last_completion_stream_text_chunk("Hey!");
2938 fake_model.end_last_completion_stream();
2939 cx.run_until_parked();
2940 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2941
2942 // Ensure the summary model has been invoked to generate a title.
2943 summary_model.send_last_completion_stream_text_chunk("Hello ");
2944 summary_model.send_last_completion_stream_text_chunk("world\nG");
2945 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2946 summary_model.end_last_completion_stream();
2947 send.collect::<Vec<_>>().await;
2948 cx.run_until_parked();
2949 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2950
2951 // Send another message, ensuring no title is generated this time.
2952 let send = thread
2953 .update(cx, |thread, cx| {
2954 thread.send(UserMessageId::new(), ["Hello again"], cx)
2955 })
2956 .unwrap();
2957 cx.run_until_parked();
2958 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2959 fake_model.end_last_completion_stream();
2960 cx.run_until_parked();
2961 assert_eq!(summary_model.pending_completions(), Vec::new());
2962 send.collect::<Vec<_>>().await;
2963 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2964}
2965
2966#[gpui::test]
2967async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2968 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2969 let fake_model = model.as_fake();
2970
2971 let _events = thread
2972 .update(cx, |thread, cx| {
2973 thread.add_tool(ToolRequiringPermission, None);
2974 thread.add_tool(EchoTool, None);
2975 thread.send(UserMessageId::new(), ["Hey!"], cx)
2976 })
2977 .unwrap();
2978 cx.run_until_parked();
2979
2980 let permission_tool_use = LanguageModelToolUse {
2981 id: "tool_id_1".into(),
2982 name: ToolRequiringPermission::NAME.into(),
2983 raw_input: "{}".into(),
2984 input: json!({}),
2985 is_input_complete: true,
2986 thought_signature: None,
2987 };
2988 let echo_tool_use = LanguageModelToolUse {
2989 id: "tool_id_2".into(),
2990 name: EchoTool::NAME.into(),
2991 raw_input: json!({"text": "test"}).to_string(),
2992 input: json!({"text": "test"}),
2993 is_input_complete: true,
2994 thought_signature: None,
2995 };
2996 fake_model.send_last_completion_stream_text_chunk("Hi!");
2997 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2998 permission_tool_use,
2999 ));
3000 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3001 echo_tool_use.clone(),
3002 ));
3003 fake_model.end_last_completion_stream();
3004 cx.run_until_parked();
3005
3006 // Ensure pending tools are skipped when building a request.
3007 let request = thread
3008 .read_with(cx, |thread, cx| {
3009 thread.build_completion_request(CompletionIntent::EditFile, cx)
3010 })
3011 .unwrap();
3012 assert_eq!(
3013 request.messages[1..],
3014 vec![
3015 LanguageModelRequestMessage {
3016 role: Role::User,
3017 content: vec!["Hey!".into()],
3018 cache: true,
3019 reasoning_details: None,
3020 },
3021 LanguageModelRequestMessage {
3022 role: Role::Assistant,
3023 content: vec![
3024 MessageContent::Text("Hi!".into()),
3025 MessageContent::ToolUse(echo_tool_use.clone())
3026 ],
3027 cache: false,
3028 reasoning_details: None,
3029 },
3030 LanguageModelRequestMessage {
3031 role: Role::User,
3032 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3033 tool_use_id: echo_tool_use.id.clone(),
3034 tool_name: echo_tool_use.name,
3035 is_error: false,
3036 content: "test".into(),
3037 output: Some("test".into())
3038 })],
3039 cache: false,
3040 reasoning_details: None,
3041 },
3042 ],
3043 );
3044}
3045
3046#[gpui::test]
3047async fn test_agent_connection(cx: &mut TestAppContext) {
3048 cx.update(settings::init);
3049 let templates = Templates::new();
3050
3051 // Initialize language model system with test provider
3052 cx.update(|cx| {
3053 gpui_tokio::init(cx);
3054
3055 let http_client = FakeHttpClient::with_404_response();
3056 let clock = Arc::new(clock::FakeSystemClock::new());
3057 let client = Client::new(clock, http_client, cx);
3058 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3059 language_model::init(client.clone(), cx);
3060 language_models::init(user_store, client.clone(), cx);
3061 LanguageModelRegistry::test(cx);
3062 });
3063 cx.executor().forbid_parking();
3064
3065 // Create a project for new_thread
3066 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3067 fake_fs.insert_tree(path!("/test"), json!({})).await;
3068 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3069 let cwd = Path::new("/test");
3070 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3071
3072 // Create agent and connection
3073 let agent = NativeAgent::new(
3074 project.clone(),
3075 thread_store,
3076 templates.clone(),
3077 None,
3078 fake_fs.clone(),
3079 &mut cx.to_async(),
3080 )
3081 .await
3082 .unwrap();
3083 let connection = NativeAgentConnection(agent.clone());
3084
3085 // Create a thread using new_thread
3086 let connection_rc = Rc::new(connection.clone());
3087 let acp_thread = cx
3088 .update(|cx| connection_rc.new_session(project, cwd, cx))
3089 .await
3090 .expect("new_thread should succeed");
3091
3092 // Get the session_id from the AcpThread
3093 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3094
3095 // Test model_selector returns Some
3096 let selector_opt = connection.model_selector(&session_id);
3097 assert!(
3098 selector_opt.is_some(),
3099 "agent should always support ModelSelector"
3100 );
3101 let selector = selector_opt.unwrap();
3102
3103 // Test list_models
3104 let listed_models = cx
3105 .update(|cx| selector.list_models(cx))
3106 .await
3107 .expect("list_models should succeed");
3108 let AgentModelList::Grouped(listed_models) = listed_models else {
3109 panic!("Unexpected model list type");
3110 };
3111 assert!(!listed_models.is_empty(), "should have at least one model");
3112 assert_eq!(
3113 listed_models[&AgentModelGroupName("Fake".into())][0]
3114 .id
3115 .0
3116 .as_ref(),
3117 "fake/fake"
3118 );
3119
3120 // Test selected_model returns the default
3121 let model = cx
3122 .update(|cx| selector.selected_model(cx))
3123 .await
3124 .expect("selected_model should succeed");
3125 let model = cx
3126 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3127 .unwrap();
3128 let model = model.as_fake();
3129 assert_eq!(model.id().0, "fake", "should return default model");
3130
3131 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3132 cx.run_until_parked();
3133 model.send_last_completion_stream_text_chunk("def");
3134 cx.run_until_parked();
3135 acp_thread.read_with(cx, |thread, cx| {
3136 assert_eq!(
3137 thread.to_markdown(cx),
3138 indoc! {"
3139 ## User
3140
3141 abc
3142
3143 ## Assistant
3144
3145 def
3146
3147 "}
3148 )
3149 });
3150
3151 // Test cancel
3152 cx.update(|cx| connection.cancel(&session_id, cx));
3153 request.await.expect("prompt should fail gracefully");
3154
3155 // Explicitly close the session and drop the ACP thread.
3156 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3157 .await
3158 .unwrap();
3159 drop(acp_thread);
3160 let result = cx
3161 .update(|cx| {
3162 connection.prompt(
3163 Some(acp_thread::UserMessageId::new()),
3164 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3165 cx,
3166 )
3167 })
3168 .await;
3169 assert_eq!(
3170 result.as_ref().unwrap_err().to_string(),
3171 "Session not found",
3172 "unexpected result: {:?}",
3173 result
3174 );
3175}
3176
3177#[gpui::test]
3178async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3179 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3180 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
3181 let fake_model = model.as_fake();
3182
3183 let mut events = thread
3184 .update(cx, |thread, cx| {
3185 thread.send(UserMessageId::new(), ["Echo something"], cx)
3186 })
3187 .unwrap();
3188 cx.run_until_parked();
3189
3190 // Simulate streaming partial input.
3191 let input = json!({});
3192 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3193 LanguageModelToolUse {
3194 id: "1".into(),
3195 name: EchoTool::NAME.into(),
3196 raw_input: input.to_string(),
3197 input,
3198 is_input_complete: false,
3199 thought_signature: None,
3200 },
3201 ));
3202
3203 // Input streaming completed
3204 let input = json!({ "text": "Hello!" });
3205 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3206 LanguageModelToolUse {
3207 id: "1".into(),
3208 name: "echo".into(),
3209 raw_input: input.to_string(),
3210 input,
3211 is_input_complete: true,
3212 thought_signature: None,
3213 },
3214 ));
3215 fake_model.end_last_completion_stream();
3216 cx.run_until_parked();
3217
3218 let tool_call = expect_tool_call(&mut events).await;
3219 assert_eq!(
3220 tool_call,
3221 acp::ToolCall::new("1", "Echo")
3222 .raw_input(json!({}))
3223 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3224 );
3225 let update = expect_tool_call_update_fields(&mut events).await;
3226 assert_eq!(
3227 update,
3228 acp::ToolCallUpdate::new(
3229 "1",
3230 acp::ToolCallUpdateFields::new()
3231 .title("Echo")
3232 .kind(acp::ToolKind::Other)
3233 .raw_input(json!({ "text": "Hello!"}))
3234 )
3235 );
3236 let update = expect_tool_call_update_fields(&mut events).await;
3237 assert_eq!(
3238 update,
3239 acp::ToolCallUpdate::new(
3240 "1",
3241 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3242 )
3243 );
3244 let update = expect_tool_call_update_fields(&mut events).await;
3245 assert_eq!(
3246 update,
3247 acp::ToolCallUpdate::new(
3248 "1",
3249 acp::ToolCallUpdateFields::new()
3250 .status(acp::ToolCallStatus::Completed)
3251 .raw_output("Hello!")
3252 )
3253 );
3254}
3255
3256#[gpui::test]
3257async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3258 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3259 let fake_model = model.as_fake();
3260
3261 let mut events = thread
3262 .update(cx, |thread, cx| {
3263 thread.send(UserMessageId::new(), ["Hello!"], cx)
3264 })
3265 .unwrap();
3266 cx.run_until_parked();
3267
3268 fake_model.send_last_completion_stream_text_chunk("Hey!");
3269 fake_model.end_last_completion_stream();
3270
3271 let mut retry_events = Vec::new();
3272 while let Some(Ok(event)) = events.next().await {
3273 match event {
3274 ThreadEvent::Retry(retry_status) => {
3275 retry_events.push(retry_status);
3276 }
3277 ThreadEvent::Stop(..) => break,
3278 _ => {}
3279 }
3280 }
3281
3282 assert_eq!(retry_events.len(), 0);
3283 thread.read_with(cx, |thread, _cx| {
3284 assert_eq!(
3285 thread.to_markdown(),
3286 indoc! {"
3287 ## User
3288
3289 Hello!
3290
3291 ## Assistant
3292
3293 Hey!
3294 "}
3295 )
3296 });
3297}
3298
3299#[gpui::test]
3300async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3301 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3302 let fake_model = model.as_fake();
3303
3304 let mut events = thread
3305 .update(cx, |thread, cx| {
3306 thread.send(UserMessageId::new(), ["Hello!"], cx)
3307 })
3308 .unwrap();
3309 cx.run_until_parked();
3310
3311 fake_model.send_last_completion_stream_text_chunk("Hey,");
3312 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3313 provider: LanguageModelProviderName::new("Anthropic"),
3314 retry_after: Some(Duration::from_secs(3)),
3315 });
3316 fake_model.end_last_completion_stream();
3317
3318 cx.executor().advance_clock(Duration::from_secs(3));
3319 cx.run_until_parked();
3320
3321 fake_model.send_last_completion_stream_text_chunk("there!");
3322 fake_model.end_last_completion_stream();
3323 cx.run_until_parked();
3324
3325 let mut retry_events = Vec::new();
3326 while let Some(Ok(event)) = events.next().await {
3327 match event {
3328 ThreadEvent::Retry(retry_status) => {
3329 retry_events.push(retry_status);
3330 }
3331 ThreadEvent::Stop(..) => break,
3332 _ => {}
3333 }
3334 }
3335
3336 assert_eq!(retry_events.len(), 1);
3337 assert!(matches!(
3338 retry_events[0],
3339 acp_thread::RetryStatus { attempt: 1, .. }
3340 ));
3341 thread.read_with(cx, |thread, _cx| {
3342 assert_eq!(
3343 thread.to_markdown(),
3344 indoc! {"
3345 ## User
3346
3347 Hello!
3348
3349 ## Assistant
3350
3351 Hey,
3352
3353 [resume]
3354
3355 ## Assistant
3356
3357 there!
3358 "}
3359 )
3360 });
3361}
3362
3363#[gpui::test]
3364async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3365 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3366 let fake_model = model.as_fake();
3367
3368 let events = thread
3369 .update(cx, |thread, cx| {
3370 thread.add_tool(EchoTool, None);
3371 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3372 })
3373 .unwrap();
3374 cx.run_until_parked();
3375
3376 let tool_use_1 = LanguageModelToolUse {
3377 id: "tool_1".into(),
3378 name: EchoTool::NAME.into(),
3379 raw_input: json!({"text": "test"}).to_string(),
3380 input: json!({"text": "test"}),
3381 is_input_complete: true,
3382 thought_signature: None,
3383 };
3384 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3385 tool_use_1.clone(),
3386 ));
3387 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3388 provider: LanguageModelProviderName::new("Anthropic"),
3389 retry_after: Some(Duration::from_secs(3)),
3390 });
3391 fake_model.end_last_completion_stream();
3392
3393 cx.executor().advance_clock(Duration::from_secs(3));
3394 let completion = fake_model.pending_completions().pop().unwrap();
3395 assert_eq!(
3396 completion.messages[1..],
3397 vec![
3398 LanguageModelRequestMessage {
3399 role: Role::User,
3400 content: vec!["Call the echo tool!".into()],
3401 cache: false,
3402 reasoning_details: None,
3403 },
3404 LanguageModelRequestMessage {
3405 role: Role::Assistant,
3406 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3407 cache: false,
3408 reasoning_details: None,
3409 },
3410 LanguageModelRequestMessage {
3411 role: Role::User,
3412 content: vec![language_model::MessageContent::ToolResult(
3413 LanguageModelToolResult {
3414 tool_use_id: tool_use_1.id.clone(),
3415 tool_name: tool_use_1.name.clone(),
3416 is_error: false,
3417 content: "test".into(),
3418 output: Some("test".into())
3419 }
3420 )],
3421 cache: true,
3422 reasoning_details: None,
3423 },
3424 ]
3425 );
3426
3427 fake_model.send_last_completion_stream_text_chunk("Done");
3428 fake_model.end_last_completion_stream();
3429 cx.run_until_parked();
3430 events.collect::<Vec<_>>().await;
3431 thread.read_with(cx, |thread, _cx| {
3432 assert_eq!(
3433 thread.last_message(),
3434 Some(Message::Agent(AgentMessage {
3435 content: vec![AgentMessageContent::Text("Done".into())],
3436 tool_results: IndexMap::default(),
3437 reasoning_details: None,
3438 }))
3439 );
3440 })
3441}
3442
3443#[gpui::test]
3444async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3445 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3446 let fake_model = model.as_fake();
3447
3448 let mut events = thread
3449 .update(cx, |thread, cx| {
3450 thread.send(UserMessageId::new(), ["Hello!"], cx)
3451 })
3452 .unwrap();
3453 cx.run_until_parked();
3454
3455 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3456 fake_model.send_last_completion_stream_error(
3457 LanguageModelCompletionError::ServerOverloaded {
3458 provider: LanguageModelProviderName::new("Anthropic"),
3459 retry_after: Some(Duration::from_secs(3)),
3460 },
3461 );
3462 fake_model.end_last_completion_stream();
3463 cx.executor().advance_clock(Duration::from_secs(3));
3464 cx.run_until_parked();
3465 }
3466
3467 let mut errors = Vec::new();
3468 let mut retry_events = Vec::new();
3469 while let Some(event) = events.next().await {
3470 match event {
3471 Ok(ThreadEvent::Retry(retry_status)) => {
3472 retry_events.push(retry_status);
3473 }
3474 Ok(ThreadEvent::Stop(..)) => break,
3475 Err(error) => errors.push(error),
3476 _ => {}
3477 }
3478 }
3479
3480 assert_eq!(
3481 retry_events.len(),
3482 crate::thread::MAX_RETRY_ATTEMPTS as usize
3483 );
3484 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3485 assert_eq!(retry_events[i].attempt, i + 1);
3486 }
3487 assert_eq!(errors.len(), 1);
3488 let error = errors[0]
3489 .downcast_ref::<LanguageModelCompletionError>()
3490 .unwrap();
3491 assert!(matches!(
3492 error,
3493 LanguageModelCompletionError::ServerOverloaded { .. }
3494 ));
3495}
3496
3497/// Filters out the stop events for asserting against in tests
3498fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3499 result_events
3500 .into_iter()
3501 .filter_map(|event| match event.unwrap() {
3502 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3503 _ => None,
3504 })
3505 .collect()
3506}
3507
3508struct ThreadTest {
3509 model: Arc<dyn LanguageModel>,
3510 thread: Entity<Thread>,
3511 project_context: Entity<ProjectContext>,
3512 context_server_store: Entity<ContextServerStore>,
3513 fs: Arc<FakeFs>,
3514}
3515
3516enum TestModel {
3517 Sonnet4,
3518 Fake,
3519}
3520
3521impl TestModel {
3522 fn id(&self) -> LanguageModelId {
3523 match self {
3524 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3525 TestModel::Fake => unreachable!(),
3526 }
3527 }
3528}
3529
3530async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3531 cx.executor().allow_parking();
3532
3533 let fs = FakeFs::new(cx.background_executor.clone());
3534 fs.create_dir(paths::settings_file().parent().unwrap())
3535 .await
3536 .unwrap();
3537 fs.insert_file(
3538 paths::settings_file(),
3539 json!({
3540 "agent": {
3541 "default_profile": "test-profile",
3542 "profiles": {
3543 "test-profile": {
3544 "name": "Test Profile",
3545 "tools": {
3546 EchoTool::NAME: true,
3547 DelayTool::NAME: true,
3548 WordListTool::NAME: true,
3549 ToolRequiringPermission::NAME: true,
3550 InfiniteTool::NAME: true,
3551 CancellationAwareTool::NAME: true,
3552 (TerminalTool::NAME): true,
3553 }
3554 }
3555 }
3556 }
3557 })
3558 .to_string()
3559 .into_bytes(),
3560 )
3561 .await;
3562
3563 cx.update(|cx| {
3564 settings::init(cx);
3565
3566 match model {
3567 TestModel::Fake => {}
3568 TestModel::Sonnet4 => {
3569 gpui_tokio::init(cx);
3570 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3571 cx.set_http_client(Arc::new(http_client));
3572 let client = Client::production(cx);
3573 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3574 language_model::init(client.clone(), cx);
3575 language_models::init(user_store, client.clone(), cx);
3576 }
3577 };
3578
3579 watch_settings(fs.clone(), cx);
3580 });
3581
3582 let templates = Templates::new();
3583
3584 fs.insert_tree(path!("/test"), json!({})).await;
3585 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3586
3587 let model = cx
3588 .update(|cx| {
3589 if let TestModel::Fake = model {
3590 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3591 } else {
3592 let model_id = model.id();
3593 let models = LanguageModelRegistry::read_global(cx);
3594 let model = models
3595 .available_models(cx)
3596 .find(|model| model.id() == model_id)
3597 .unwrap();
3598
3599 let provider = models.provider(&model.provider_id()).unwrap();
3600 let authenticated = provider.authenticate(cx);
3601
3602 cx.spawn(async move |_cx| {
3603 authenticated.await.unwrap();
3604 model
3605 })
3606 }
3607 })
3608 .await;
3609
3610 let project_context = cx.new(|_cx| ProjectContext::default());
3611 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3612 let context_server_registry =
3613 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3614 let thread = cx.new(|cx| {
3615 Thread::new(
3616 project,
3617 project_context.clone(),
3618 context_server_registry,
3619 templates,
3620 Some(model.clone()),
3621 cx,
3622 )
3623 });
3624 ThreadTest {
3625 model,
3626 thread,
3627 project_context,
3628 context_server_store,
3629 fs,
3630 }
3631}
3632
3633#[cfg(test)]
3634#[ctor::ctor]
3635fn init_logger() {
3636 if std::env::var("RUST_LOG").is_ok() {
3637 env_logger::init();
3638 }
3639}
3640
3641fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3642 let fs = fs.clone();
3643 cx.spawn({
3644 async move |cx| {
3645 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3646 cx.background_executor(),
3647 fs,
3648 paths::settings_file().clone(),
3649 );
3650 let _watcher_task = watcher_task;
3651
3652 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3653 cx.update(|cx| {
3654 SettingsStore::update_global(cx, |settings, cx| {
3655 settings.set_user_settings(&new_settings_content, cx)
3656 })
3657 })
3658 .ok();
3659 }
3660 }
3661 })
3662 .detach();
3663}
3664
3665fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3666 completion
3667 .tools
3668 .iter()
3669 .map(|tool| tool.name.clone())
3670 .collect()
3671}
3672
3673fn setup_context_server(
3674 name: &'static str,
3675 tools: Vec<context_server::types::Tool>,
3676 context_server_store: &Entity<ContextServerStore>,
3677 cx: &mut TestAppContext,
3678) -> mpsc::UnboundedReceiver<(
3679 context_server::types::CallToolParams,
3680 oneshot::Sender<context_server::types::CallToolResponse>,
3681)> {
3682 cx.update(|cx| {
3683 let mut settings = ProjectSettings::get_global(cx).clone();
3684 settings.context_servers.insert(
3685 name.into(),
3686 project::project_settings::ContextServerSettings::Stdio {
3687 enabled: true,
3688 remote: false,
3689 command: ContextServerCommand {
3690 path: "somebinary".into(),
3691 args: Vec::new(),
3692 env: None,
3693 timeout: None,
3694 },
3695 },
3696 );
3697 ProjectSettings::override_global(settings, cx);
3698 });
3699
3700 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3701 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3702 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3703 context_server::types::InitializeResponse {
3704 protocol_version: context_server::types::ProtocolVersion(
3705 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3706 ),
3707 server_info: context_server::types::Implementation {
3708 name: name.into(),
3709 version: "1.0.0".to_string(),
3710 },
3711 capabilities: context_server::types::ServerCapabilities {
3712 tools: Some(context_server::types::ToolsCapabilities {
3713 list_changed: Some(true),
3714 }),
3715 ..Default::default()
3716 },
3717 meta: None,
3718 }
3719 })
3720 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3721 let tools = tools.clone();
3722 async move {
3723 context_server::types::ListToolsResponse {
3724 tools,
3725 next_cursor: None,
3726 meta: None,
3727 }
3728 }
3729 })
3730 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3731 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3732 async move {
3733 let (response_tx, response_rx) = oneshot::channel();
3734 mcp_tool_calls_tx
3735 .unbounded_send((params, response_tx))
3736 .unwrap();
3737 response_rx.await.unwrap()
3738 }
3739 });
3740 context_server_store.update(cx, |store, cx| {
3741 store.start_server(
3742 Arc::new(ContextServer::new(
3743 ContextServerId(name.into()),
3744 Arc::new(fake_transport),
3745 )),
3746 cx,
3747 );
3748 });
3749 cx.run_until_parked();
3750 mcp_tool_calls_rx
3751}
3752
3753#[gpui::test]
3754async fn test_tokens_before_message(cx: &mut TestAppContext) {
3755 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3756 let fake_model = model.as_fake();
3757
3758 // First message
3759 let message_1_id = UserMessageId::new();
3760 thread
3761 .update(cx, |thread, cx| {
3762 thread.send(message_1_id.clone(), ["First message"], cx)
3763 })
3764 .unwrap();
3765 cx.run_until_parked();
3766
3767 // Before any response, tokens_before_message should return None for first message
3768 thread.read_with(cx, |thread, _| {
3769 assert_eq!(
3770 thread.tokens_before_message(&message_1_id),
3771 None,
3772 "First message should have no tokens before it"
3773 );
3774 });
3775
3776 // Complete first message with usage
3777 fake_model.send_last_completion_stream_text_chunk("Response 1");
3778 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3779 language_model::TokenUsage {
3780 input_tokens: 100,
3781 output_tokens: 50,
3782 cache_creation_input_tokens: 0,
3783 cache_read_input_tokens: 0,
3784 },
3785 ));
3786 fake_model.end_last_completion_stream();
3787 cx.run_until_parked();
3788
3789 // First message still has no tokens before it
3790 thread.read_with(cx, |thread, _| {
3791 assert_eq!(
3792 thread.tokens_before_message(&message_1_id),
3793 None,
3794 "First message should still have no tokens before it after response"
3795 );
3796 });
3797
3798 // Second message
3799 let message_2_id = UserMessageId::new();
3800 thread
3801 .update(cx, |thread, cx| {
3802 thread.send(message_2_id.clone(), ["Second message"], cx)
3803 })
3804 .unwrap();
3805 cx.run_until_parked();
3806
3807 // Second message should have first message's input tokens before it
3808 thread.read_with(cx, |thread, _| {
3809 assert_eq!(
3810 thread.tokens_before_message(&message_2_id),
3811 Some(100),
3812 "Second message should have 100 tokens before it (from first request)"
3813 );
3814 });
3815
3816 // Complete second message
3817 fake_model.send_last_completion_stream_text_chunk("Response 2");
3818 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3819 language_model::TokenUsage {
3820 input_tokens: 250, // Total for this request (includes previous context)
3821 output_tokens: 75,
3822 cache_creation_input_tokens: 0,
3823 cache_read_input_tokens: 0,
3824 },
3825 ));
3826 fake_model.end_last_completion_stream();
3827 cx.run_until_parked();
3828
3829 // Third message
3830 let message_3_id = UserMessageId::new();
3831 thread
3832 .update(cx, |thread, cx| {
3833 thread.send(message_3_id.clone(), ["Third message"], cx)
3834 })
3835 .unwrap();
3836 cx.run_until_parked();
3837
3838 // Third message should have second message's input tokens (250) before it
3839 thread.read_with(cx, |thread, _| {
3840 assert_eq!(
3841 thread.tokens_before_message(&message_3_id),
3842 Some(250),
3843 "Third message should have 250 tokens before it (from second request)"
3844 );
3845 // Second message should still have 100
3846 assert_eq!(
3847 thread.tokens_before_message(&message_2_id),
3848 Some(100),
3849 "Second message should still have 100 tokens before it"
3850 );
3851 // First message still has none
3852 assert_eq!(
3853 thread.tokens_before_message(&message_1_id),
3854 None,
3855 "First message should still have no tokens before it"
3856 );
3857 });
3858}
3859
3860#[gpui::test]
3861async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3862 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3863 let fake_model = model.as_fake();
3864
3865 // Set up three messages with responses
3866 let message_1_id = UserMessageId::new();
3867 thread
3868 .update(cx, |thread, cx| {
3869 thread.send(message_1_id.clone(), ["Message 1"], cx)
3870 })
3871 .unwrap();
3872 cx.run_until_parked();
3873 fake_model.send_last_completion_stream_text_chunk("Response 1");
3874 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3875 language_model::TokenUsage {
3876 input_tokens: 100,
3877 output_tokens: 50,
3878 cache_creation_input_tokens: 0,
3879 cache_read_input_tokens: 0,
3880 },
3881 ));
3882 fake_model.end_last_completion_stream();
3883 cx.run_until_parked();
3884
3885 let message_2_id = UserMessageId::new();
3886 thread
3887 .update(cx, |thread, cx| {
3888 thread.send(message_2_id.clone(), ["Message 2"], cx)
3889 })
3890 .unwrap();
3891 cx.run_until_parked();
3892 fake_model.send_last_completion_stream_text_chunk("Response 2");
3893 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3894 language_model::TokenUsage {
3895 input_tokens: 250,
3896 output_tokens: 75,
3897 cache_creation_input_tokens: 0,
3898 cache_read_input_tokens: 0,
3899 },
3900 ));
3901 fake_model.end_last_completion_stream();
3902 cx.run_until_parked();
3903
3904 // Verify initial state
3905 thread.read_with(cx, |thread, _| {
3906 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3907 });
3908
3909 // Truncate at message 2 (removes message 2 and everything after)
3910 thread
3911 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3912 .unwrap();
3913 cx.run_until_parked();
3914
3915 // After truncation, message_2_id no longer exists, so lookup should return None
3916 thread.read_with(cx, |thread, _| {
3917 assert_eq!(
3918 thread.tokens_before_message(&message_2_id),
3919 None,
3920 "After truncation, message 2 no longer exists"
3921 );
3922 // Message 1 still exists but has no tokens before it
3923 assert_eq!(
3924 thread.tokens_before_message(&message_1_id),
3925 None,
3926 "First message still has no tokens before it"
3927 );
3928 });
3929}
3930
3931#[gpui::test]
3932async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3933 init_test(cx);
3934
3935 let fs = FakeFs::new(cx.executor());
3936 fs.insert_tree("/root", json!({})).await;
3937 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3938
3939 // Test 1: Deny rule blocks command
3940 {
3941 let environment = Rc::new(cx.update(|cx| {
3942 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3943 }));
3944
3945 cx.update(|cx| {
3946 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3947 settings.tool_permissions.tools.insert(
3948 TerminalTool::NAME.into(),
3949 agent_settings::ToolRules {
3950 default: Some(settings::ToolPermissionMode::Confirm),
3951 always_allow: vec![],
3952 always_deny: vec![
3953 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3954 ],
3955 always_confirm: vec![],
3956 invalid_patterns: vec![],
3957 },
3958 );
3959 agent_settings::AgentSettings::override_global(settings, cx);
3960 });
3961
3962 #[allow(clippy::arc_with_non_send_sync)]
3963 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3964 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3965
3966 let task = cx.update(|cx| {
3967 tool.run(
3968 crate::TerminalToolInput {
3969 command: "rm -rf /".to_string(),
3970 cd: ".".to_string(),
3971 timeout_ms: None,
3972 },
3973 event_stream,
3974 cx,
3975 )
3976 });
3977
3978 let result = task.await;
3979 assert!(
3980 result.is_err(),
3981 "expected command to be blocked by deny rule"
3982 );
3983 let err_msg = result.unwrap_err().to_string().to_lowercase();
3984 assert!(
3985 err_msg.contains("blocked"),
3986 "error should mention the command was blocked"
3987 );
3988 }
3989
3990 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
3991 {
3992 let environment = Rc::new(cx.update(|cx| {
3993 FakeThreadEnvironment::default()
3994 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3995 }));
3996
3997 cx.update(|cx| {
3998 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3999 settings.tool_permissions.tools.insert(
4000 TerminalTool::NAME.into(),
4001 agent_settings::ToolRules {
4002 default: Some(settings::ToolPermissionMode::Deny),
4003 always_allow: vec![
4004 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4005 ],
4006 always_deny: vec![],
4007 always_confirm: vec![],
4008 invalid_patterns: vec![],
4009 },
4010 );
4011 agent_settings::AgentSettings::override_global(settings, cx);
4012 });
4013
4014 #[allow(clippy::arc_with_non_send_sync)]
4015 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4016 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4017
4018 let task = cx.update(|cx| {
4019 tool.run(
4020 crate::TerminalToolInput {
4021 command: "echo hello".to_string(),
4022 cd: ".".to_string(),
4023 timeout_ms: None,
4024 },
4025 event_stream,
4026 cx,
4027 )
4028 });
4029
4030 let update = rx.expect_update_fields().await;
4031 assert!(
4032 update.content.iter().any(|blocks| {
4033 blocks
4034 .iter()
4035 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4036 }),
4037 "expected terminal content (allow rule should skip confirmation and override default deny)"
4038 );
4039
4040 let result = task.await;
4041 assert!(
4042 result.is_ok(),
4043 "expected command to succeed without confirmation"
4044 );
4045 }
4046
4047 // Test 3: global default: allow does NOT override always_confirm patterns
4048 {
4049 let environment = Rc::new(cx.update(|cx| {
4050 FakeThreadEnvironment::default()
4051 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4052 }));
4053
4054 cx.update(|cx| {
4055 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4056 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4057 settings.tool_permissions.tools.insert(
4058 TerminalTool::NAME.into(),
4059 agent_settings::ToolRules {
4060 default: Some(settings::ToolPermissionMode::Allow),
4061 always_allow: vec![],
4062 always_deny: vec![],
4063 always_confirm: vec![
4064 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4065 ],
4066 invalid_patterns: vec![],
4067 },
4068 );
4069 agent_settings::AgentSettings::override_global(settings, cx);
4070 });
4071
4072 #[allow(clippy::arc_with_non_send_sync)]
4073 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4074 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4075
4076 let _task = cx.update(|cx| {
4077 tool.run(
4078 crate::TerminalToolInput {
4079 command: "sudo rm file".to_string(),
4080 cd: ".".to_string(),
4081 timeout_ms: None,
4082 },
4083 event_stream,
4084 cx,
4085 )
4086 });
4087
4088 // With global default: allow, confirm patterns are still respected
4089 // The expect_authorization() call will panic if no authorization is requested,
4090 // which validates that the confirm pattern still triggers confirmation
4091 let _auth = rx.expect_authorization().await;
4092
4093 drop(_task);
4094 }
4095
4096 // Test 4: tool-specific default: deny is respected even with global default: allow
4097 {
4098 let environment = Rc::new(cx.update(|cx| {
4099 FakeThreadEnvironment::default()
4100 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4101 }));
4102
4103 cx.update(|cx| {
4104 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4105 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4106 settings.tool_permissions.tools.insert(
4107 TerminalTool::NAME.into(),
4108 agent_settings::ToolRules {
4109 default: Some(settings::ToolPermissionMode::Deny),
4110 always_allow: vec![],
4111 always_deny: vec![],
4112 always_confirm: vec![],
4113 invalid_patterns: vec![],
4114 },
4115 );
4116 agent_settings::AgentSettings::override_global(settings, cx);
4117 });
4118
4119 #[allow(clippy::arc_with_non_send_sync)]
4120 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4121 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4122
4123 let task = cx.update(|cx| {
4124 tool.run(
4125 crate::TerminalToolInput {
4126 command: "echo hello".to_string(),
4127 cd: ".".to_string(),
4128 timeout_ms: None,
4129 },
4130 event_stream,
4131 cx,
4132 )
4133 });
4134
4135 // tool-specific default: deny is respected even with global default: allow
4136 let result = task.await;
4137 assert!(
4138 result.is_err(),
4139 "expected command to be blocked by tool-specific deny default"
4140 );
4141 let err_msg = result.unwrap_err().to_string().to_lowercase();
4142 assert!(
4143 err_msg.contains("disabled"),
4144 "error should mention the tool is disabled, got: {err_msg}"
4145 );
4146 }
4147}
4148
4149#[gpui::test]
4150async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4151 init_test(cx);
4152
4153 cx.update(|cx| {
4154 cx.update_flags(true, vec!["subagents".to_string()]);
4155 });
4156
4157 let fs = FakeFs::new(cx.executor());
4158 fs.insert_tree(path!("/test"), json!({})).await;
4159 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4160 let project_context = cx.new(|_cx| ProjectContext::default());
4161 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4162 let context_server_registry =
4163 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4164 let model = Arc::new(FakeLanguageModel::default());
4165
4166 let environment = Rc::new(cx.update(|cx| {
4167 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4168 }));
4169
4170 let thread = cx.new(|cx| {
4171 let mut thread = Thread::new(
4172 project.clone(),
4173 project_context,
4174 context_server_registry,
4175 Templates::new(),
4176 Some(model),
4177 cx,
4178 );
4179 thread.add_default_tools(None, environment, cx);
4180 thread
4181 });
4182
4183 thread.read_with(cx, |thread, _| {
4184 assert!(
4185 thread.has_registered_tool(SubagentTool::NAME),
4186 "subagent tool should be present when feature flag is enabled"
4187 );
4188 });
4189}
4190
4191#[gpui::test]
4192async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4193 init_test(cx);
4194
4195 cx.update(|cx| {
4196 cx.update_flags(true, vec!["subagents".to_string()]);
4197 });
4198
4199 let fs = FakeFs::new(cx.executor());
4200 fs.insert_tree(path!("/test"), json!({})).await;
4201 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4202 let project_context = cx.new(|_cx| ProjectContext::default());
4203 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4204 let context_server_registry =
4205 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4206 let model = Arc::new(FakeLanguageModel::default());
4207
4208 let parent_thread = cx.new(|cx| {
4209 Thread::new(
4210 project.clone(),
4211 project_context,
4212 context_server_registry,
4213 Templates::new(),
4214 Some(model.clone()),
4215 cx,
4216 )
4217 });
4218
4219 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4220 subagent_thread.read_with(cx, |subagent_thread, cx| {
4221 assert!(subagent_thread.is_subagent());
4222 assert_eq!(subagent_thread.depth(), 1);
4223 assert_eq!(
4224 subagent_thread.model().map(|model| model.id()),
4225 Some(model.id())
4226 );
4227 assert_eq!(
4228 subagent_thread.parent_thread_id(),
4229 Some(parent_thread.read(cx).id().clone())
4230 );
4231 });
4232}
4233
4234#[gpui::test]
4235async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4236 init_test(cx);
4237
4238 cx.update(|cx| {
4239 cx.update_flags(true, vec!["subagents".to_string()]);
4240 });
4241
4242 let fs = FakeFs::new(cx.executor());
4243 fs.insert_tree(path!("/test"), json!({})).await;
4244 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4245 let project_context = cx.new(|_cx| ProjectContext::default());
4246 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4247 let context_server_registry =
4248 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4249 let model = Arc::new(FakeLanguageModel::default());
4250 let environment = Rc::new(cx.update(|cx| {
4251 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4252 }));
4253
4254 let deep_parent_thread = cx.new(|cx| {
4255 let mut thread = Thread::new(
4256 project.clone(),
4257 project_context,
4258 context_server_registry,
4259 Templates::new(),
4260 Some(model.clone()),
4261 cx,
4262 );
4263 thread.set_subagent_context(SubagentContext {
4264 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4265 depth: MAX_SUBAGENT_DEPTH - 1,
4266 });
4267 thread
4268 });
4269 let deep_subagent_thread = cx.new(|cx| {
4270 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4271 thread.add_default_tools(None, environment, cx);
4272 thread
4273 });
4274
4275 deep_subagent_thread.read_with(cx, |thread, _| {
4276 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4277 assert!(
4278 !thread.has_registered_tool(SubagentTool::NAME),
4279 "subagent tool should not be present at max depth"
4280 );
4281 });
4282}
4283
4284#[gpui::test]
4285async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4286 init_test(cx);
4287
4288 cx.update(|cx| {
4289 cx.update_flags(true, vec!["subagents".to_string()]);
4290 });
4291
4292 let fs = FakeFs::new(cx.executor());
4293 fs.insert_tree(path!("/test"), json!({})).await;
4294 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4295 let project_context = cx.new(|_cx| ProjectContext::default());
4296 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4297 let context_server_registry =
4298 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4299 let model = Arc::new(FakeLanguageModel::default());
4300
4301 let parent = cx.new(|cx| {
4302 Thread::new(
4303 project.clone(),
4304 project_context.clone(),
4305 context_server_registry.clone(),
4306 Templates::new(),
4307 Some(model.clone()),
4308 cx,
4309 )
4310 });
4311
4312 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4313
4314 parent.update(cx, |thread, _cx| {
4315 thread.register_running_subagent(subagent.downgrade());
4316 });
4317
4318 subagent
4319 .update(cx, |thread, cx| {
4320 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4321 })
4322 .unwrap();
4323 cx.run_until_parked();
4324
4325 subagent.read_with(cx, |thread, _| {
4326 assert!(!thread.is_turn_complete(), "subagent should be running");
4327 });
4328
4329 parent.update(cx, |thread, cx| {
4330 thread.cancel(cx).detach();
4331 });
4332
4333 subagent.read_with(cx, |thread, _| {
4334 assert!(
4335 thread.is_turn_complete(),
4336 "subagent should be cancelled when parent cancels"
4337 );
4338 });
4339}
4340
4341#[gpui::test]
4342async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4343 // This test verifies that the subagent tool properly handles user cancellation
4344 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4345 init_test(cx);
4346 always_allow_tools(cx);
4347
4348 cx.update(|cx| {
4349 cx.update_flags(true, vec!["subagents".to_string()]);
4350 });
4351
4352 let fs = FakeFs::new(cx.executor());
4353 fs.insert_tree(path!("/test"), json!({})).await;
4354 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4355 let project_context = cx.new(|_cx| ProjectContext::default());
4356 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4357 let context_server_registry =
4358 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4359 let model = Arc::new(FakeLanguageModel::default());
4360 let environment = Rc::new(cx.update(|cx| {
4361 FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
4362 }));
4363
4364 let parent = cx.new(|cx| {
4365 Thread::new(
4366 project.clone(),
4367 project_context.clone(),
4368 context_server_registry.clone(),
4369 Templates::new(),
4370 Some(model.clone()),
4371 cx,
4372 )
4373 });
4374
4375 #[allow(clippy::arc_with_non_send_sync)]
4376 let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
4377
4378 let (event_stream, _rx, mut cancellation_tx) =
4379 crate::ToolCallEventStream::test_with_cancellation();
4380
4381 // Start the subagent tool
4382 let task = cx.update(|cx| {
4383 tool.run(
4384 SubagentToolInput {
4385 label: "Long running task".to_string(),
4386 task_prompt: "Do a very long task that takes forever".to_string(),
4387 summary_prompt: "Summarize".to_string(),
4388 timeout_ms: None,
4389 allowed_tools: None,
4390 },
4391 event_stream.clone(),
4392 cx,
4393 )
4394 });
4395
4396 cx.run_until_parked();
4397
4398 // Signal cancellation via the event stream
4399 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4400
4401 // The task should complete promptly with a cancellation error
4402 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4403 let result = futures::select! {
4404 result = task.fuse() => result,
4405 _ = timeout.fuse() => {
4406 panic!("subagent tool did not respond to cancellation within timeout");
4407 }
4408 };
4409
4410 // Verify we got a cancellation error
4411 let err = result.unwrap_err();
4412 assert!(
4413 err.to_string().contains("cancelled by user"),
4414 "expected cancellation error, got: {}",
4415 err
4416 );
4417}
4418
4419#[gpui::test]
4420async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4421 init_test(cx);
4422 always_allow_tools(cx);
4423
4424 cx.update(|cx| {
4425 cx.update_flags(true, vec!["subagents".to_string()]);
4426 });
4427
4428 let fs = FakeFs::new(cx.executor());
4429 fs.insert_tree(path!("/test"), json!({})).await;
4430 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4431 let project_context = cx.new(|_cx| ProjectContext::default());
4432 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4433 let context_server_registry =
4434 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4435 cx.update(LanguageModelRegistry::test);
4436 let model = Arc::new(FakeLanguageModel::default());
4437 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4438 let native_agent = NativeAgent::new(
4439 project.clone(),
4440 thread_store,
4441 Templates::new(),
4442 None,
4443 fs,
4444 &mut cx.to_async(),
4445 )
4446 .await
4447 .unwrap();
4448 let parent_thread = cx.new(|cx| {
4449 Thread::new(
4450 project.clone(),
4451 project_context,
4452 context_server_registry,
4453 Templates::new(),
4454 Some(model.clone()),
4455 cx,
4456 )
4457 });
4458
4459 let mut handles = Vec::new();
4460 for _ in 0..MAX_PARALLEL_SUBAGENTS {
4461 let handle = cx
4462 .update(|cx| {
4463 NativeThreadEnvironment::create_subagent_thread(
4464 native_agent.downgrade(),
4465 parent_thread.clone(),
4466 "some title".to_string(),
4467 "some task".to_string(),
4468 None,
4469 None,
4470 cx,
4471 )
4472 })
4473 .expect("Expected to be able to create subagent thread");
4474 handles.push(handle);
4475 }
4476
4477 let result = cx.update(|cx| {
4478 NativeThreadEnvironment::create_subagent_thread(
4479 native_agent.downgrade(),
4480 parent_thread.clone(),
4481 "some title".to_string(),
4482 "some task".to_string(),
4483 None,
4484 None,
4485 cx,
4486 )
4487 });
4488 assert!(result.is_err());
4489 assert_eq!(
4490 result.err().unwrap().to_string(),
4491 format!(
4492 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
4493 MAX_PARALLEL_SUBAGENTS
4494 )
4495 );
4496}
4497
4498#[gpui::test]
4499async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4500 init_test(cx);
4501
4502 always_allow_tools(cx);
4503
4504 cx.update(|cx| {
4505 cx.update_flags(true, vec!["subagents".to_string()]);
4506 });
4507
4508 let fs = FakeFs::new(cx.executor());
4509 fs.insert_tree(path!("/test"), json!({})).await;
4510 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4511 let project_context = cx.new(|_cx| ProjectContext::default());
4512 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4513 let context_server_registry =
4514 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4515 cx.update(LanguageModelRegistry::test);
4516 let model = Arc::new(FakeLanguageModel::default());
4517 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4518 let native_agent = NativeAgent::new(
4519 project.clone(),
4520 thread_store,
4521 Templates::new(),
4522 None,
4523 fs,
4524 &mut cx.to_async(),
4525 )
4526 .await
4527 .unwrap();
4528 let parent_thread = cx.new(|cx| {
4529 Thread::new(
4530 project.clone(),
4531 project_context,
4532 context_server_registry,
4533 Templates::new(),
4534 Some(model.clone()),
4535 cx,
4536 )
4537 });
4538
4539 let subagent_handle = cx
4540 .update(|cx| {
4541 NativeThreadEnvironment::create_subagent_thread(
4542 native_agent.downgrade(),
4543 parent_thread.clone(),
4544 "some title".to_string(),
4545 "task prompt".to_string(),
4546 Some(Duration::from_millis(10)),
4547 None,
4548 cx,
4549 )
4550 })
4551 .expect("Failed to create subagent");
4552
4553 let summary_task =
4554 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4555
4556 cx.run_until_parked();
4557
4558 {
4559 let messages = model.pending_completions().last().unwrap().messages.clone();
4560 // Ensure that model received a system prompt
4561 assert_eq!(messages[0].role, Role::System);
4562 // Ensure that model received a task prompt
4563 assert_eq!(messages[1].role, Role::User);
4564 assert_eq!(
4565 messages[1].content,
4566 vec![MessageContent::Text("task prompt".to_string())]
4567 );
4568 }
4569
4570 model.send_last_completion_stream_text_chunk("Some task response...");
4571 model.end_last_completion_stream();
4572
4573 cx.run_until_parked();
4574
4575 {
4576 let messages = model.pending_completions().last().unwrap().messages.clone();
4577 assert_eq!(messages[2].role, Role::Assistant);
4578 assert_eq!(
4579 messages[2].content,
4580 vec![MessageContent::Text("Some task response...".to_string())]
4581 );
4582 // Ensure that model received a summary prompt
4583 assert_eq!(messages[3].role, Role::User);
4584 assert_eq!(
4585 messages[3].content,
4586 vec![MessageContent::Text("summary prompt".to_string())]
4587 );
4588 }
4589
4590 model.send_last_completion_stream_text_chunk("Some summary...");
4591 model.end_last_completion_stream();
4592
4593 let result = summary_task.await;
4594 assert_eq!(result.unwrap(), "Some summary...\n");
4595}
4596
4597#[gpui::test]
4598async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4599 cx: &mut TestAppContext,
4600) {
4601 init_test(cx);
4602
4603 always_allow_tools(cx);
4604
4605 cx.update(|cx| {
4606 cx.update_flags(true, vec!["subagents".to_string()]);
4607 });
4608
4609 let fs = FakeFs::new(cx.executor());
4610 fs.insert_tree(path!("/test"), json!({})).await;
4611 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4612 let project_context = cx.new(|_cx| ProjectContext::default());
4613 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4614 let context_server_registry =
4615 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4616 cx.update(LanguageModelRegistry::test);
4617 let model = Arc::new(FakeLanguageModel::default());
4618 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4619 let native_agent = NativeAgent::new(
4620 project.clone(),
4621 thread_store,
4622 Templates::new(),
4623 None,
4624 fs,
4625 &mut cx.to_async(),
4626 )
4627 .await
4628 .unwrap();
4629 let parent_thread = cx.new(|cx| {
4630 Thread::new(
4631 project.clone(),
4632 project_context,
4633 context_server_registry,
4634 Templates::new(),
4635 Some(model.clone()),
4636 cx,
4637 )
4638 });
4639
4640 let subagent_handle = cx
4641 .update(|cx| {
4642 NativeThreadEnvironment::create_subagent_thread(
4643 native_agent.downgrade(),
4644 parent_thread.clone(),
4645 "some title".to_string(),
4646 "task prompt".to_string(),
4647 Some(Duration::from_millis(100)),
4648 None,
4649 cx,
4650 )
4651 })
4652 .expect("Failed to create subagent");
4653
4654 let summary_task =
4655 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4656
4657 cx.run_until_parked();
4658
4659 {
4660 let messages = model.pending_completions().last().unwrap().messages.clone();
4661 // Ensure that model received a system prompt
4662 assert_eq!(messages[0].role, Role::System);
4663 // Ensure that model received a task prompt
4664 assert_eq!(
4665 messages[1].content,
4666 vec![MessageContent::Text("task prompt".to_string())]
4667 );
4668 }
4669
4670 // Don't complete the initial model stream — let the timeout expire instead.
4671 cx.executor().advance_clock(Duration::from_millis(200));
4672 cx.run_until_parked();
4673
4674 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
4675 // instead of the summary_prompt.
4676 {
4677 let messages = model.pending_completions().last().unwrap().messages.clone();
4678 let last_user_message = messages
4679 .iter()
4680 .rev()
4681 .find(|m| m.role == Role::User)
4682 .unwrap();
4683 assert_eq!(
4684 last_user_message.content,
4685 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
4686 );
4687 }
4688
4689 model.send_last_completion_stream_text_chunk("Some context low response...");
4690 model.end_last_completion_stream();
4691
4692 let result = summary_task.await;
4693 assert_eq!(result.unwrap(), "Some context low response...\n");
4694}
4695
4696#[gpui::test]
4697async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
4698 init_test(cx);
4699
4700 always_allow_tools(cx);
4701
4702 cx.update(|cx| {
4703 cx.update_flags(true, vec!["subagents".to_string()]);
4704 });
4705
4706 let fs = FakeFs::new(cx.executor());
4707 fs.insert_tree(path!("/test"), json!({})).await;
4708 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4709 let project_context = cx.new(|_cx| ProjectContext::default());
4710 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4711 let context_server_registry =
4712 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4713 cx.update(LanguageModelRegistry::test);
4714 let model = Arc::new(FakeLanguageModel::default());
4715 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4716 let native_agent = NativeAgent::new(
4717 project.clone(),
4718 thread_store,
4719 Templates::new(),
4720 None,
4721 fs,
4722 &mut cx.to_async(),
4723 )
4724 .await
4725 .unwrap();
4726 let parent_thread = cx.new(|cx| {
4727 let mut thread = Thread::new(
4728 project.clone(),
4729 project_context,
4730 context_server_registry,
4731 Templates::new(),
4732 Some(model.clone()),
4733 cx,
4734 );
4735 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4736 thread.add_tool(GrepTool::new(project.clone()), None);
4737 thread
4738 });
4739
4740 let _subagent_handle = cx
4741 .update(|cx| {
4742 NativeThreadEnvironment::create_subagent_thread(
4743 native_agent.downgrade(),
4744 parent_thread.clone(),
4745 "some title".to_string(),
4746 "task prompt".to_string(),
4747 Some(Duration::from_millis(10)),
4748 None,
4749 cx,
4750 )
4751 })
4752 .expect("Failed to create subagent");
4753
4754 cx.run_until_parked();
4755
4756 let tools = model
4757 .pending_completions()
4758 .last()
4759 .unwrap()
4760 .tools
4761 .iter()
4762 .map(|tool| tool.name.clone())
4763 .collect::<Vec<_>>();
4764 assert_eq!(tools.len(), 2);
4765 assert!(tools.contains(&"grep".to_string()));
4766 assert!(tools.contains(&"list_directory".to_string()));
4767}
4768
4769#[gpui::test]
4770async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
4771 init_test(cx);
4772
4773 always_allow_tools(cx);
4774
4775 cx.update(|cx| {
4776 cx.update_flags(true, vec!["subagents".to_string()]);
4777 });
4778
4779 let fs = FakeFs::new(cx.executor());
4780 fs.insert_tree(path!("/test"), json!({})).await;
4781 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4782 let project_context = cx.new(|_cx| ProjectContext::default());
4783 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4784 let context_server_registry =
4785 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4786 cx.update(LanguageModelRegistry::test);
4787 let model = Arc::new(FakeLanguageModel::default());
4788 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4789 let native_agent = NativeAgent::new(
4790 project.clone(),
4791 thread_store,
4792 Templates::new(),
4793 None,
4794 fs,
4795 &mut cx.to_async(),
4796 )
4797 .await
4798 .unwrap();
4799 let parent_thread = cx.new(|cx| {
4800 let mut thread = Thread::new(
4801 project.clone(),
4802 project_context,
4803 context_server_registry,
4804 Templates::new(),
4805 Some(model.clone()),
4806 cx,
4807 );
4808 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4809 thread.add_tool(GrepTool::new(project.clone()), None);
4810 thread
4811 });
4812
4813 let _subagent_handle = cx
4814 .update(|cx| {
4815 NativeThreadEnvironment::create_subagent_thread(
4816 native_agent.downgrade(),
4817 parent_thread.clone(),
4818 "some title".to_string(),
4819 "task prompt".to_string(),
4820 Some(Duration::from_millis(10)),
4821 Some(vec!["grep".to_string()]),
4822 cx,
4823 )
4824 })
4825 .expect("Failed to create subagent");
4826
4827 cx.run_until_parked();
4828
4829 let tools = model
4830 .pending_completions()
4831 .last()
4832 .unwrap()
4833 .tools
4834 .iter()
4835 .map(|tool| tool.name.clone())
4836 .collect::<Vec<_>>();
4837 assert_eq!(tools, vec!["grep"]);
4838}
4839
4840#[gpui::test]
4841async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4842 init_test(cx);
4843
4844 let fs = FakeFs::new(cx.executor());
4845 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4846 .await;
4847 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4848
4849 cx.update(|cx| {
4850 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4851 settings.tool_permissions.tools.insert(
4852 EditFileTool::NAME.into(),
4853 agent_settings::ToolRules {
4854 default: Some(settings::ToolPermissionMode::Allow),
4855 always_allow: vec![],
4856 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
4857 always_confirm: vec![],
4858 invalid_patterns: vec![],
4859 },
4860 );
4861 agent_settings::AgentSettings::override_global(settings, cx);
4862 });
4863
4864 let context_server_registry =
4865 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4866 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4867 let templates = crate::Templates::new();
4868 let thread = cx.new(|cx| {
4869 crate::Thread::new(
4870 project.clone(),
4871 cx.new(|_cx| prompt_store::ProjectContext::default()),
4872 context_server_registry,
4873 templates.clone(),
4874 None,
4875 cx,
4876 )
4877 });
4878
4879 #[allow(clippy::arc_with_non_send_sync)]
4880 let tool = Arc::new(crate::EditFileTool::new(
4881 project.clone(),
4882 thread.downgrade(),
4883 language_registry,
4884 templates,
4885 ));
4886 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4887
4888 let task = cx.update(|cx| {
4889 tool.run(
4890 crate::EditFileToolInput {
4891 display_description: "Edit sensitive file".to_string(),
4892 path: "root/sensitive_config.txt".into(),
4893 mode: crate::EditFileMode::Edit,
4894 },
4895 event_stream,
4896 cx,
4897 )
4898 });
4899
4900 let result = task.await;
4901 assert!(result.is_err(), "expected edit to be blocked");
4902 assert!(
4903 result.unwrap_err().to_string().contains("blocked"),
4904 "error should mention the edit was blocked"
4905 );
4906}
4907
4908#[gpui::test]
4909async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4910 init_test(cx);
4911
4912 let fs = FakeFs::new(cx.executor());
4913 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4914 .await;
4915 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4916
4917 cx.update(|cx| {
4918 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4919 settings.tool_permissions.tools.insert(
4920 DeletePathTool::NAME.into(),
4921 agent_settings::ToolRules {
4922 default: Some(settings::ToolPermissionMode::Allow),
4923 always_allow: vec![],
4924 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4925 always_confirm: vec![],
4926 invalid_patterns: vec![],
4927 },
4928 );
4929 agent_settings::AgentSettings::override_global(settings, cx);
4930 });
4931
4932 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4933
4934 #[allow(clippy::arc_with_non_send_sync)]
4935 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4936 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4937
4938 let task = cx.update(|cx| {
4939 tool.run(
4940 crate::DeletePathToolInput {
4941 path: "root/important_data.txt".to_string(),
4942 },
4943 event_stream,
4944 cx,
4945 )
4946 });
4947
4948 let result = task.await;
4949 assert!(result.is_err(), "expected deletion to be blocked");
4950 assert!(
4951 result.unwrap_err().to_string().contains("blocked"),
4952 "error should mention the deletion was blocked"
4953 );
4954}
4955
4956#[gpui::test]
4957async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
4958 init_test(cx);
4959
4960 let fs = FakeFs::new(cx.executor());
4961 fs.insert_tree(
4962 "/root",
4963 json!({
4964 "safe.txt": "content",
4965 "protected": {}
4966 }),
4967 )
4968 .await;
4969 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4970
4971 cx.update(|cx| {
4972 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4973 settings.tool_permissions.tools.insert(
4974 MovePathTool::NAME.into(),
4975 agent_settings::ToolRules {
4976 default: Some(settings::ToolPermissionMode::Allow),
4977 always_allow: vec![],
4978 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
4979 always_confirm: vec![],
4980 invalid_patterns: vec![],
4981 },
4982 );
4983 agent_settings::AgentSettings::override_global(settings, cx);
4984 });
4985
4986 #[allow(clippy::arc_with_non_send_sync)]
4987 let tool = Arc::new(crate::MovePathTool::new(project));
4988 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4989
4990 let task = cx.update(|cx| {
4991 tool.run(
4992 crate::MovePathToolInput {
4993 source_path: "root/safe.txt".to_string(),
4994 destination_path: "root/protected/safe.txt".to_string(),
4995 },
4996 event_stream,
4997 cx,
4998 )
4999 });
5000
5001 let result = task.await;
5002 assert!(
5003 result.is_err(),
5004 "expected move to be blocked due to destination path"
5005 );
5006 assert!(
5007 result.unwrap_err().to_string().contains("blocked"),
5008 "error should mention the move was blocked"
5009 );
5010}
5011
5012#[gpui::test]
5013async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5014 init_test(cx);
5015
5016 let fs = FakeFs::new(cx.executor());
5017 fs.insert_tree(
5018 "/root",
5019 json!({
5020 "secret.txt": "secret content",
5021 "public": {}
5022 }),
5023 )
5024 .await;
5025 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5026
5027 cx.update(|cx| {
5028 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5029 settings.tool_permissions.tools.insert(
5030 MovePathTool::NAME.into(),
5031 agent_settings::ToolRules {
5032 default: Some(settings::ToolPermissionMode::Allow),
5033 always_allow: vec![],
5034 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5035 always_confirm: vec![],
5036 invalid_patterns: vec![],
5037 },
5038 );
5039 agent_settings::AgentSettings::override_global(settings, cx);
5040 });
5041
5042 #[allow(clippy::arc_with_non_send_sync)]
5043 let tool = Arc::new(crate::MovePathTool::new(project));
5044 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5045
5046 let task = cx.update(|cx| {
5047 tool.run(
5048 crate::MovePathToolInput {
5049 source_path: "root/secret.txt".to_string(),
5050 destination_path: "root/public/not_secret.txt".to_string(),
5051 },
5052 event_stream,
5053 cx,
5054 )
5055 });
5056
5057 let result = task.await;
5058 assert!(
5059 result.is_err(),
5060 "expected move to be blocked due to source path"
5061 );
5062 assert!(
5063 result.unwrap_err().to_string().contains("blocked"),
5064 "error should mention the move was blocked"
5065 );
5066}
5067
5068#[gpui::test]
5069async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5070 init_test(cx);
5071
5072 let fs = FakeFs::new(cx.executor());
5073 fs.insert_tree(
5074 "/root",
5075 json!({
5076 "confidential.txt": "confidential data",
5077 "dest": {}
5078 }),
5079 )
5080 .await;
5081 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5082
5083 cx.update(|cx| {
5084 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5085 settings.tool_permissions.tools.insert(
5086 CopyPathTool::NAME.into(),
5087 agent_settings::ToolRules {
5088 default: Some(settings::ToolPermissionMode::Allow),
5089 always_allow: vec![],
5090 always_deny: vec![
5091 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5092 ],
5093 always_confirm: vec![],
5094 invalid_patterns: vec![],
5095 },
5096 );
5097 agent_settings::AgentSettings::override_global(settings, cx);
5098 });
5099
5100 #[allow(clippy::arc_with_non_send_sync)]
5101 let tool = Arc::new(crate::CopyPathTool::new(project));
5102 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5103
5104 let task = cx.update(|cx| {
5105 tool.run(
5106 crate::CopyPathToolInput {
5107 source_path: "root/confidential.txt".to_string(),
5108 destination_path: "root/dest/copy.txt".to_string(),
5109 },
5110 event_stream,
5111 cx,
5112 )
5113 });
5114
5115 let result = task.await;
5116 assert!(result.is_err(), "expected copy to be blocked");
5117 assert!(
5118 result.unwrap_err().to_string().contains("blocked"),
5119 "error should mention the copy was blocked"
5120 );
5121}
5122
5123#[gpui::test]
5124async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5125 init_test(cx);
5126
5127 let fs = FakeFs::new(cx.executor());
5128 fs.insert_tree(
5129 "/root",
5130 json!({
5131 "normal.txt": "normal content",
5132 "readonly": {
5133 "config.txt": "readonly content"
5134 }
5135 }),
5136 )
5137 .await;
5138 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5139
5140 cx.update(|cx| {
5141 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5142 settings.tool_permissions.tools.insert(
5143 SaveFileTool::NAME.into(),
5144 agent_settings::ToolRules {
5145 default: Some(settings::ToolPermissionMode::Allow),
5146 always_allow: vec![],
5147 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5148 always_confirm: vec![],
5149 invalid_patterns: vec![],
5150 },
5151 );
5152 agent_settings::AgentSettings::override_global(settings, cx);
5153 });
5154
5155 #[allow(clippy::arc_with_non_send_sync)]
5156 let tool = Arc::new(crate::SaveFileTool::new(project));
5157 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5158
5159 let task = cx.update(|cx| {
5160 tool.run(
5161 crate::SaveFileToolInput {
5162 paths: vec![
5163 std::path::PathBuf::from("root/normal.txt"),
5164 std::path::PathBuf::from("root/readonly/config.txt"),
5165 ],
5166 },
5167 event_stream,
5168 cx,
5169 )
5170 });
5171
5172 let result = task.await;
5173 assert!(
5174 result.is_err(),
5175 "expected save to be blocked due to denied path"
5176 );
5177 assert!(
5178 result.unwrap_err().to_string().contains("blocked"),
5179 "error should mention the save was blocked"
5180 );
5181}
5182
5183#[gpui::test]
5184async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5185 init_test(cx);
5186
5187 let fs = FakeFs::new(cx.executor());
5188 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5189 .await;
5190 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5191
5192 cx.update(|cx| {
5193 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5194 settings.tool_permissions.tools.insert(
5195 SaveFileTool::NAME.into(),
5196 agent_settings::ToolRules {
5197 default: Some(settings::ToolPermissionMode::Allow),
5198 always_allow: vec![],
5199 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5200 always_confirm: vec![],
5201 invalid_patterns: vec![],
5202 },
5203 );
5204 agent_settings::AgentSettings::override_global(settings, cx);
5205 });
5206
5207 #[allow(clippy::arc_with_non_send_sync)]
5208 let tool = Arc::new(crate::SaveFileTool::new(project));
5209 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5210
5211 let task = cx.update(|cx| {
5212 tool.run(
5213 crate::SaveFileToolInput {
5214 paths: vec![std::path::PathBuf::from("root/config.secret")],
5215 },
5216 event_stream,
5217 cx,
5218 )
5219 });
5220
5221 let result = task.await;
5222 assert!(result.is_err(), "expected save to be blocked");
5223 assert!(
5224 result.unwrap_err().to_string().contains("blocked"),
5225 "error should mention the save was blocked"
5226 );
5227}
5228
5229#[gpui::test]
5230async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5231 init_test(cx);
5232
5233 cx.update(|cx| {
5234 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5235 settings.tool_permissions.tools.insert(
5236 WebSearchTool::NAME.into(),
5237 agent_settings::ToolRules {
5238 default: Some(settings::ToolPermissionMode::Allow),
5239 always_allow: vec![],
5240 always_deny: vec![
5241 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5242 ],
5243 always_confirm: vec![],
5244 invalid_patterns: vec![],
5245 },
5246 );
5247 agent_settings::AgentSettings::override_global(settings, cx);
5248 });
5249
5250 #[allow(clippy::arc_with_non_send_sync)]
5251 let tool = Arc::new(crate::WebSearchTool);
5252 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5253
5254 let input: crate::WebSearchToolInput =
5255 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5256
5257 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5258
5259 let result = task.await;
5260 assert!(result.is_err(), "expected search to be blocked");
5261 assert!(
5262 result.unwrap_err().to_string().contains("blocked"),
5263 "error should mention the search was blocked"
5264 );
5265}
5266
5267#[gpui::test]
5268async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5269 init_test(cx);
5270
5271 let fs = FakeFs::new(cx.executor());
5272 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5273 .await;
5274 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5275
5276 cx.update(|cx| {
5277 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5278 settings.tool_permissions.tools.insert(
5279 EditFileTool::NAME.into(),
5280 agent_settings::ToolRules {
5281 default: Some(settings::ToolPermissionMode::Confirm),
5282 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5283 always_deny: vec![],
5284 always_confirm: vec![],
5285 invalid_patterns: vec![],
5286 },
5287 );
5288 agent_settings::AgentSettings::override_global(settings, cx);
5289 });
5290
5291 let context_server_registry =
5292 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5293 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5294 let templates = crate::Templates::new();
5295 let thread = cx.new(|cx| {
5296 crate::Thread::new(
5297 project.clone(),
5298 cx.new(|_cx| prompt_store::ProjectContext::default()),
5299 context_server_registry,
5300 templates.clone(),
5301 None,
5302 cx,
5303 )
5304 });
5305
5306 #[allow(clippy::arc_with_non_send_sync)]
5307 let tool = Arc::new(crate::EditFileTool::new(
5308 project,
5309 thread.downgrade(),
5310 language_registry,
5311 templates,
5312 ));
5313 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5314
5315 let _task = cx.update(|cx| {
5316 tool.run(
5317 crate::EditFileToolInput {
5318 display_description: "Edit README".to_string(),
5319 path: "root/README.md".into(),
5320 mode: crate::EditFileMode::Edit,
5321 },
5322 event_stream,
5323 cx,
5324 )
5325 });
5326
5327 cx.run_until_parked();
5328
5329 let event = rx.try_next();
5330 assert!(
5331 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5332 "expected no authorization request for allowed .md file"
5333 );
5334}
5335
5336#[gpui::test]
5337async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5338 init_test(cx);
5339
5340 let fs = FakeFs::new(cx.executor());
5341 fs.insert_tree(
5342 "/root",
5343 json!({
5344 ".zed": {
5345 "settings.json": "{}"
5346 },
5347 "README.md": "# Hello"
5348 }),
5349 )
5350 .await;
5351 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5352
5353 cx.update(|cx| {
5354 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5355 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5356 agent_settings::AgentSettings::override_global(settings, cx);
5357 });
5358
5359 let context_server_registry =
5360 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5361 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5362 let templates = crate::Templates::new();
5363 let thread = cx.new(|cx| {
5364 crate::Thread::new(
5365 project.clone(),
5366 cx.new(|_cx| prompt_store::ProjectContext::default()),
5367 context_server_registry,
5368 templates.clone(),
5369 None,
5370 cx,
5371 )
5372 });
5373
5374 #[allow(clippy::arc_with_non_send_sync)]
5375 let tool = Arc::new(crate::EditFileTool::new(
5376 project,
5377 thread.downgrade(),
5378 language_registry,
5379 templates,
5380 ));
5381
5382 // Editing a file inside .zed/ should still prompt even with global default: allow,
5383 // because local settings paths are sensitive and require confirmation regardless.
5384 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5385 let _task = cx.update(|cx| {
5386 tool.run(
5387 crate::EditFileToolInput {
5388 display_description: "Edit local settings".to_string(),
5389 path: "root/.zed/settings.json".into(),
5390 mode: crate::EditFileMode::Edit,
5391 },
5392 event_stream,
5393 cx,
5394 )
5395 });
5396
5397 let _update = rx.expect_update_fields().await;
5398 let _auth = rx.expect_authorization().await;
5399}
5400
5401#[gpui::test]
5402async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5403 init_test(cx);
5404
5405 cx.update(|cx| {
5406 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5407 settings.tool_permissions.tools.insert(
5408 FetchTool::NAME.into(),
5409 agent_settings::ToolRules {
5410 default: Some(settings::ToolPermissionMode::Allow),
5411 always_allow: vec![],
5412 always_deny: vec![
5413 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5414 ],
5415 always_confirm: vec![],
5416 invalid_patterns: vec![],
5417 },
5418 );
5419 agent_settings::AgentSettings::override_global(settings, cx);
5420 });
5421
5422 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5423
5424 #[allow(clippy::arc_with_non_send_sync)]
5425 let tool = Arc::new(crate::FetchTool::new(http_client));
5426 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5427
5428 let input: crate::FetchToolInput =
5429 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5430
5431 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5432
5433 let result = task.await;
5434 assert!(result.is_err(), "expected fetch to be blocked");
5435 assert!(
5436 result.unwrap_err().to_string().contains("blocked"),
5437 "error should mention the fetch was blocked"
5438 );
5439}
5440
5441#[gpui::test]
5442async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5443 init_test(cx);
5444
5445 cx.update(|cx| {
5446 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5447 settings.tool_permissions.tools.insert(
5448 FetchTool::NAME.into(),
5449 agent_settings::ToolRules {
5450 default: Some(settings::ToolPermissionMode::Confirm),
5451 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5452 always_deny: vec![],
5453 always_confirm: vec![],
5454 invalid_patterns: vec![],
5455 },
5456 );
5457 agent_settings::AgentSettings::override_global(settings, cx);
5458 });
5459
5460 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5461
5462 #[allow(clippy::arc_with_non_send_sync)]
5463 let tool = Arc::new(crate::FetchTool::new(http_client));
5464 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5465
5466 let input: crate::FetchToolInput =
5467 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5468
5469 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5470
5471 cx.run_until_parked();
5472
5473 let event = rx.try_next();
5474 assert!(
5475 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5476 "expected no authorization request for allowed docs.rs URL"
5477 );
5478}
5479
5480#[gpui::test]
5481async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5482 init_test(cx);
5483 always_allow_tools(cx);
5484
5485 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5486 let fake_model = model.as_fake();
5487
5488 // Add a tool so we can simulate tool calls
5489 thread.update(cx, |thread, _cx| {
5490 thread.add_tool(EchoTool, None);
5491 });
5492
5493 // Start a turn by sending a message
5494 let mut events = thread
5495 .update(cx, |thread, cx| {
5496 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5497 })
5498 .unwrap();
5499 cx.run_until_parked();
5500
5501 // Simulate the model making a tool call
5502 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5503 LanguageModelToolUse {
5504 id: "tool_1".into(),
5505 name: "echo".into(),
5506 raw_input: r#"{"text": "hello"}"#.into(),
5507 input: json!({"text": "hello"}),
5508 is_input_complete: true,
5509 thought_signature: None,
5510 },
5511 ));
5512 fake_model
5513 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5514
5515 // Signal that a message is queued before ending the stream
5516 thread.update(cx, |thread, _cx| {
5517 thread.set_has_queued_message(true);
5518 });
5519
5520 // Now end the stream - tool will run, and the boundary check should see the queue
5521 fake_model.end_last_completion_stream();
5522
5523 // Collect all events until the turn stops
5524 let all_events = collect_events_until_stop(&mut events, cx).await;
5525
5526 // Verify we received the tool call event
5527 let tool_call_ids: Vec<_> = all_events
5528 .iter()
5529 .filter_map(|e| match e {
5530 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5531 _ => None,
5532 })
5533 .collect();
5534 assert_eq!(
5535 tool_call_ids,
5536 vec!["tool_1"],
5537 "Should have received a tool call event for our echo tool"
5538 );
5539
5540 // The turn should have stopped with EndTurn
5541 let stop_reasons = stop_events(all_events);
5542 assert_eq!(
5543 stop_reasons,
5544 vec![acp::StopReason::EndTurn],
5545 "Turn should have ended after tool completion due to queued message"
5546 );
5547
5548 // Verify the queued message flag is still set
5549 thread.update(cx, |thread, _cx| {
5550 assert!(
5551 thread.has_queued_message(),
5552 "Should still have queued message flag set"
5553 );
5554 });
5555
5556 // Thread should be idle now
5557 thread.update(cx, |thread, _cx| {
5558 assert!(
5559 thread.is_turn_complete(),
5560 "Thread should not be running after turn ends"
5561 );
5562 });
5563}