1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeSubagentHandle {
159 session_id: acp::SessionId,
160 wait_for_summary_task: Shared<Task<String>>,
161}
162
163impl FakeSubagentHandle {
164 fn new_never_completes(cx: &App) -> Self {
165 Self {
166 session_id: acp::SessionId::new("subagent-id"),
167 wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
168 }
169 }
170}
171
172impl SubagentHandle for FakeSubagentHandle {
173 fn id(&self) -> acp::SessionId {
174 self.session_id.clone()
175 }
176
177 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
178 let task = self.wait_for_summary_task.clone();
179 cx.background_spawn(async move { Ok(task.await) })
180 }
181}
182
183#[derive(Default)]
184struct FakeThreadEnvironment {
185 terminal_handle: Option<Rc<FakeTerminalHandle>>,
186 subagent_handle: Option<Rc<FakeSubagentHandle>>,
187}
188
189impl FakeThreadEnvironment {
190 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
191 Self {
192 terminal_handle: Some(terminal_handle.into()),
193 ..self
194 }
195 }
196
197 pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
198 Self {
199 subagent_handle: Some(subagent_handle.into()),
200 ..self
201 }
202 }
203}
204
205impl crate::ThreadEnvironment for FakeThreadEnvironment {
206 fn create_terminal(
207 &self,
208 _command: String,
209 _cwd: Option<std::path::PathBuf>,
210 _output_byte_limit: Option<u64>,
211 _cx: &mut AsyncApp,
212 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
213 let handle = self
214 .terminal_handle
215 .clone()
216 .expect("Terminal handle not available on FakeThreadEnvironment");
217 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
218 }
219
220 fn create_subagent(
221 &self,
222 _parent_thread: Entity<Thread>,
223 _label: String,
224 _initial_prompt: String,
225 _timeout_ms: Option<Duration>,
226 _allowed_tools: Option<Vec<String>>,
227 _cx: &mut App,
228 ) -> Result<Rc<dyn SubagentHandle>> {
229 Ok(self
230 .subagent_handle
231 .clone()
232 .expect("Subagent handle not available on FakeThreadEnvironment")
233 as Rc<dyn SubagentHandle>)
234 }
235}
236
237/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
238struct MultiTerminalEnvironment {
239 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
240}
241
242impl MultiTerminalEnvironment {
243 fn new() -> Self {
244 Self {
245 handles: std::cell::RefCell::new(Vec::new()),
246 }
247 }
248
249 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
250 self.handles.borrow().clone()
251 }
252}
253
254impl crate::ThreadEnvironment for MultiTerminalEnvironment {
255 fn create_terminal(
256 &self,
257 _command: String,
258 _cwd: Option<std::path::PathBuf>,
259 _output_byte_limit: Option<u64>,
260 cx: &mut AsyncApp,
261 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
263 self.handles.borrow_mut().push(handle.clone());
264 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
265 }
266
267 fn create_subagent(
268 &self,
269 _parent_thread: Entity<Thread>,
270 _label: String,
271 _initial_prompt: String,
272 _timeout: Option<Duration>,
273 _allowed_tools: Option<Vec<String>>,
274 _cx: &mut App,
275 ) -> Result<Rc<dyn SubagentHandle>> {
276 unimplemented!()
277 }
278}
279
280fn always_allow_tools(cx: &mut TestAppContext) {
281 cx.update(|cx| {
282 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
283 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
284 agent_settings::AgentSettings::override_global(settings, cx);
285 });
286}
287
288#[gpui::test]
289async fn test_echo(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
296 })
297 .unwrap();
298 cx.run_until_parked();
299 fake_model.send_last_completion_stream_text_chunk("Hello");
300 fake_model
301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
302 fake_model.end_last_completion_stream();
303
304 let events = events.collect().await;
305 thread.update(cx, |thread, _cx| {
306 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
307 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
308 });
309 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
310}
311
312#[gpui::test]
313async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
314 init_test(cx);
315 always_allow_tools(cx);
316
317 let fs = FakeFs::new(cx.executor());
318 let project = Project::test(fs, [], cx).await;
319
320 let environment = Rc::new(cx.update(|cx| {
321 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
322 }));
323 let handle = environment.terminal_handle.clone().unwrap();
324
325 #[allow(clippy::arc_with_non_send_sync)]
326 let tool = Arc::new(crate::TerminalTool::new(project, environment));
327 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
328
329 let task = cx.update(|cx| {
330 tool.run(
331 crate::TerminalToolInput {
332 command: "sleep 1000".to_string(),
333 cd: ".".to_string(),
334 timeout_ms: Some(5),
335 },
336 event_stream,
337 cx,
338 )
339 });
340
341 let update = rx.expect_update_fields().await;
342 assert!(
343 update.content.iter().any(|blocks| {
344 blocks
345 .iter()
346 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
347 }),
348 "expected tool call update to include terminal content"
349 );
350
351 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
352
353 let deadline = std::time::Instant::now() + Duration::from_millis(500);
354 loop {
355 if let Some(result) = task_future.as_mut().now_or_never() {
356 let result = result.expect("terminal tool task should complete");
357
358 assert!(
359 handle.was_killed(),
360 "expected terminal handle to be killed on timeout"
361 );
362 assert!(
363 result.contains("partial output"),
364 "expected result to include terminal output, got: {result}"
365 );
366 return;
367 }
368
369 if std::time::Instant::now() >= deadline {
370 panic!("timed out waiting for terminal tool task to complete");
371 }
372
373 cx.run_until_parked();
374 cx.background_executor.timer(Duration::from_millis(1)).await;
375 }
376}
377
378#[gpui::test]
379#[ignore]
380async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
381 init_test(cx);
382 always_allow_tools(cx);
383
384 let fs = FakeFs::new(cx.executor());
385 let project = Project::test(fs, [], cx).await;
386
387 let environment = Rc::new(cx.update(|cx| {
388 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
389 }));
390 let handle = environment.terminal_handle.clone().unwrap();
391
392 #[allow(clippy::arc_with_non_send_sync)]
393 let tool = Arc::new(crate::TerminalTool::new(project, environment));
394 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
395
396 let _task = cx.update(|cx| {
397 tool.run(
398 crate::TerminalToolInput {
399 command: "sleep 1000".to_string(),
400 cd: ".".to_string(),
401 timeout_ms: None,
402 },
403 event_stream,
404 cx,
405 )
406 });
407
408 let update = rx.expect_update_fields().await;
409 assert!(
410 update.content.iter().any(|blocks| {
411 blocks
412 .iter()
413 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
414 }),
415 "expected tool call update to include terminal content"
416 );
417
418 cx.background_executor
419 .timer(Duration::from_millis(25))
420 .await;
421
422 assert!(
423 !handle.was_killed(),
424 "did not expect terminal handle to be killed without a timeout"
425 );
426}
427
428#[gpui::test]
429async fn test_thinking(cx: &mut TestAppContext) {
430 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
431 let fake_model = model.as_fake();
432
433 let events = thread
434 .update(cx, |thread, cx| {
435 thread.send(
436 UserMessageId::new(),
437 [indoc! {"
438 Testing:
439
440 Generate a thinking step where you just think the word 'Think',
441 and have your final answer be 'Hello'
442 "}],
443 cx,
444 )
445 })
446 .unwrap();
447 cx.run_until_parked();
448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
449 text: "Think".to_string(),
450 signature: None,
451 });
452 fake_model.send_last_completion_stream_text_chunk("Hello");
453 fake_model
454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
455 fake_model.end_last_completion_stream();
456
457 let events = events.collect().await;
458 thread.update(cx, |thread, _cx| {
459 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
460 assert_eq!(
461 thread.last_message().unwrap().to_markdown(),
462 indoc! {"
463 <think>Think</think>
464 Hello
465 "}
466 )
467 });
468 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
469}
470
471#[gpui::test]
472async fn test_system_prompt(cx: &mut TestAppContext) {
473 let ThreadTest {
474 model,
475 thread,
476 project_context,
477 ..
478 } = setup(cx, TestModel::Fake).await;
479 let fake_model = model.as_fake();
480
481 project_context.update(cx, |project_context, _cx| {
482 project_context.shell = "test-shell".into()
483 });
484 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["abc"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491 let mut pending_completions = fake_model.pending_completions();
492 assert_eq!(
493 pending_completions.len(),
494 1,
495 "unexpected pending completions: {:?}",
496 pending_completions
497 );
498
499 let pending_completion = pending_completions.pop().unwrap();
500 assert_eq!(pending_completion.messages[0].role, Role::System);
501
502 let system_message = &pending_completion.messages[0];
503 let system_prompt = system_message.content[0].to_str().unwrap();
504 assert!(
505 system_prompt.contains("test-shell"),
506 "unexpected system message: {:?}",
507 system_message
508 );
509 assert!(
510 system_prompt.contains("## Fixing Diagnostics"),
511 "unexpected system message: {:?}",
512 system_message
513 );
514}
515
516#[gpui::test]
517async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
518 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
519 let fake_model = model.as_fake();
520
521 thread
522 .update(cx, |thread, cx| {
523 thread.send(UserMessageId::new(), ["abc"], cx)
524 })
525 .unwrap();
526 cx.run_until_parked();
527 let mut pending_completions = fake_model.pending_completions();
528 assert_eq!(
529 pending_completions.len(),
530 1,
531 "unexpected pending completions: {:?}",
532 pending_completions
533 );
534
535 let pending_completion = pending_completions.pop().unwrap();
536 assert_eq!(pending_completion.messages[0].role, Role::System);
537
538 let system_message = &pending_completion.messages[0];
539 let system_prompt = system_message.content[0].to_str().unwrap();
540 assert!(
541 !system_prompt.contains("## Tool Use"),
542 "unexpected system message: {:?}",
543 system_message
544 );
545 assert!(
546 !system_prompt.contains("## Fixing Diagnostics"),
547 "unexpected system message: {:?}",
548 system_message
549 );
550}
551
552#[gpui::test]
553async fn test_prompt_caching(cx: &mut TestAppContext) {
554 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
555 let fake_model = model.as_fake();
556
557 // Send initial user message and verify it's cached
558 thread
559 .update(cx, |thread, cx| {
560 thread.send(UserMessageId::new(), ["Message 1"], cx)
561 })
562 .unwrap();
563 cx.run_until_parked();
564
565 let completion = fake_model.pending_completions().pop().unwrap();
566 assert_eq!(
567 completion.messages[1..],
568 vec![LanguageModelRequestMessage {
569 role: Role::User,
570 content: vec!["Message 1".into()],
571 cache: true,
572 reasoning_details: None,
573 }]
574 );
575 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
576 "Response to Message 1".into(),
577 ));
578 fake_model.end_last_completion_stream();
579 cx.run_until_parked();
580
581 // Send another user message and verify only the latest is cached
582 thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["Message 2"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588
589 let completion = fake_model.pending_completions().pop().unwrap();
590 assert_eq!(
591 completion.messages[1..],
592 vec![
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 1".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 1".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Message 2".into()],
608 cache: true,
609 reasoning_details: None,
610 }
611 ]
612 );
613 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
614 "Response to Message 2".into(),
615 ));
616 fake_model.end_last_completion_stream();
617 cx.run_until_parked();
618
619 // Simulate a tool call and verify that the latest tool result is cached
620 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
621 thread
622 .update(cx, |thread, cx| {
623 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
624 })
625 .unwrap();
626 cx.run_until_parked();
627
628 let tool_use = LanguageModelToolUse {
629 id: "tool_1".into(),
630 name: EchoTool::NAME.into(),
631 raw_input: json!({"text": "test"}).to_string(),
632 input: json!({"text": "test"}),
633 is_input_complete: true,
634 thought_signature: None,
635 };
636 fake_model
637 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
638 fake_model.end_last_completion_stream();
639 cx.run_until_parked();
640
641 let completion = fake_model.pending_completions().pop().unwrap();
642 let tool_result = LanguageModelToolResult {
643 tool_use_id: "tool_1".into(),
644 tool_name: EchoTool::NAME.into(),
645 is_error: false,
646 content: "test".into(),
647 output: Some("test".into()),
648 };
649 assert_eq!(
650 completion.messages[1..],
651 vec![
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec!["Message 1".into()],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::Assistant,
660 content: vec!["Response to Message 1".into()],
661 cache: false,
662 reasoning_details: None,
663 },
664 LanguageModelRequestMessage {
665 role: Role::User,
666 content: vec!["Message 2".into()],
667 cache: false,
668 reasoning_details: None,
669 },
670 LanguageModelRequestMessage {
671 role: Role::Assistant,
672 content: vec!["Response to Message 2".into()],
673 cache: false,
674 reasoning_details: None,
675 },
676 LanguageModelRequestMessage {
677 role: Role::User,
678 content: vec!["Use the echo tool".into()],
679 cache: false,
680 reasoning_details: None,
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false,
686 reasoning_details: None,
687 },
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec![MessageContent::ToolResult(tool_result)],
691 cache: true,
692 reasoning_details: None,
693 }
694 ]
695 );
696}
697
698#[gpui::test]
699#[cfg_attr(not(feature = "e2e"), ignore)]
700async fn test_basic_tool_calls(cx: &mut TestAppContext) {
701 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
702
703 // Test a tool call that's likely to complete *before* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.add_tool(EchoTool, None);
707 thread.send(
708 UserMessageId::new(),
709 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
710 cx,
711 )
712 })
713 .unwrap()
714 .collect()
715 .await;
716 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
717
718 // Test a tool calls that's likely to complete *after* streaming stops.
719 let events = thread
720 .update(cx, |thread, cx| {
721 thread.remove_tool(&EchoTool::NAME);
722 thread.add_tool(DelayTool, None);
723 thread.send(
724 UserMessageId::new(),
725 [
726 "Now call the delay tool with 200ms.",
727 "When the timer goes off, then you echo the output of the tool.",
728 ],
729 cx,
730 )
731 })
732 .unwrap()
733 .collect()
734 .await;
735 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
736 thread.update(cx, |thread, _cx| {
737 assert!(
738 thread
739 .last_message()
740 .unwrap()
741 .as_agent_message()
742 .unwrap()
743 .content
744 .iter()
745 .any(|content| {
746 if let AgentMessageContent::Text(text) = content {
747 text.contains("Ding")
748 } else {
749 false
750 }
751 }),
752 "{}",
753 thread.to_markdown()
754 );
755 });
756}
757
758#[gpui::test]
759#[cfg_attr(not(feature = "e2e"), ignore)]
760async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
761 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
762
763 // Test a tool call that's likely to complete *before* streaming stops.
764 let mut events = thread
765 .update(cx, |thread, cx| {
766 thread.add_tool(WordListTool, None);
767 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
768 })
769 .unwrap();
770
771 let mut saw_partial_tool_use = false;
772 while let Some(event) = events.next().await {
773 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
774 thread.update(cx, |thread, _cx| {
775 // Look for a tool use in the thread's last message
776 let message = thread.last_message().unwrap();
777 let agent_message = message.as_agent_message().unwrap();
778 let last_content = agent_message.content.last().unwrap();
779 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
780 assert_eq!(last_tool_use.name.as_ref(), "word_list");
781 if tool_call.status == acp::ToolCallStatus::Pending {
782 if !last_tool_use.is_input_complete
783 && last_tool_use.input.get("g").is_none()
784 {
785 saw_partial_tool_use = true;
786 }
787 } else {
788 last_tool_use
789 .input
790 .get("a")
791 .expect("'a' has streamed because input is now complete");
792 last_tool_use
793 .input
794 .get("g")
795 .expect("'g' has streamed because input is now complete");
796 }
797 } else {
798 panic!("last content should be a tool use");
799 }
800 });
801 }
802 }
803
804 assert!(
805 saw_partial_tool_use,
806 "should see at least one partially streamed tool use in the history"
807 );
808}
809
810#[gpui::test]
811async fn test_tool_authorization(cx: &mut TestAppContext) {
812 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
813 let fake_model = model.as_fake();
814
815 let mut events = thread
816 .update(cx, |thread, cx| {
817 thread.add_tool(ToolRequiringPermission, None);
818 thread.send(UserMessageId::new(), ["abc"], cx)
819 })
820 .unwrap();
821 cx.run_until_parked();
822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
823 LanguageModelToolUse {
824 id: "tool_id_1".into(),
825 name: ToolRequiringPermission::NAME.into(),
826 raw_input: "{}".into(),
827 input: json!({}),
828 is_input_complete: true,
829 thought_signature: None,
830 },
831 ));
832 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
833 LanguageModelToolUse {
834 id: "tool_id_2".into(),
835 name: ToolRequiringPermission::NAME.into(),
836 raw_input: "{}".into(),
837 input: json!({}),
838 is_input_complete: true,
839 thought_signature: None,
840 },
841 ));
842 fake_model.end_last_completion_stream();
843 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
844 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
845
846 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
847 tool_call_auth_1
848 .response
849 .send(acp::PermissionOptionId::new("allow"))
850 .unwrap();
851 cx.run_until_parked();
852
853 // Reject the second - send "deny" option_id directly since Deny is now a button
854 tool_call_auth_2
855 .response
856 .send(acp::PermissionOptionId::new("deny"))
857 .unwrap();
858 cx.run_until_parked();
859
860 let completion = fake_model.pending_completions().pop().unwrap();
861 let message = completion.messages.last().unwrap();
862 assert_eq!(
863 message.content,
864 vec![
865 language_model::MessageContent::ToolResult(LanguageModelToolResult {
866 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
867 tool_name: ToolRequiringPermission::NAME.into(),
868 is_error: false,
869 content: "Allowed".into(),
870 output: Some("Allowed".into())
871 }),
872 language_model::MessageContent::ToolResult(LanguageModelToolResult {
873 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
874 tool_name: ToolRequiringPermission::NAME.into(),
875 is_error: true,
876 content: "Permission to run tool denied by user".into(),
877 output: Some("Permission to run tool denied by user".into())
878 })
879 ]
880 );
881
882 // Simulate yet another tool call.
883 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
884 LanguageModelToolUse {
885 id: "tool_id_3".into(),
886 name: ToolRequiringPermission::NAME.into(),
887 raw_input: "{}".into(),
888 input: json!({}),
889 is_input_complete: true,
890 thought_signature: None,
891 },
892 ));
893 fake_model.end_last_completion_stream();
894
895 // Respond by always allowing tools - send transformed option_id
896 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
897 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
898 tool_call_auth_3
899 .response
900 .send(acp::PermissionOptionId::new(
901 "always_allow:tool_requiring_permission",
902 ))
903 .unwrap();
904 cx.run_until_parked();
905 let completion = fake_model.pending_completions().pop().unwrap();
906 let message = completion.messages.last().unwrap();
907 assert_eq!(
908 message.content,
909 vec![language_model::MessageContent::ToolResult(
910 LanguageModelToolResult {
911 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
912 tool_name: ToolRequiringPermission::NAME.into(),
913 is_error: false,
914 content: "Allowed".into(),
915 output: Some("Allowed".into())
916 }
917 )]
918 );
919
920 // Simulate a final tool call, ensuring we don't trigger authorization.
921 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
922 LanguageModelToolUse {
923 id: "tool_id_4".into(),
924 name: ToolRequiringPermission::NAME.into(),
925 raw_input: "{}".into(),
926 input: json!({}),
927 is_input_complete: true,
928 thought_signature: None,
929 },
930 ));
931 fake_model.end_last_completion_stream();
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let message = completion.messages.last().unwrap();
935 assert_eq!(
936 message.content,
937 vec![language_model::MessageContent::ToolResult(
938 LanguageModelToolResult {
939 tool_use_id: "tool_id_4".into(),
940 tool_name: ToolRequiringPermission::NAME.into(),
941 is_error: false,
942 content: "Allowed".into(),
943 output: Some("Allowed".into())
944 }
945 )]
946 );
947}
948
949#[gpui::test]
950async fn test_tool_hallucination(cx: &mut TestAppContext) {
951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
952 let fake_model = model.as_fake();
953
954 let mut events = thread
955 .update(cx, |thread, cx| {
956 thread.send(UserMessageId::new(), ["abc"], cx)
957 })
958 .unwrap();
959 cx.run_until_parked();
960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
961 LanguageModelToolUse {
962 id: "tool_id_1".into(),
963 name: "nonexistent_tool".into(),
964 raw_input: "{}".into(),
965 input: json!({}),
966 is_input_complete: true,
967 thought_signature: None,
968 },
969 ));
970 fake_model.end_last_completion_stream();
971
972 let tool_call = expect_tool_call(&mut events).await;
973 assert_eq!(tool_call.title, "nonexistent_tool");
974 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
975 let update = expect_tool_call_update_fields(&mut events).await;
976 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
977}
978
979async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
980 let event = events
981 .next()
982 .await
983 .expect("no tool call authorization event received")
984 .unwrap();
985 match event {
986 ThreadEvent::ToolCall(tool_call) => tool_call,
987 event => {
988 panic!("Unexpected event {event:?}");
989 }
990 }
991}
992
993async fn expect_tool_call_update_fields(
994 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
995) -> acp::ToolCallUpdate {
996 let event = events
997 .next()
998 .await
999 .expect("no tool call authorization event received")
1000 .unwrap();
1001 match event {
1002 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1003 event => {
1004 panic!("Unexpected event {event:?}");
1005 }
1006 }
1007}
1008
1009async fn next_tool_call_authorization(
1010 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1011) -> ToolCallAuthorization {
1012 loop {
1013 let event = events
1014 .next()
1015 .await
1016 .expect("no tool call authorization event received")
1017 .unwrap();
1018 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1019 let permission_kinds = tool_call_authorization
1020 .options
1021 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1022 .map(|option| option.kind);
1023 let allow_once = tool_call_authorization
1024 .options
1025 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1026 .map(|option| option.kind);
1027
1028 assert_eq!(
1029 permission_kinds,
1030 Some(acp::PermissionOptionKind::AllowAlways)
1031 );
1032 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1033 return tool_call_authorization;
1034 }
1035 }
1036}
1037
1038#[test]
1039fn test_permission_options_terminal_with_pattern() {
1040 let permission_options = ToolPermissionContext::new(
1041 TerminalTool::NAME,
1042 vec!["cargo build --release".to_string()],
1043 )
1044 .build_permission_options();
1045
1046 let PermissionOptions::Dropdown(choices) = permission_options else {
1047 panic!("Expected dropdown permission options");
1048 };
1049
1050 assert_eq!(choices.len(), 3);
1051 let labels: Vec<&str> = choices
1052 .iter()
1053 .map(|choice| choice.allow.name.as_ref())
1054 .collect();
1055 assert!(labels.contains(&"Always for terminal"));
1056 assert!(labels.contains(&"Always for `cargo build` commands"));
1057 assert!(labels.contains(&"Only this time"));
1058}
1059
1060#[test]
1061fn test_permission_options_terminal_command_with_flag_second_token() {
1062 let permission_options =
1063 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1064 .build_permission_options();
1065
1066 let PermissionOptions::Dropdown(choices) = permission_options else {
1067 panic!("Expected dropdown permission options");
1068 };
1069
1070 assert_eq!(choices.len(), 3);
1071 let labels: Vec<&str> = choices
1072 .iter()
1073 .map(|choice| choice.allow.name.as_ref())
1074 .collect();
1075 assert!(labels.contains(&"Always for terminal"));
1076 assert!(labels.contains(&"Always for `ls` commands"));
1077 assert!(labels.contains(&"Only this time"));
1078}
1079
1080#[test]
1081fn test_permission_options_terminal_single_word_command() {
1082 let permission_options =
1083 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1084 .build_permission_options();
1085
1086 let PermissionOptions::Dropdown(choices) = permission_options else {
1087 panic!("Expected dropdown permission options");
1088 };
1089
1090 assert_eq!(choices.len(), 3);
1091 let labels: Vec<&str> = choices
1092 .iter()
1093 .map(|choice| choice.allow.name.as_ref())
1094 .collect();
1095 assert!(labels.contains(&"Always for terminal"));
1096 assert!(labels.contains(&"Always for `whoami` commands"));
1097 assert!(labels.contains(&"Only this time"));
1098}
1099
1100#[test]
1101fn test_permission_options_edit_file_with_path_pattern() {
1102 let permission_options =
1103 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1104 .build_permission_options();
1105
1106 let PermissionOptions::Dropdown(choices) = permission_options else {
1107 panic!("Expected dropdown permission options");
1108 };
1109
1110 let labels: Vec<&str> = choices
1111 .iter()
1112 .map(|choice| choice.allow.name.as_ref())
1113 .collect();
1114 assert!(labels.contains(&"Always for edit file"));
1115 assert!(labels.contains(&"Always for `src/`"));
1116}
1117
1118#[test]
1119fn test_permission_options_fetch_with_domain_pattern() {
1120 let permission_options =
1121 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1122 .build_permission_options();
1123
1124 let PermissionOptions::Dropdown(choices) = permission_options else {
1125 panic!("Expected dropdown permission options");
1126 };
1127
1128 let labels: Vec<&str> = choices
1129 .iter()
1130 .map(|choice| choice.allow.name.as_ref())
1131 .collect();
1132 assert!(labels.contains(&"Always for fetch"));
1133 assert!(labels.contains(&"Always for `docs.rs`"));
1134}
1135
1136#[test]
1137fn test_permission_options_without_pattern() {
1138 let permission_options = ToolPermissionContext::new(
1139 TerminalTool::NAME,
1140 vec!["./deploy.sh --production".to_string()],
1141 )
1142 .build_permission_options();
1143
1144 let PermissionOptions::Dropdown(choices) = permission_options else {
1145 panic!("Expected dropdown permission options");
1146 };
1147
1148 assert_eq!(choices.len(), 2);
1149 let labels: Vec<&str> = choices
1150 .iter()
1151 .map(|choice| choice.allow.name.as_ref())
1152 .collect();
1153 assert!(labels.contains(&"Always for terminal"));
1154 assert!(labels.contains(&"Only this time"));
1155 assert!(!labels.iter().any(|label| label.contains("commands")));
1156}
1157
1158#[test]
1159fn test_permission_options_symlink_target_are_flat_once_only() {
1160 let permission_options =
1161 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1162 .build_permission_options();
1163
1164 let PermissionOptions::Flat(options) = permission_options else {
1165 panic!("Expected flat permission options for symlink target authorization");
1166 };
1167
1168 assert_eq!(options.len(), 2);
1169 assert!(options.iter().any(|option| {
1170 option.option_id.0.as_ref() == "allow"
1171 && option.kind == acp::PermissionOptionKind::AllowOnce
1172 }));
1173 assert!(options.iter().any(|option| {
1174 option.option_id.0.as_ref() == "deny"
1175 && option.kind == acp::PermissionOptionKind::RejectOnce
1176 }));
1177}
1178
1179#[test]
1180fn test_permission_option_ids_for_terminal() {
1181 let permission_options = ToolPermissionContext::new(
1182 TerminalTool::NAME,
1183 vec!["cargo build --release".to_string()],
1184 )
1185 .build_permission_options();
1186
1187 let PermissionOptions::Dropdown(choices) = permission_options else {
1188 panic!("Expected dropdown permission options");
1189 };
1190
1191 let allow_ids: Vec<String> = choices
1192 .iter()
1193 .map(|choice| choice.allow.option_id.0.to_string())
1194 .collect();
1195 let deny_ids: Vec<String> = choices
1196 .iter()
1197 .map(|choice| choice.deny.option_id.0.to_string())
1198 .collect();
1199
1200 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1201 assert!(allow_ids.contains(&"allow".to_string()));
1202 assert!(
1203 allow_ids
1204 .iter()
1205 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1206 "Missing allow pattern option"
1207 );
1208
1209 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1210 assert!(deny_ids.contains(&"deny".to_string()));
1211 assert!(
1212 deny_ids
1213 .iter()
1214 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1215 "Missing deny pattern option"
1216 );
1217}
1218
1219#[gpui::test]
1220#[cfg_attr(not(feature = "e2e"), ignore)]
1221async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1222 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1223
1224 // Test concurrent tool calls with different delay times
1225 let events = thread
1226 .update(cx, |thread, cx| {
1227 thread.add_tool(DelayTool, None);
1228 thread.send(
1229 UserMessageId::new(),
1230 [
1231 "Call the delay tool twice in the same message.",
1232 "Once with 100ms. Once with 300ms.",
1233 "When both timers are complete, describe the outputs.",
1234 ],
1235 cx,
1236 )
1237 })
1238 .unwrap()
1239 .collect()
1240 .await;
1241
1242 let stop_reasons = stop_events(events);
1243 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1244
1245 thread.update(cx, |thread, _cx| {
1246 let last_message = thread.last_message().unwrap();
1247 let agent_message = last_message.as_agent_message().unwrap();
1248 let text = agent_message
1249 .content
1250 .iter()
1251 .filter_map(|content| {
1252 if let AgentMessageContent::Text(text) = content {
1253 Some(text.as_str())
1254 } else {
1255 None
1256 }
1257 })
1258 .collect::<String>();
1259
1260 assert!(text.contains("Ding"));
1261 });
1262}
1263
1264#[gpui::test]
1265async fn test_profiles(cx: &mut TestAppContext) {
1266 let ThreadTest {
1267 model, thread, fs, ..
1268 } = setup(cx, TestModel::Fake).await;
1269 let fake_model = model.as_fake();
1270
1271 thread.update(cx, |thread, _cx| {
1272 thread.add_tool(DelayTool, None);
1273 thread.add_tool(EchoTool, None);
1274 thread.add_tool(InfiniteTool, None);
1275 });
1276
1277 // Override profiles and wait for settings to be loaded.
1278 fs.insert_file(
1279 paths::settings_file(),
1280 json!({
1281 "agent": {
1282 "profiles": {
1283 "test-1": {
1284 "name": "Test Profile 1",
1285 "tools": {
1286 EchoTool::NAME: true,
1287 DelayTool::NAME: true,
1288 }
1289 },
1290 "test-2": {
1291 "name": "Test Profile 2",
1292 "tools": {
1293 InfiniteTool::NAME: true,
1294 }
1295 }
1296 }
1297 }
1298 })
1299 .to_string()
1300 .into_bytes(),
1301 )
1302 .await;
1303 cx.run_until_parked();
1304
1305 // Test that test-1 profile (default) has echo and delay tools
1306 thread
1307 .update(cx, |thread, cx| {
1308 thread.set_profile(AgentProfileId("test-1".into()), cx);
1309 thread.send(UserMessageId::new(), ["test"], cx)
1310 })
1311 .unwrap();
1312 cx.run_until_parked();
1313
1314 let mut pending_completions = fake_model.pending_completions();
1315 assert_eq!(pending_completions.len(), 1);
1316 let completion = pending_completions.pop().unwrap();
1317 let tool_names: Vec<String> = completion
1318 .tools
1319 .iter()
1320 .map(|tool| tool.name.clone())
1321 .collect();
1322 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1323 fake_model.end_last_completion_stream();
1324
1325 // Switch to test-2 profile, and verify that it has only the infinite tool.
1326 thread
1327 .update(cx, |thread, cx| {
1328 thread.set_profile(AgentProfileId("test-2".into()), cx);
1329 thread.send(UserMessageId::new(), ["test2"], cx)
1330 })
1331 .unwrap();
1332 cx.run_until_parked();
1333 let mut pending_completions = fake_model.pending_completions();
1334 assert_eq!(pending_completions.len(), 1);
1335 let completion = pending_completions.pop().unwrap();
1336 let tool_names: Vec<String> = completion
1337 .tools
1338 .iter()
1339 .map(|tool| tool.name.clone())
1340 .collect();
1341 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1342}
1343
1344#[gpui::test]
1345async fn test_mcp_tools(cx: &mut TestAppContext) {
1346 let ThreadTest {
1347 model,
1348 thread,
1349 context_server_store,
1350 fs,
1351 ..
1352 } = setup(cx, TestModel::Fake).await;
1353 let fake_model = model.as_fake();
1354
1355 // Override profiles and wait for settings to be loaded.
1356 fs.insert_file(
1357 paths::settings_file(),
1358 json!({
1359 "agent": {
1360 "tool_permissions": { "default": "allow" },
1361 "profiles": {
1362 "test": {
1363 "name": "Test Profile",
1364 "enable_all_context_servers": true,
1365 "tools": {
1366 EchoTool::NAME: true,
1367 }
1368 },
1369 }
1370 }
1371 })
1372 .to_string()
1373 .into_bytes(),
1374 )
1375 .await;
1376 cx.run_until_parked();
1377 thread.update(cx, |thread, cx| {
1378 thread.set_profile(AgentProfileId("test".into()), cx)
1379 });
1380
1381 let mut mcp_tool_calls = setup_context_server(
1382 "test_server",
1383 vec![context_server::types::Tool {
1384 name: "echo".into(),
1385 description: None,
1386 input_schema: serde_json::to_value(EchoTool::input_schema(
1387 LanguageModelToolSchemaFormat::JsonSchema,
1388 ))
1389 .unwrap(),
1390 output_schema: None,
1391 annotations: None,
1392 }],
1393 &context_server_store,
1394 cx,
1395 );
1396
1397 let events = thread.update(cx, |thread, cx| {
1398 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1399 });
1400 cx.run_until_parked();
1401
1402 // Simulate the model calling the MCP tool.
1403 let completion = fake_model.pending_completions().pop().unwrap();
1404 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1405 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1406 LanguageModelToolUse {
1407 id: "tool_1".into(),
1408 name: "echo".into(),
1409 raw_input: json!({"text": "test"}).to_string(),
1410 input: json!({"text": "test"}),
1411 is_input_complete: true,
1412 thought_signature: None,
1413 },
1414 ));
1415 fake_model.end_last_completion_stream();
1416 cx.run_until_parked();
1417
1418 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1419 assert_eq!(tool_call_params.name, "echo");
1420 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1421 tool_call_response
1422 .send(context_server::types::CallToolResponse {
1423 content: vec![context_server::types::ToolResponseContent::Text {
1424 text: "test".into(),
1425 }],
1426 is_error: None,
1427 meta: None,
1428 structured_content: None,
1429 })
1430 .unwrap();
1431 cx.run_until_parked();
1432
1433 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1434 fake_model.send_last_completion_stream_text_chunk("Done!");
1435 fake_model.end_last_completion_stream();
1436 events.collect::<Vec<_>>().await;
1437
1438 // Send again after adding the echo tool, ensuring the name collision is resolved.
1439 let events = thread.update(cx, |thread, cx| {
1440 thread.add_tool(EchoTool, None);
1441 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1442 });
1443 cx.run_until_parked();
1444 let completion = fake_model.pending_completions().pop().unwrap();
1445 assert_eq!(
1446 tool_names_for_completion(&completion),
1447 vec!["echo", "test_server_echo"]
1448 );
1449 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1450 LanguageModelToolUse {
1451 id: "tool_2".into(),
1452 name: "test_server_echo".into(),
1453 raw_input: json!({"text": "mcp"}).to_string(),
1454 input: json!({"text": "mcp"}),
1455 is_input_complete: true,
1456 thought_signature: None,
1457 },
1458 ));
1459 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1460 LanguageModelToolUse {
1461 id: "tool_3".into(),
1462 name: "echo".into(),
1463 raw_input: json!({"text": "native"}).to_string(),
1464 input: json!({"text": "native"}),
1465 is_input_complete: true,
1466 thought_signature: None,
1467 },
1468 ));
1469 fake_model.end_last_completion_stream();
1470 cx.run_until_parked();
1471
1472 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1473 assert_eq!(tool_call_params.name, "echo");
1474 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1475 tool_call_response
1476 .send(context_server::types::CallToolResponse {
1477 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1478 is_error: None,
1479 meta: None,
1480 structured_content: None,
1481 })
1482 .unwrap();
1483 cx.run_until_parked();
1484
1485 // Ensure the tool results were inserted with the correct names.
1486 let completion = fake_model.pending_completions().pop().unwrap();
1487 assert_eq!(
1488 completion.messages.last().unwrap().content,
1489 vec![
1490 MessageContent::ToolResult(LanguageModelToolResult {
1491 tool_use_id: "tool_3".into(),
1492 tool_name: "echo".into(),
1493 is_error: false,
1494 content: "native".into(),
1495 output: Some("native".into()),
1496 },),
1497 MessageContent::ToolResult(LanguageModelToolResult {
1498 tool_use_id: "tool_2".into(),
1499 tool_name: "test_server_echo".into(),
1500 is_error: false,
1501 content: "mcp".into(),
1502 output: Some("mcp".into()),
1503 },),
1504 ]
1505 );
1506 fake_model.end_last_completion_stream();
1507 events.collect::<Vec<_>>().await;
1508}
1509
1510#[gpui::test]
1511async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1512 let ThreadTest {
1513 model,
1514 thread,
1515 context_server_store,
1516 fs,
1517 ..
1518 } = setup(cx, TestModel::Fake).await;
1519 let fake_model = model.as_fake();
1520
1521 // Setup settings to allow MCP tools
1522 fs.insert_file(
1523 paths::settings_file(),
1524 json!({
1525 "agent": {
1526 "always_allow_tool_actions": true,
1527 "profiles": {
1528 "test": {
1529 "name": "Test Profile",
1530 "enable_all_context_servers": true,
1531 "tools": {}
1532 },
1533 }
1534 }
1535 })
1536 .to_string()
1537 .into_bytes(),
1538 )
1539 .await;
1540 cx.run_until_parked();
1541 thread.update(cx, |thread, cx| {
1542 thread.set_profile(AgentProfileId("test".into()), cx)
1543 });
1544
1545 // Setup a context server with a tool
1546 let mut mcp_tool_calls = setup_context_server(
1547 "github_server",
1548 vec![context_server::types::Tool {
1549 name: "issue_read".into(),
1550 description: Some("Read a GitHub issue".into()),
1551 input_schema: json!({
1552 "type": "object",
1553 "properties": {
1554 "issue_url": { "type": "string" }
1555 }
1556 }),
1557 output_schema: None,
1558 annotations: None,
1559 }],
1560 &context_server_store,
1561 cx,
1562 );
1563
1564 // Send a message and have the model call the MCP tool
1565 let events = thread.update(cx, |thread, cx| {
1566 thread
1567 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1568 .unwrap()
1569 });
1570 cx.run_until_parked();
1571
1572 // Verify the MCP tool is available to the model
1573 let completion = fake_model.pending_completions().pop().unwrap();
1574 assert_eq!(
1575 tool_names_for_completion(&completion),
1576 vec!["issue_read"],
1577 "MCP tool should be available"
1578 );
1579
1580 // Simulate the model calling the MCP tool
1581 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1582 LanguageModelToolUse {
1583 id: "tool_1".into(),
1584 name: "issue_read".into(),
1585 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1586 .to_string(),
1587 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1588 is_input_complete: true,
1589 thought_signature: None,
1590 },
1591 ));
1592 fake_model.end_last_completion_stream();
1593 cx.run_until_parked();
1594
1595 // The MCP server receives the tool call and responds with content
1596 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1597 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1598 assert_eq!(tool_call_params.name, "issue_read");
1599 tool_call_response
1600 .send(context_server::types::CallToolResponse {
1601 content: vec![context_server::types::ToolResponseContent::Text {
1602 text: expected_tool_output.into(),
1603 }],
1604 is_error: None,
1605 meta: None,
1606 structured_content: None,
1607 })
1608 .unwrap();
1609 cx.run_until_parked();
1610
1611 // After tool completes, the model continues with a new completion request
1612 // that includes the tool results. We need to respond to this.
1613 let _completion = fake_model.pending_completions().pop().unwrap();
1614 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1615 fake_model
1616 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1617 fake_model.end_last_completion_stream();
1618 events.collect::<Vec<_>>().await;
1619
1620 // Verify the tool result is stored in the thread by checking the markdown output.
1621 // The tool result is in the first assistant message (not the last one, which is
1622 // the model's response after the tool completed).
1623 thread.update(cx, |thread, _cx| {
1624 let markdown = thread.to_markdown();
1625 assert!(
1626 markdown.contains("**Tool Result**: issue_read"),
1627 "Thread should contain tool result header"
1628 );
1629 assert!(
1630 markdown.contains(expected_tool_output),
1631 "Thread should contain tool output: {}",
1632 expected_tool_output
1633 );
1634 });
1635
1636 // Simulate app restart: disconnect the MCP server.
1637 // After restart, the MCP server won't be connected yet when the thread is replayed.
1638 context_server_store.update(cx, |store, cx| {
1639 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1640 });
1641 cx.run_until_parked();
1642
1643 // Replay the thread (this is what happens when loading a saved thread)
1644 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1645
1646 let mut found_tool_call = None;
1647 let mut found_tool_call_update_with_output = None;
1648
1649 while let Some(event) = replay_events.next().await {
1650 let event = event.unwrap();
1651 match &event {
1652 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1653 found_tool_call = Some(tc.clone());
1654 }
1655 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1656 if update.tool_call_id.to_string() == "tool_1" =>
1657 {
1658 if update.fields.raw_output.is_some() {
1659 found_tool_call_update_with_output = Some(update.clone());
1660 }
1661 }
1662 _ => {}
1663 }
1664 }
1665
1666 // The tool call should be found
1667 assert!(
1668 found_tool_call.is_some(),
1669 "Tool call should be emitted during replay"
1670 );
1671
1672 assert!(
1673 found_tool_call_update_with_output.is_some(),
1674 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1675 );
1676
1677 let update = found_tool_call_update_with_output.unwrap();
1678 assert_eq!(
1679 update.fields.raw_output,
1680 Some(expected_tool_output.into()),
1681 "raw_output should contain the saved tool result"
1682 );
1683
1684 // Also verify the status is correct (completed, not failed)
1685 assert_eq!(
1686 update.fields.status,
1687 Some(acp::ToolCallStatus::Completed),
1688 "Tool call status should reflect the original completion status"
1689 );
1690}
1691
1692#[gpui::test]
1693async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1694 let ThreadTest {
1695 model,
1696 thread,
1697 context_server_store,
1698 fs,
1699 ..
1700 } = setup(cx, TestModel::Fake).await;
1701 let fake_model = model.as_fake();
1702
1703 // Set up a profile with all tools enabled
1704 fs.insert_file(
1705 paths::settings_file(),
1706 json!({
1707 "agent": {
1708 "profiles": {
1709 "test": {
1710 "name": "Test Profile",
1711 "enable_all_context_servers": true,
1712 "tools": {
1713 EchoTool::NAME: true,
1714 DelayTool::NAME: true,
1715 WordListTool::NAME: true,
1716 ToolRequiringPermission::NAME: true,
1717 InfiniteTool::NAME: true,
1718 }
1719 },
1720 }
1721 }
1722 })
1723 .to_string()
1724 .into_bytes(),
1725 )
1726 .await;
1727 cx.run_until_parked();
1728
1729 thread.update(cx, |thread, cx| {
1730 thread.set_profile(AgentProfileId("test".into()), cx);
1731 thread.add_tool(EchoTool, None);
1732 thread.add_tool(DelayTool, None);
1733 thread.add_tool(WordListTool, None);
1734 thread.add_tool(ToolRequiringPermission, None);
1735 thread.add_tool(InfiniteTool, None);
1736 });
1737
1738 // Set up multiple context servers with some overlapping tool names
1739 let _server1_calls = setup_context_server(
1740 "xxx",
1741 vec![
1742 context_server::types::Tool {
1743 name: "echo".into(), // Conflicts with native EchoTool
1744 description: None,
1745 input_schema: serde_json::to_value(EchoTool::input_schema(
1746 LanguageModelToolSchemaFormat::JsonSchema,
1747 ))
1748 .unwrap(),
1749 output_schema: None,
1750 annotations: None,
1751 },
1752 context_server::types::Tool {
1753 name: "unique_tool_1".into(),
1754 description: None,
1755 input_schema: json!({"type": "object", "properties": {}}),
1756 output_schema: None,
1757 annotations: None,
1758 },
1759 ],
1760 &context_server_store,
1761 cx,
1762 );
1763
1764 let _server2_calls = setup_context_server(
1765 "yyy",
1766 vec![
1767 context_server::types::Tool {
1768 name: "echo".into(), // Also conflicts with native EchoTool
1769 description: None,
1770 input_schema: serde_json::to_value(EchoTool::input_schema(
1771 LanguageModelToolSchemaFormat::JsonSchema,
1772 ))
1773 .unwrap(),
1774 output_schema: None,
1775 annotations: None,
1776 },
1777 context_server::types::Tool {
1778 name: "unique_tool_2".into(),
1779 description: None,
1780 input_schema: json!({"type": "object", "properties": {}}),
1781 output_schema: None,
1782 annotations: None,
1783 },
1784 context_server::types::Tool {
1785 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1786 description: None,
1787 input_schema: json!({"type": "object", "properties": {}}),
1788 output_schema: None,
1789 annotations: None,
1790 },
1791 context_server::types::Tool {
1792 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1793 description: None,
1794 input_schema: json!({"type": "object", "properties": {}}),
1795 output_schema: None,
1796 annotations: None,
1797 },
1798 ],
1799 &context_server_store,
1800 cx,
1801 );
1802 let _server3_calls = setup_context_server(
1803 "zzz",
1804 vec![
1805 context_server::types::Tool {
1806 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1807 description: None,
1808 input_schema: json!({"type": "object", "properties": {}}),
1809 output_schema: None,
1810 annotations: None,
1811 },
1812 context_server::types::Tool {
1813 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1814 description: None,
1815 input_schema: json!({"type": "object", "properties": {}}),
1816 output_schema: None,
1817 annotations: None,
1818 },
1819 context_server::types::Tool {
1820 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1821 description: None,
1822 input_schema: json!({"type": "object", "properties": {}}),
1823 output_schema: None,
1824 annotations: None,
1825 },
1826 ],
1827 &context_server_store,
1828 cx,
1829 );
1830
1831 // Server with spaces in name - tests snake_case conversion for API compatibility
1832 let _server4_calls = setup_context_server(
1833 "Azure DevOps",
1834 vec![context_server::types::Tool {
1835 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1836 description: None,
1837 input_schema: serde_json::to_value(EchoTool::input_schema(
1838 LanguageModelToolSchemaFormat::JsonSchema,
1839 ))
1840 .unwrap(),
1841 output_schema: None,
1842 annotations: None,
1843 }],
1844 &context_server_store,
1845 cx,
1846 );
1847
1848 thread
1849 .update(cx, |thread, cx| {
1850 thread.send(UserMessageId::new(), ["Go"], cx)
1851 })
1852 .unwrap();
1853 cx.run_until_parked();
1854 let completion = fake_model.pending_completions().pop().unwrap();
1855 assert_eq!(
1856 tool_names_for_completion(&completion),
1857 vec![
1858 "azure_dev_ops_echo",
1859 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1860 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1861 "delay",
1862 "echo",
1863 "infinite",
1864 "tool_requiring_permission",
1865 "unique_tool_1",
1866 "unique_tool_2",
1867 "word_list",
1868 "xxx_echo",
1869 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1870 "yyy_echo",
1871 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1872 ]
1873 );
1874}
1875
1876#[gpui::test]
1877#[cfg_attr(not(feature = "e2e"), ignore)]
1878async fn test_cancellation(cx: &mut TestAppContext) {
1879 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1880
1881 let mut events = thread
1882 .update(cx, |thread, cx| {
1883 thread.add_tool(InfiniteTool, None);
1884 thread.add_tool(EchoTool, None);
1885 thread.send(
1886 UserMessageId::new(),
1887 ["Call the echo tool, then call the infinite tool, then explain their output"],
1888 cx,
1889 )
1890 })
1891 .unwrap();
1892
1893 // Wait until both tools are called.
1894 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1895 let mut echo_id = None;
1896 let mut echo_completed = false;
1897 while let Some(event) = events.next().await {
1898 match event.unwrap() {
1899 ThreadEvent::ToolCall(tool_call) => {
1900 assert_eq!(tool_call.title, expected_tools.remove(0));
1901 if tool_call.title == "Echo" {
1902 echo_id = Some(tool_call.tool_call_id);
1903 }
1904 }
1905 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1906 acp::ToolCallUpdate {
1907 tool_call_id,
1908 fields:
1909 acp::ToolCallUpdateFields {
1910 status: Some(acp::ToolCallStatus::Completed),
1911 ..
1912 },
1913 ..
1914 },
1915 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1916 echo_completed = true;
1917 }
1918 _ => {}
1919 }
1920
1921 if expected_tools.is_empty() && echo_completed {
1922 break;
1923 }
1924 }
1925
1926 // Cancel the current send and ensure that the event stream is closed, even
1927 // if one of the tools is still running.
1928 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1929 let events = events.collect::<Vec<_>>().await;
1930 let last_event = events.last();
1931 assert!(
1932 matches!(
1933 last_event,
1934 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1935 ),
1936 "unexpected event {last_event:?}"
1937 );
1938
1939 // Ensure we can still send a new message after cancellation.
1940 let events = thread
1941 .update(cx, |thread, cx| {
1942 thread.send(
1943 UserMessageId::new(),
1944 ["Testing: reply with 'Hello' then stop."],
1945 cx,
1946 )
1947 })
1948 .unwrap()
1949 .collect::<Vec<_>>()
1950 .await;
1951 thread.update(cx, |thread, _cx| {
1952 let message = thread.last_message().unwrap();
1953 let agent_message = message.as_agent_message().unwrap();
1954 assert_eq!(
1955 agent_message.content,
1956 vec![AgentMessageContent::Text("Hello".to_string())]
1957 );
1958 });
1959 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1960}
1961
1962#[gpui::test]
1963async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1964 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1965 always_allow_tools(cx);
1966 let fake_model = model.as_fake();
1967
1968 let environment = Rc::new(cx.update(|cx| {
1969 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1970 }));
1971 let handle = environment.terminal_handle.clone().unwrap();
1972
1973 let mut events = thread
1974 .update(cx, |thread, cx| {
1975 thread.add_tool(
1976 crate::TerminalTool::new(thread.project().clone(), environment),
1977 None,
1978 );
1979 thread.send(UserMessageId::new(), ["run a command"], cx)
1980 })
1981 .unwrap();
1982
1983 cx.run_until_parked();
1984
1985 // Simulate the model calling the terminal tool
1986 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1987 LanguageModelToolUse {
1988 id: "terminal_tool_1".into(),
1989 name: TerminalTool::NAME.into(),
1990 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1991 input: json!({"command": "sleep 1000", "cd": "."}),
1992 is_input_complete: true,
1993 thought_signature: None,
1994 },
1995 ));
1996 fake_model.end_last_completion_stream();
1997
1998 // Wait for the terminal tool to start running
1999 wait_for_terminal_tool_started(&mut events, cx).await;
2000
2001 // Cancel the thread while the terminal is running
2002 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2003
2004 // Collect remaining events, driving the executor to let cancellation complete
2005 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2006
2007 // Verify the terminal was killed
2008 assert!(
2009 handle.was_killed(),
2010 "expected terminal handle to be killed on cancellation"
2011 );
2012
2013 // Verify we got a cancellation stop event
2014 assert_eq!(
2015 stop_events(remaining_events),
2016 vec![acp::StopReason::Cancelled],
2017 );
2018
2019 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2020 thread.update(cx, |thread, _cx| {
2021 let message = thread.last_message().unwrap();
2022 let agent_message = message.as_agent_message().unwrap();
2023
2024 let tool_use = agent_message
2025 .content
2026 .iter()
2027 .find_map(|content| match content {
2028 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2029 _ => None,
2030 })
2031 .expect("expected tool use in agent message");
2032
2033 let tool_result = agent_message
2034 .tool_results
2035 .get(&tool_use.id)
2036 .expect("expected tool result");
2037
2038 let result_text = match &tool_result.content {
2039 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2040 _ => panic!("expected text content in tool result"),
2041 };
2042
2043 // "partial output" comes from FakeTerminalHandle's output field
2044 assert!(
2045 result_text.contains("partial output"),
2046 "expected tool result to contain terminal output, got: {result_text}"
2047 );
2048 // Match the actual format from process_content in terminal_tool.rs
2049 assert!(
2050 result_text.contains("The user stopped this command"),
2051 "expected tool result to indicate user stopped, got: {result_text}"
2052 );
2053 });
2054
2055 // Verify we can send a new message after cancellation
2056 verify_thread_recovery(&thread, &fake_model, cx).await;
2057}
2058
2059#[gpui::test]
2060async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2061 // This test verifies that tools which properly handle cancellation via
2062 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2063 // to cancellation and report that they were cancelled.
2064 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2065 always_allow_tools(cx);
2066 let fake_model = model.as_fake();
2067
2068 let (tool, was_cancelled) = CancellationAwareTool::new();
2069
2070 let mut events = thread
2071 .update(cx, |thread, cx| {
2072 thread.add_tool(tool, None);
2073 thread.send(
2074 UserMessageId::new(),
2075 ["call the cancellation aware tool"],
2076 cx,
2077 )
2078 })
2079 .unwrap();
2080
2081 cx.run_until_parked();
2082
2083 // Simulate the model calling the cancellation-aware tool
2084 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2085 LanguageModelToolUse {
2086 id: "cancellation_aware_1".into(),
2087 name: "cancellation_aware".into(),
2088 raw_input: r#"{}"#.into(),
2089 input: json!({}),
2090 is_input_complete: true,
2091 thought_signature: None,
2092 },
2093 ));
2094 fake_model.end_last_completion_stream();
2095
2096 cx.run_until_parked();
2097
2098 // Wait for the tool call to be reported
2099 let mut tool_started = false;
2100 let deadline = cx.executor().num_cpus() * 100;
2101 for _ in 0..deadline {
2102 cx.run_until_parked();
2103
2104 while let Some(Some(event)) = events.next().now_or_never() {
2105 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2106 if tool_call.title == "Cancellation Aware Tool" {
2107 tool_started = true;
2108 break;
2109 }
2110 }
2111 }
2112
2113 if tool_started {
2114 break;
2115 }
2116
2117 cx.background_executor
2118 .timer(Duration::from_millis(10))
2119 .await;
2120 }
2121 assert!(tool_started, "expected cancellation aware tool to start");
2122
2123 // Cancel the thread and wait for it to complete
2124 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2125
2126 // The cancel task should complete promptly because the tool handles cancellation
2127 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2128 futures::select! {
2129 _ = cancel_task.fuse() => {}
2130 _ = timeout.fuse() => {
2131 panic!("cancel task timed out - tool did not respond to cancellation");
2132 }
2133 }
2134
2135 // Verify the tool detected cancellation via its flag
2136 assert!(
2137 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2138 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2139 );
2140
2141 // Collect remaining events
2142 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2143
2144 // Verify we got a cancellation stop event
2145 assert_eq!(
2146 stop_events(remaining_events),
2147 vec![acp::StopReason::Cancelled],
2148 );
2149
2150 // Verify we can send a new message after cancellation
2151 verify_thread_recovery(&thread, &fake_model, cx).await;
2152}
2153
2154/// Helper to verify thread can recover after cancellation by sending a simple message.
2155async fn verify_thread_recovery(
2156 thread: &Entity<Thread>,
2157 fake_model: &FakeLanguageModel,
2158 cx: &mut TestAppContext,
2159) {
2160 let events = thread
2161 .update(cx, |thread, cx| {
2162 thread.send(
2163 UserMessageId::new(),
2164 ["Testing: reply with 'Hello' then stop."],
2165 cx,
2166 )
2167 })
2168 .unwrap();
2169 cx.run_until_parked();
2170 fake_model.send_last_completion_stream_text_chunk("Hello");
2171 fake_model
2172 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2173 fake_model.end_last_completion_stream();
2174
2175 let events = events.collect::<Vec<_>>().await;
2176 thread.update(cx, |thread, _cx| {
2177 let message = thread.last_message().unwrap();
2178 let agent_message = message.as_agent_message().unwrap();
2179 assert_eq!(
2180 agent_message.content,
2181 vec![AgentMessageContent::Text("Hello".to_string())]
2182 );
2183 });
2184 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2185}
2186
2187/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2188async fn wait_for_terminal_tool_started(
2189 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2190 cx: &mut TestAppContext,
2191) {
2192 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2193 for _ in 0..deadline {
2194 cx.run_until_parked();
2195
2196 while let Some(Some(event)) = events.next().now_or_never() {
2197 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2198 update,
2199 ))) = &event
2200 {
2201 if update.fields.content.as_ref().is_some_and(|content| {
2202 content
2203 .iter()
2204 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2205 }) {
2206 return;
2207 }
2208 }
2209 }
2210
2211 cx.background_executor
2212 .timer(Duration::from_millis(10))
2213 .await;
2214 }
2215 panic!("terminal tool did not start within the expected time");
2216}
2217
2218/// Collects events until a Stop event is received, driving the executor to completion.
2219async fn collect_events_until_stop(
2220 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2221 cx: &mut TestAppContext,
2222) -> Vec<Result<ThreadEvent>> {
2223 let mut collected = Vec::new();
2224 let deadline = cx.executor().num_cpus() * 200;
2225
2226 for _ in 0..deadline {
2227 cx.executor().advance_clock(Duration::from_millis(10));
2228 cx.run_until_parked();
2229
2230 while let Some(Some(event)) = events.next().now_or_never() {
2231 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2232 collected.push(event);
2233 if is_stop {
2234 return collected;
2235 }
2236 }
2237 }
2238 panic!(
2239 "did not receive Stop event within the expected time; collected {} events",
2240 collected.len()
2241 );
2242}
2243
2244#[gpui::test]
2245async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2246 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2247 always_allow_tools(cx);
2248 let fake_model = model.as_fake();
2249
2250 let environment = Rc::new(cx.update(|cx| {
2251 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2252 }));
2253 let handle = environment.terminal_handle.clone().unwrap();
2254
2255 let message_id = UserMessageId::new();
2256 let mut events = thread
2257 .update(cx, |thread, cx| {
2258 thread.add_tool(
2259 crate::TerminalTool::new(thread.project().clone(), environment),
2260 None,
2261 );
2262 thread.send(message_id.clone(), ["run a command"], cx)
2263 })
2264 .unwrap();
2265
2266 cx.run_until_parked();
2267
2268 // Simulate the model calling the terminal tool
2269 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2270 LanguageModelToolUse {
2271 id: "terminal_tool_1".into(),
2272 name: TerminalTool::NAME.into(),
2273 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2274 input: json!({"command": "sleep 1000", "cd": "."}),
2275 is_input_complete: true,
2276 thought_signature: None,
2277 },
2278 ));
2279 fake_model.end_last_completion_stream();
2280
2281 // Wait for the terminal tool to start running
2282 wait_for_terminal_tool_started(&mut events, cx).await;
2283
2284 // Truncate the thread while the terminal is running
2285 thread
2286 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2287 .unwrap();
2288
2289 // Drive the executor to let cancellation complete
2290 let _ = collect_events_until_stop(&mut events, cx).await;
2291
2292 // Verify the terminal was killed
2293 assert!(
2294 handle.was_killed(),
2295 "expected terminal handle to be killed on truncate"
2296 );
2297
2298 // Verify the thread is empty after truncation
2299 thread.update(cx, |thread, _cx| {
2300 assert_eq!(
2301 thread.to_markdown(),
2302 "",
2303 "expected thread to be empty after truncating the only message"
2304 );
2305 });
2306
2307 // Verify we can send a new message after truncation
2308 verify_thread_recovery(&thread, &fake_model, cx).await;
2309}
2310
2311#[gpui::test]
2312async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2313 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2314 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2315 always_allow_tools(cx);
2316 let fake_model = model.as_fake();
2317
2318 let environment = Rc::new(MultiTerminalEnvironment::new());
2319
2320 let mut events = thread
2321 .update(cx, |thread, cx| {
2322 thread.add_tool(
2323 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2324 None,
2325 );
2326 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2327 })
2328 .unwrap();
2329
2330 cx.run_until_parked();
2331
2332 // Simulate the model calling two terminal tools
2333 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2334 LanguageModelToolUse {
2335 id: "terminal_tool_1".into(),
2336 name: TerminalTool::NAME.into(),
2337 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2338 input: json!({"command": "sleep 1000", "cd": "."}),
2339 is_input_complete: true,
2340 thought_signature: None,
2341 },
2342 ));
2343 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2344 LanguageModelToolUse {
2345 id: "terminal_tool_2".into(),
2346 name: TerminalTool::NAME.into(),
2347 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2348 input: json!({"command": "sleep 2000", "cd": "."}),
2349 is_input_complete: true,
2350 thought_signature: None,
2351 },
2352 ));
2353 fake_model.end_last_completion_stream();
2354
2355 // Wait for both terminal tools to start by counting terminal content updates
2356 let mut terminals_started = 0;
2357 let deadline = cx.executor().num_cpus() * 100;
2358 for _ in 0..deadline {
2359 cx.run_until_parked();
2360
2361 while let Some(Some(event)) = events.next().now_or_never() {
2362 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2363 update,
2364 ))) = &event
2365 {
2366 if update.fields.content.as_ref().is_some_and(|content| {
2367 content
2368 .iter()
2369 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2370 }) {
2371 terminals_started += 1;
2372 if terminals_started >= 2 {
2373 break;
2374 }
2375 }
2376 }
2377 }
2378 if terminals_started >= 2 {
2379 break;
2380 }
2381
2382 cx.background_executor
2383 .timer(Duration::from_millis(10))
2384 .await;
2385 }
2386 assert!(
2387 terminals_started >= 2,
2388 "expected 2 terminal tools to start, got {terminals_started}"
2389 );
2390
2391 // Cancel the thread while both terminals are running
2392 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2393
2394 // Collect remaining events
2395 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2396
2397 // Verify both terminal handles were killed
2398 let handles = environment.handles();
2399 assert_eq!(
2400 handles.len(),
2401 2,
2402 "expected 2 terminal handles to be created"
2403 );
2404 assert!(
2405 handles[0].was_killed(),
2406 "expected first terminal handle to be killed on cancellation"
2407 );
2408 assert!(
2409 handles[1].was_killed(),
2410 "expected second terminal handle to be killed on cancellation"
2411 );
2412
2413 // Verify we got a cancellation stop event
2414 assert_eq!(
2415 stop_events(remaining_events),
2416 vec![acp::StopReason::Cancelled],
2417 );
2418}
2419
2420#[gpui::test]
2421async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2422 // Tests that clicking the stop button on the terminal card (as opposed to the main
2423 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2424 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2425 always_allow_tools(cx);
2426 let fake_model = model.as_fake();
2427
2428 let environment = Rc::new(cx.update(|cx| {
2429 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2430 }));
2431 let handle = environment.terminal_handle.clone().unwrap();
2432
2433 let mut events = thread
2434 .update(cx, |thread, cx| {
2435 thread.add_tool(
2436 crate::TerminalTool::new(thread.project().clone(), environment),
2437 None,
2438 );
2439 thread.send(UserMessageId::new(), ["run a command"], cx)
2440 })
2441 .unwrap();
2442
2443 cx.run_until_parked();
2444
2445 // Simulate the model calling the terminal tool
2446 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2447 LanguageModelToolUse {
2448 id: "terminal_tool_1".into(),
2449 name: TerminalTool::NAME.into(),
2450 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2451 input: json!({"command": "sleep 1000", "cd": "."}),
2452 is_input_complete: true,
2453 thought_signature: None,
2454 },
2455 ));
2456 fake_model.end_last_completion_stream();
2457
2458 // Wait for the terminal tool to start running
2459 wait_for_terminal_tool_started(&mut events, cx).await;
2460
2461 // Simulate user clicking stop on the terminal card itself.
2462 // This sets the flag and signals exit (simulating what the real UI would do).
2463 handle.set_stopped_by_user(true);
2464 handle.killed.store(true, Ordering::SeqCst);
2465 handle.signal_exit();
2466
2467 // Wait for the tool to complete
2468 cx.run_until_parked();
2469
2470 // The thread continues after tool completion - simulate the model ending its turn
2471 fake_model
2472 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2473 fake_model.end_last_completion_stream();
2474
2475 // Collect remaining events
2476 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2477
2478 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2479 assert_eq!(
2480 stop_events(remaining_events),
2481 vec![acp::StopReason::EndTurn],
2482 );
2483
2484 // Verify the tool result indicates user stopped
2485 thread.update(cx, |thread, _cx| {
2486 let message = thread.last_message().unwrap();
2487 let agent_message = message.as_agent_message().unwrap();
2488
2489 let tool_use = agent_message
2490 .content
2491 .iter()
2492 .find_map(|content| match content {
2493 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2494 _ => None,
2495 })
2496 .expect("expected tool use in agent message");
2497
2498 let tool_result = agent_message
2499 .tool_results
2500 .get(&tool_use.id)
2501 .expect("expected tool result");
2502
2503 let result_text = match &tool_result.content {
2504 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2505 _ => panic!("expected text content in tool result"),
2506 };
2507
2508 assert!(
2509 result_text.contains("The user stopped this command"),
2510 "expected tool result to indicate user stopped, got: {result_text}"
2511 );
2512 });
2513}
2514
2515#[gpui::test]
2516async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2517 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2518 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2519 always_allow_tools(cx);
2520 let fake_model = model.as_fake();
2521
2522 let environment = Rc::new(cx.update(|cx| {
2523 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2524 }));
2525 let handle = environment.terminal_handle.clone().unwrap();
2526
2527 let mut events = thread
2528 .update(cx, |thread, cx| {
2529 thread.add_tool(
2530 crate::TerminalTool::new(thread.project().clone(), environment),
2531 None,
2532 );
2533 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2534 })
2535 .unwrap();
2536
2537 cx.run_until_parked();
2538
2539 // Simulate the model calling the terminal tool with a short timeout
2540 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2541 LanguageModelToolUse {
2542 id: "terminal_tool_1".into(),
2543 name: TerminalTool::NAME.into(),
2544 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2545 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2546 is_input_complete: true,
2547 thought_signature: None,
2548 },
2549 ));
2550 fake_model.end_last_completion_stream();
2551
2552 // Wait for the terminal tool to start running
2553 wait_for_terminal_tool_started(&mut events, cx).await;
2554
2555 // Advance clock past the timeout
2556 cx.executor().advance_clock(Duration::from_millis(200));
2557 cx.run_until_parked();
2558
2559 // The thread continues after tool completion - simulate the model ending its turn
2560 fake_model
2561 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2562 fake_model.end_last_completion_stream();
2563
2564 // Collect remaining events
2565 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2566
2567 // Verify the terminal was killed due to timeout
2568 assert!(
2569 handle.was_killed(),
2570 "expected terminal handle to be killed on timeout"
2571 );
2572
2573 // Verify we got an EndTurn (the tool completed, just with timeout)
2574 assert_eq!(
2575 stop_events(remaining_events),
2576 vec![acp::StopReason::EndTurn],
2577 );
2578
2579 // Verify the tool result indicates timeout, not user stopped
2580 thread.update(cx, |thread, _cx| {
2581 let message = thread.last_message().unwrap();
2582 let agent_message = message.as_agent_message().unwrap();
2583
2584 let tool_use = agent_message
2585 .content
2586 .iter()
2587 .find_map(|content| match content {
2588 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2589 _ => None,
2590 })
2591 .expect("expected tool use in agent message");
2592
2593 let tool_result = agent_message
2594 .tool_results
2595 .get(&tool_use.id)
2596 .expect("expected tool result");
2597
2598 let result_text = match &tool_result.content {
2599 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2600 _ => panic!("expected text content in tool result"),
2601 };
2602
2603 assert!(
2604 result_text.contains("timed out"),
2605 "expected tool result to indicate timeout, got: {result_text}"
2606 );
2607 assert!(
2608 !result_text.contains("The user stopped"),
2609 "tool result should not mention user stopped when it timed out, got: {result_text}"
2610 );
2611 });
2612}
2613
2614#[gpui::test]
2615async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2616 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2617 let fake_model = model.as_fake();
2618
2619 let events_1 = thread
2620 .update(cx, |thread, cx| {
2621 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2622 })
2623 .unwrap();
2624 cx.run_until_parked();
2625 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2626 cx.run_until_parked();
2627
2628 let events_2 = thread
2629 .update(cx, |thread, cx| {
2630 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2631 })
2632 .unwrap();
2633 cx.run_until_parked();
2634 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2635 fake_model
2636 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2637 fake_model.end_last_completion_stream();
2638
2639 let events_1 = events_1.collect::<Vec<_>>().await;
2640 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2641 let events_2 = events_2.collect::<Vec<_>>().await;
2642 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2643}
2644
2645#[gpui::test]
2646async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2647 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2648 let fake_model = model.as_fake();
2649
2650 let events_1 = thread
2651 .update(cx, |thread, cx| {
2652 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2653 })
2654 .unwrap();
2655 cx.run_until_parked();
2656 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2657 fake_model
2658 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2659 fake_model.end_last_completion_stream();
2660 let events_1 = events_1.collect::<Vec<_>>().await;
2661
2662 let events_2 = thread
2663 .update(cx, |thread, cx| {
2664 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2665 })
2666 .unwrap();
2667 cx.run_until_parked();
2668 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2669 fake_model
2670 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2671 fake_model.end_last_completion_stream();
2672 let events_2 = events_2.collect::<Vec<_>>().await;
2673
2674 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2675 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2676}
2677
2678#[gpui::test]
2679async fn test_refusal(cx: &mut TestAppContext) {
2680 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2681 let fake_model = model.as_fake();
2682
2683 let events = thread
2684 .update(cx, |thread, cx| {
2685 thread.send(UserMessageId::new(), ["Hello"], cx)
2686 })
2687 .unwrap();
2688 cx.run_until_parked();
2689 thread.read_with(cx, |thread, _| {
2690 assert_eq!(
2691 thread.to_markdown(),
2692 indoc! {"
2693 ## User
2694
2695 Hello
2696 "}
2697 );
2698 });
2699
2700 fake_model.send_last_completion_stream_text_chunk("Hey!");
2701 cx.run_until_parked();
2702 thread.read_with(cx, |thread, _| {
2703 assert_eq!(
2704 thread.to_markdown(),
2705 indoc! {"
2706 ## User
2707
2708 Hello
2709
2710 ## Assistant
2711
2712 Hey!
2713 "}
2714 );
2715 });
2716
2717 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2718 fake_model
2719 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2720 let events = events.collect::<Vec<_>>().await;
2721 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2722 thread.read_with(cx, |thread, _| {
2723 assert_eq!(thread.to_markdown(), "");
2724 });
2725}
2726
2727#[gpui::test]
2728async fn test_truncate_first_message(cx: &mut TestAppContext) {
2729 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2730 let fake_model = model.as_fake();
2731
2732 let message_id = UserMessageId::new();
2733 thread
2734 .update(cx, |thread, cx| {
2735 thread.send(message_id.clone(), ["Hello"], cx)
2736 })
2737 .unwrap();
2738 cx.run_until_parked();
2739 thread.read_with(cx, |thread, _| {
2740 assert_eq!(
2741 thread.to_markdown(),
2742 indoc! {"
2743 ## User
2744
2745 Hello
2746 "}
2747 );
2748 assert_eq!(thread.latest_token_usage(), None);
2749 });
2750
2751 fake_model.send_last_completion_stream_text_chunk("Hey!");
2752 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2753 language_model::TokenUsage {
2754 input_tokens: 32_000,
2755 output_tokens: 16_000,
2756 cache_creation_input_tokens: 0,
2757 cache_read_input_tokens: 0,
2758 },
2759 ));
2760 cx.run_until_parked();
2761 thread.read_with(cx, |thread, _| {
2762 assert_eq!(
2763 thread.to_markdown(),
2764 indoc! {"
2765 ## User
2766
2767 Hello
2768
2769 ## Assistant
2770
2771 Hey!
2772 "}
2773 );
2774 assert_eq!(
2775 thread.latest_token_usage(),
2776 Some(acp_thread::TokenUsage {
2777 used_tokens: 32_000 + 16_000,
2778 max_tokens: 1_000_000,
2779 max_output_tokens: None,
2780 input_tokens: 32_000,
2781 output_tokens: 16_000,
2782 })
2783 );
2784 });
2785
2786 thread
2787 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2788 .unwrap();
2789 cx.run_until_parked();
2790 thread.read_with(cx, |thread, _| {
2791 assert_eq!(thread.to_markdown(), "");
2792 assert_eq!(thread.latest_token_usage(), None);
2793 });
2794
2795 // Ensure we can still send a new message after truncation.
2796 thread
2797 .update(cx, |thread, cx| {
2798 thread.send(UserMessageId::new(), ["Hi"], cx)
2799 })
2800 .unwrap();
2801 thread.update(cx, |thread, _cx| {
2802 assert_eq!(
2803 thread.to_markdown(),
2804 indoc! {"
2805 ## User
2806
2807 Hi
2808 "}
2809 );
2810 });
2811 cx.run_until_parked();
2812 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2813 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2814 language_model::TokenUsage {
2815 input_tokens: 40_000,
2816 output_tokens: 20_000,
2817 cache_creation_input_tokens: 0,
2818 cache_read_input_tokens: 0,
2819 },
2820 ));
2821 cx.run_until_parked();
2822 thread.read_with(cx, |thread, _| {
2823 assert_eq!(
2824 thread.to_markdown(),
2825 indoc! {"
2826 ## User
2827
2828 Hi
2829
2830 ## Assistant
2831
2832 Ahoy!
2833 "}
2834 );
2835
2836 assert_eq!(
2837 thread.latest_token_usage(),
2838 Some(acp_thread::TokenUsage {
2839 used_tokens: 40_000 + 20_000,
2840 max_tokens: 1_000_000,
2841 max_output_tokens: None,
2842 input_tokens: 40_000,
2843 output_tokens: 20_000,
2844 })
2845 );
2846 });
2847}
2848
2849#[gpui::test]
2850async fn test_truncate_second_message(cx: &mut TestAppContext) {
2851 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2852 let fake_model = model.as_fake();
2853
2854 thread
2855 .update(cx, |thread, cx| {
2856 thread.send(UserMessageId::new(), ["Message 1"], cx)
2857 })
2858 .unwrap();
2859 cx.run_until_parked();
2860 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2861 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2862 language_model::TokenUsage {
2863 input_tokens: 32_000,
2864 output_tokens: 16_000,
2865 cache_creation_input_tokens: 0,
2866 cache_read_input_tokens: 0,
2867 },
2868 ));
2869 fake_model.end_last_completion_stream();
2870 cx.run_until_parked();
2871
2872 let assert_first_message_state = |cx: &mut TestAppContext| {
2873 thread.clone().read_with(cx, |thread, _| {
2874 assert_eq!(
2875 thread.to_markdown(),
2876 indoc! {"
2877 ## User
2878
2879 Message 1
2880
2881 ## Assistant
2882
2883 Message 1 response
2884 "}
2885 );
2886
2887 assert_eq!(
2888 thread.latest_token_usage(),
2889 Some(acp_thread::TokenUsage {
2890 used_tokens: 32_000 + 16_000,
2891 max_tokens: 1_000_000,
2892 max_output_tokens: None,
2893 input_tokens: 32_000,
2894 output_tokens: 16_000,
2895 })
2896 );
2897 });
2898 };
2899
2900 assert_first_message_state(cx);
2901
2902 let second_message_id = UserMessageId::new();
2903 thread
2904 .update(cx, |thread, cx| {
2905 thread.send(second_message_id.clone(), ["Message 2"], cx)
2906 })
2907 .unwrap();
2908 cx.run_until_parked();
2909
2910 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2911 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2912 language_model::TokenUsage {
2913 input_tokens: 40_000,
2914 output_tokens: 20_000,
2915 cache_creation_input_tokens: 0,
2916 cache_read_input_tokens: 0,
2917 },
2918 ));
2919 fake_model.end_last_completion_stream();
2920 cx.run_until_parked();
2921
2922 thread.read_with(cx, |thread, _| {
2923 assert_eq!(
2924 thread.to_markdown(),
2925 indoc! {"
2926 ## User
2927
2928 Message 1
2929
2930 ## Assistant
2931
2932 Message 1 response
2933
2934 ## User
2935
2936 Message 2
2937
2938 ## Assistant
2939
2940 Message 2 response
2941 "}
2942 );
2943
2944 assert_eq!(
2945 thread.latest_token_usage(),
2946 Some(acp_thread::TokenUsage {
2947 used_tokens: 40_000 + 20_000,
2948 max_tokens: 1_000_000,
2949 max_output_tokens: None,
2950 input_tokens: 40_000,
2951 output_tokens: 20_000,
2952 })
2953 );
2954 });
2955
2956 thread
2957 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2958 .unwrap();
2959 cx.run_until_parked();
2960
2961 assert_first_message_state(cx);
2962}
2963
2964#[gpui::test]
2965async fn test_title_generation(cx: &mut TestAppContext) {
2966 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2967 let fake_model = model.as_fake();
2968
2969 let summary_model = Arc::new(FakeLanguageModel::default());
2970 thread.update(cx, |thread, cx| {
2971 thread.set_summarization_model(Some(summary_model.clone()), cx)
2972 });
2973
2974 let send = thread
2975 .update(cx, |thread, cx| {
2976 thread.send(UserMessageId::new(), ["Hello"], cx)
2977 })
2978 .unwrap();
2979 cx.run_until_parked();
2980
2981 fake_model.send_last_completion_stream_text_chunk("Hey!");
2982 fake_model.end_last_completion_stream();
2983 cx.run_until_parked();
2984 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2985
2986 // Ensure the summary model has been invoked to generate a title.
2987 summary_model.send_last_completion_stream_text_chunk("Hello ");
2988 summary_model.send_last_completion_stream_text_chunk("world\nG");
2989 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2990 summary_model.end_last_completion_stream();
2991 send.collect::<Vec<_>>().await;
2992 cx.run_until_parked();
2993 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2994
2995 // Send another message, ensuring no title is generated this time.
2996 let send = thread
2997 .update(cx, |thread, cx| {
2998 thread.send(UserMessageId::new(), ["Hello again"], cx)
2999 })
3000 .unwrap();
3001 cx.run_until_parked();
3002 fake_model.send_last_completion_stream_text_chunk("Hey again!");
3003 fake_model.end_last_completion_stream();
3004 cx.run_until_parked();
3005 assert_eq!(summary_model.pending_completions(), Vec::new());
3006 send.collect::<Vec<_>>().await;
3007 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3008}
3009
3010#[gpui::test]
3011async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
3012 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3013 let fake_model = model.as_fake();
3014
3015 let _events = thread
3016 .update(cx, |thread, cx| {
3017 thread.add_tool(ToolRequiringPermission, None);
3018 thread.add_tool(EchoTool, None);
3019 thread.send(UserMessageId::new(), ["Hey!"], cx)
3020 })
3021 .unwrap();
3022 cx.run_until_parked();
3023
3024 let permission_tool_use = LanguageModelToolUse {
3025 id: "tool_id_1".into(),
3026 name: ToolRequiringPermission::NAME.into(),
3027 raw_input: "{}".into(),
3028 input: json!({}),
3029 is_input_complete: true,
3030 thought_signature: None,
3031 };
3032 let echo_tool_use = LanguageModelToolUse {
3033 id: "tool_id_2".into(),
3034 name: EchoTool::NAME.into(),
3035 raw_input: json!({"text": "test"}).to_string(),
3036 input: json!({"text": "test"}),
3037 is_input_complete: true,
3038 thought_signature: None,
3039 };
3040 fake_model.send_last_completion_stream_text_chunk("Hi!");
3041 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3042 permission_tool_use,
3043 ));
3044 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3045 echo_tool_use.clone(),
3046 ));
3047 fake_model.end_last_completion_stream();
3048 cx.run_until_parked();
3049
3050 // Ensure pending tools are skipped when building a request.
3051 let request = thread
3052 .read_with(cx, |thread, cx| {
3053 thread.build_completion_request(CompletionIntent::EditFile, cx)
3054 })
3055 .unwrap();
3056 assert_eq!(
3057 request.messages[1..],
3058 vec![
3059 LanguageModelRequestMessage {
3060 role: Role::User,
3061 content: vec!["Hey!".into()],
3062 cache: true,
3063 reasoning_details: None,
3064 },
3065 LanguageModelRequestMessage {
3066 role: Role::Assistant,
3067 content: vec![
3068 MessageContent::Text("Hi!".into()),
3069 MessageContent::ToolUse(echo_tool_use.clone())
3070 ],
3071 cache: false,
3072 reasoning_details: None,
3073 },
3074 LanguageModelRequestMessage {
3075 role: Role::User,
3076 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3077 tool_use_id: echo_tool_use.id.clone(),
3078 tool_name: echo_tool_use.name,
3079 is_error: false,
3080 content: "test".into(),
3081 output: Some("test".into())
3082 })],
3083 cache: false,
3084 reasoning_details: None,
3085 },
3086 ],
3087 );
3088}
3089
3090#[gpui::test]
3091async fn test_agent_connection(cx: &mut TestAppContext) {
3092 cx.update(settings::init);
3093 let templates = Templates::new();
3094
3095 // Initialize language model system with test provider
3096 cx.update(|cx| {
3097 gpui_tokio::init(cx);
3098
3099 let http_client = FakeHttpClient::with_404_response();
3100 let clock = Arc::new(clock::FakeSystemClock::new());
3101 let client = Client::new(clock, http_client, cx);
3102 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3103 language_model::init(client.clone(), cx);
3104 language_models::init(user_store, client.clone(), cx);
3105 LanguageModelRegistry::test(cx);
3106 });
3107 cx.executor().forbid_parking();
3108
3109 // Create a project for new_thread
3110 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3111 fake_fs.insert_tree(path!("/test"), json!({})).await;
3112 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3113 let cwd = Path::new("/test");
3114 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3115
3116 // Create agent and connection
3117 let agent = NativeAgent::new(
3118 project.clone(),
3119 thread_store,
3120 templates.clone(),
3121 None,
3122 fake_fs.clone(),
3123 &mut cx.to_async(),
3124 )
3125 .await
3126 .unwrap();
3127 let connection = NativeAgentConnection(agent.clone());
3128
3129 // Create a thread using new_thread
3130 let connection_rc = Rc::new(connection.clone());
3131 let acp_thread = cx
3132 .update(|cx| connection_rc.new_session(project, cwd, cx))
3133 .await
3134 .expect("new_thread should succeed");
3135
3136 // Get the session_id from the AcpThread
3137 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3138
3139 // Test model_selector returns Some
3140 let selector_opt = connection.model_selector(&session_id);
3141 assert!(
3142 selector_opt.is_some(),
3143 "agent should always support ModelSelector"
3144 );
3145 let selector = selector_opt.unwrap();
3146
3147 // Test list_models
3148 let listed_models = cx
3149 .update(|cx| selector.list_models(cx))
3150 .await
3151 .expect("list_models should succeed");
3152 let AgentModelList::Grouped(listed_models) = listed_models else {
3153 panic!("Unexpected model list type");
3154 };
3155 assert!(!listed_models.is_empty(), "should have at least one model");
3156 assert_eq!(
3157 listed_models[&AgentModelGroupName("Fake".into())][0]
3158 .id
3159 .0
3160 .as_ref(),
3161 "fake/fake"
3162 );
3163
3164 // Test selected_model returns the default
3165 let model = cx
3166 .update(|cx| selector.selected_model(cx))
3167 .await
3168 .expect("selected_model should succeed");
3169 let model = cx
3170 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3171 .unwrap();
3172 let model = model.as_fake();
3173 assert_eq!(model.id().0, "fake", "should return default model");
3174
3175 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3176 cx.run_until_parked();
3177 model.send_last_completion_stream_text_chunk("def");
3178 cx.run_until_parked();
3179 acp_thread.read_with(cx, |thread, cx| {
3180 assert_eq!(
3181 thread.to_markdown(cx),
3182 indoc! {"
3183 ## User
3184
3185 abc
3186
3187 ## Assistant
3188
3189 def
3190
3191 "}
3192 )
3193 });
3194
3195 // Test cancel
3196 cx.update(|cx| connection.cancel(&session_id, cx));
3197 request.await.expect("prompt should fail gracefully");
3198
3199 // Explicitly close the session and drop the ACP thread.
3200 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3201 .await
3202 .unwrap();
3203 drop(acp_thread);
3204 let result = cx
3205 .update(|cx| {
3206 connection.prompt(
3207 Some(acp_thread::UserMessageId::new()),
3208 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3209 cx,
3210 )
3211 })
3212 .await;
3213 assert_eq!(
3214 result.as_ref().unwrap_err().to_string(),
3215 "Session not found",
3216 "unexpected result: {:?}",
3217 result
3218 );
3219}
3220
3221#[gpui::test]
3222async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3223 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3224 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
3225 let fake_model = model.as_fake();
3226
3227 let mut events = thread
3228 .update(cx, |thread, cx| {
3229 thread.send(UserMessageId::new(), ["Echo something"], cx)
3230 })
3231 .unwrap();
3232 cx.run_until_parked();
3233
3234 // Simulate streaming partial input.
3235 let input = json!({});
3236 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3237 LanguageModelToolUse {
3238 id: "1".into(),
3239 name: EchoTool::NAME.into(),
3240 raw_input: input.to_string(),
3241 input,
3242 is_input_complete: false,
3243 thought_signature: None,
3244 },
3245 ));
3246
3247 // Input streaming completed
3248 let input = json!({ "text": "Hello!" });
3249 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3250 LanguageModelToolUse {
3251 id: "1".into(),
3252 name: "echo".into(),
3253 raw_input: input.to_string(),
3254 input,
3255 is_input_complete: true,
3256 thought_signature: None,
3257 },
3258 ));
3259 fake_model.end_last_completion_stream();
3260 cx.run_until_parked();
3261
3262 let tool_call = expect_tool_call(&mut events).await;
3263 assert_eq!(
3264 tool_call,
3265 acp::ToolCall::new("1", "Echo")
3266 .raw_input(json!({}))
3267 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3268 );
3269 let update = expect_tool_call_update_fields(&mut events).await;
3270 assert_eq!(
3271 update,
3272 acp::ToolCallUpdate::new(
3273 "1",
3274 acp::ToolCallUpdateFields::new()
3275 .title("Echo")
3276 .kind(acp::ToolKind::Other)
3277 .raw_input(json!({ "text": "Hello!"}))
3278 )
3279 );
3280 let update = expect_tool_call_update_fields(&mut events).await;
3281 assert_eq!(
3282 update,
3283 acp::ToolCallUpdate::new(
3284 "1",
3285 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3286 )
3287 );
3288 let update = expect_tool_call_update_fields(&mut events).await;
3289 assert_eq!(
3290 update,
3291 acp::ToolCallUpdate::new(
3292 "1",
3293 acp::ToolCallUpdateFields::new()
3294 .status(acp::ToolCallStatus::Completed)
3295 .raw_output("Hello!")
3296 )
3297 );
3298}
3299
3300#[gpui::test]
3301async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3302 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3303 let fake_model = model.as_fake();
3304
3305 let mut events = thread
3306 .update(cx, |thread, cx| {
3307 thread.send(UserMessageId::new(), ["Hello!"], cx)
3308 })
3309 .unwrap();
3310 cx.run_until_parked();
3311
3312 fake_model.send_last_completion_stream_text_chunk("Hey!");
3313 fake_model.end_last_completion_stream();
3314
3315 let mut retry_events = Vec::new();
3316 while let Some(Ok(event)) = events.next().await {
3317 match event {
3318 ThreadEvent::Retry(retry_status) => {
3319 retry_events.push(retry_status);
3320 }
3321 ThreadEvent::Stop(..) => break,
3322 _ => {}
3323 }
3324 }
3325
3326 assert_eq!(retry_events.len(), 0);
3327 thread.read_with(cx, |thread, _cx| {
3328 assert_eq!(
3329 thread.to_markdown(),
3330 indoc! {"
3331 ## User
3332
3333 Hello!
3334
3335 ## Assistant
3336
3337 Hey!
3338 "}
3339 )
3340 });
3341}
3342
3343#[gpui::test]
3344async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3345 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3346 let fake_model = model.as_fake();
3347
3348 let mut events = thread
3349 .update(cx, |thread, cx| {
3350 thread.send(UserMessageId::new(), ["Hello!"], cx)
3351 })
3352 .unwrap();
3353 cx.run_until_parked();
3354
3355 fake_model.send_last_completion_stream_text_chunk("Hey,");
3356 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3357 provider: LanguageModelProviderName::new("Anthropic"),
3358 retry_after: Some(Duration::from_secs(3)),
3359 });
3360 fake_model.end_last_completion_stream();
3361
3362 cx.executor().advance_clock(Duration::from_secs(3));
3363 cx.run_until_parked();
3364
3365 fake_model.send_last_completion_stream_text_chunk("there!");
3366 fake_model.end_last_completion_stream();
3367 cx.run_until_parked();
3368
3369 let mut retry_events = Vec::new();
3370 while let Some(Ok(event)) = events.next().await {
3371 match event {
3372 ThreadEvent::Retry(retry_status) => {
3373 retry_events.push(retry_status);
3374 }
3375 ThreadEvent::Stop(..) => break,
3376 _ => {}
3377 }
3378 }
3379
3380 assert_eq!(retry_events.len(), 1);
3381 assert!(matches!(
3382 retry_events[0],
3383 acp_thread::RetryStatus { attempt: 1, .. }
3384 ));
3385 thread.read_with(cx, |thread, _cx| {
3386 assert_eq!(
3387 thread.to_markdown(),
3388 indoc! {"
3389 ## User
3390
3391 Hello!
3392
3393 ## Assistant
3394
3395 Hey,
3396
3397 [resume]
3398
3399 ## Assistant
3400
3401 there!
3402 "}
3403 )
3404 });
3405}
3406
3407#[gpui::test]
3408async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3409 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3410 let fake_model = model.as_fake();
3411
3412 let events = thread
3413 .update(cx, |thread, cx| {
3414 thread.add_tool(EchoTool, None);
3415 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3416 })
3417 .unwrap();
3418 cx.run_until_parked();
3419
3420 let tool_use_1 = LanguageModelToolUse {
3421 id: "tool_1".into(),
3422 name: EchoTool::NAME.into(),
3423 raw_input: json!({"text": "test"}).to_string(),
3424 input: json!({"text": "test"}),
3425 is_input_complete: true,
3426 thought_signature: None,
3427 };
3428 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3429 tool_use_1.clone(),
3430 ));
3431 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3432 provider: LanguageModelProviderName::new("Anthropic"),
3433 retry_after: Some(Duration::from_secs(3)),
3434 });
3435 fake_model.end_last_completion_stream();
3436
3437 cx.executor().advance_clock(Duration::from_secs(3));
3438 let completion = fake_model.pending_completions().pop().unwrap();
3439 assert_eq!(
3440 completion.messages[1..],
3441 vec![
3442 LanguageModelRequestMessage {
3443 role: Role::User,
3444 content: vec!["Call the echo tool!".into()],
3445 cache: false,
3446 reasoning_details: None,
3447 },
3448 LanguageModelRequestMessage {
3449 role: Role::Assistant,
3450 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3451 cache: false,
3452 reasoning_details: None,
3453 },
3454 LanguageModelRequestMessage {
3455 role: Role::User,
3456 content: vec![language_model::MessageContent::ToolResult(
3457 LanguageModelToolResult {
3458 tool_use_id: tool_use_1.id.clone(),
3459 tool_name: tool_use_1.name.clone(),
3460 is_error: false,
3461 content: "test".into(),
3462 output: Some("test".into())
3463 }
3464 )],
3465 cache: true,
3466 reasoning_details: None,
3467 },
3468 ]
3469 );
3470
3471 fake_model.send_last_completion_stream_text_chunk("Done");
3472 fake_model.end_last_completion_stream();
3473 cx.run_until_parked();
3474 events.collect::<Vec<_>>().await;
3475 thread.read_with(cx, |thread, _cx| {
3476 assert_eq!(
3477 thread.last_message(),
3478 Some(Message::Agent(AgentMessage {
3479 content: vec![AgentMessageContent::Text("Done".into())],
3480 tool_results: IndexMap::default(),
3481 reasoning_details: None,
3482 }))
3483 );
3484 })
3485}
3486
3487#[gpui::test]
3488async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3489 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3490 let fake_model = model.as_fake();
3491
3492 let mut events = thread
3493 .update(cx, |thread, cx| {
3494 thread.send(UserMessageId::new(), ["Hello!"], cx)
3495 })
3496 .unwrap();
3497 cx.run_until_parked();
3498
3499 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3500 fake_model.send_last_completion_stream_error(
3501 LanguageModelCompletionError::ServerOverloaded {
3502 provider: LanguageModelProviderName::new("Anthropic"),
3503 retry_after: Some(Duration::from_secs(3)),
3504 },
3505 );
3506 fake_model.end_last_completion_stream();
3507 cx.executor().advance_clock(Duration::from_secs(3));
3508 cx.run_until_parked();
3509 }
3510
3511 let mut errors = Vec::new();
3512 let mut retry_events = Vec::new();
3513 while let Some(event) = events.next().await {
3514 match event {
3515 Ok(ThreadEvent::Retry(retry_status)) => {
3516 retry_events.push(retry_status);
3517 }
3518 Ok(ThreadEvent::Stop(..)) => break,
3519 Err(error) => errors.push(error),
3520 _ => {}
3521 }
3522 }
3523
3524 assert_eq!(
3525 retry_events.len(),
3526 crate::thread::MAX_RETRY_ATTEMPTS as usize
3527 );
3528 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3529 assert_eq!(retry_events[i].attempt, i + 1);
3530 }
3531 assert_eq!(errors.len(), 1);
3532 let error = errors[0]
3533 .downcast_ref::<LanguageModelCompletionError>()
3534 .unwrap();
3535 assert!(matches!(
3536 error,
3537 LanguageModelCompletionError::ServerOverloaded { .. }
3538 ));
3539}
3540
3541/// Filters out the stop events for asserting against in tests
3542fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3543 result_events
3544 .into_iter()
3545 .filter_map(|event| match event.unwrap() {
3546 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3547 _ => None,
3548 })
3549 .collect()
3550}
3551
3552struct ThreadTest {
3553 model: Arc<dyn LanguageModel>,
3554 thread: Entity<Thread>,
3555 project_context: Entity<ProjectContext>,
3556 context_server_store: Entity<ContextServerStore>,
3557 fs: Arc<FakeFs>,
3558}
3559
3560enum TestModel {
3561 Sonnet4,
3562 Fake,
3563}
3564
3565impl TestModel {
3566 fn id(&self) -> LanguageModelId {
3567 match self {
3568 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3569 TestModel::Fake => unreachable!(),
3570 }
3571 }
3572}
3573
3574async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3575 cx.executor().allow_parking();
3576
3577 let fs = FakeFs::new(cx.background_executor.clone());
3578 fs.create_dir(paths::settings_file().parent().unwrap())
3579 .await
3580 .unwrap();
3581 fs.insert_file(
3582 paths::settings_file(),
3583 json!({
3584 "agent": {
3585 "default_profile": "test-profile",
3586 "profiles": {
3587 "test-profile": {
3588 "name": "Test Profile",
3589 "tools": {
3590 EchoTool::NAME: true,
3591 DelayTool::NAME: true,
3592 WordListTool::NAME: true,
3593 ToolRequiringPermission::NAME: true,
3594 InfiniteTool::NAME: true,
3595 CancellationAwareTool::NAME: true,
3596 (TerminalTool::NAME): true,
3597 }
3598 }
3599 }
3600 }
3601 })
3602 .to_string()
3603 .into_bytes(),
3604 )
3605 .await;
3606
3607 cx.update(|cx| {
3608 settings::init(cx);
3609
3610 match model {
3611 TestModel::Fake => {}
3612 TestModel::Sonnet4 => {
3613 gpui_tokio::init(cx);
3614 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3615 cx.set_http_client(Arc::new(http_client));
3616 let client = Client::production(cx);
3617 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3618 language_model::init(client.clone(), cx);
3619 language_models::init(user_store, client.clone(), cx);
3620 }
3621 };
3622
3623 watch_settings(fs.clone(), cx);
3624 });
3625
3626 let templates = Templates::new();
3627
3628 fs.insert_tree(path!("/test"), json!({})).await;
3629 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3630
3631 let model = cx
3632 .update(|cx| {
3633 if let TestModel::Fake = model {
3634 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3635 } else {
3636 let model_id = model.id();
3637 let models = LanguageModelRegistry::read_global(cx);
3638 let model = models
3639 .available_models(cx)
3640 .find(|model| model.id() == model_id)
3641 .unwrap();
3642
3643 let provider = models.provider(&model.provider_id()).unwrap();
3644 let authenticated = provider.authenticate(cx);
3645
3646 cx.spawn(async move |_cx| {
3647 authenticated.await.unwrap();
3648 model
3649 })
3650 }
3651 })
3652 .await;
3653
3654 let project_context = cx.new(|_cx| ProjectContext::default());
3655 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3656 let context_server_registry =
3657 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3658 let thread = cx.new(|cx| {
3659 Thread::new(
3660 project,
3661 project_context.clone(),
3662 context_server_registry,
3663 templates,
3664 Some(model.clone()),
3665 cx,
3666 )
3667 });
3668 ThreadTest {
3669 model,
3670 thread,
3671 project_context,
3672 context_server_store,
3673 fs,
3674 }
3675}
3676
3677#[cfg(test)]
3678#[ctor::ctor]
3679fn init_logger() {
3680 if std::env::var("RUST_LOG").is_ok() {
3681 env_logger::init();
3682 }
3683}
3684
3685fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3686 let fs = fs.clone();
3687 cx.spawn({
3688 async move |cx| {
3689 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3690 cx.background_executor(),
3691 fs,
3692 paths::settings_file().clone(),
3693 );
3694 let _watcher_task = watcher_task;
3695
3696 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3697 cx.update(|cx| {
3698 SettingsStore::update_global(cx, |settings, cx| {
3699 settings.set_user_settings(&new_settings_content, cx)
3700 })
3701 })
3702 .ok();
3703 }
3704 }
3705 })
3706 .detach();
3707}
3708
3709fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3710 completion
3711 .tools
3712 .iter()
3713 .map(|tool| tool.name.clone())
3714 .collect()
3715}
3716
3717fn setup_context_server(
3718 name: &'static str,
3719 tools: Vec<context_server::types::Tool>,
3720 context_server_store: &Entity<ContextServerStore>,
3721 cx: &mut TestAppContext,
3722) -> mpsc::UnboundedReceiver<(
3723 context_server::types::CallToolParams,
3724 oneshot::Sender<context_server::types::CallToolResponse>,
3725)> {
3726 cx.update(|cx| {
3727 let mut settings = ProjectSettings::get_global(cx).clone();
3728 settings.context_servers.insert(
3729 name.into(),
3730 project::project_settings::ContextServerSettings::Stdio {
3731 enabled: true,
3732 remote: false,
3733 command: ContextServerCommand {
3734 path: "somebinary".into(),
3735 args: Vec::new(),
3736 env: None,
3737 timeout: None,
3738 },
3739 },
3740 );
3741 ProjectSettings::override_global(settings, cx);
3742 });
3743
3744 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3745 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3746 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3747 context_server::types::InitializeResponse {
3748 protocol_version: context_server::types::ProtocolVersion(
3749 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3750 ),
3751 server_info: context_server::types::Implementation {
3752 name: name.into(),
3753 version: "1.0.0".to_string(),
3754 },
3755 capabilities: context_server::types::ServerCapabilities {
3756 tools: Some(context_server::types::ToolsCapabilities {
3757 list_changed: Some(true),
3758 }),
3759 ..Default::default()
3760 },
3761 meta: None,
3762 }
3763 })
3764 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3765 let tools = tools.clone();
3766 async move {
3767 context_server::types::ListToolsResponse {
3768 tools,
3769 next_cursor: None,
3770 meta: None,
3771 }
3772 }
3773 })
3774 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3775 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3776 async move {
3777 let (response_tx, response_rx) = oneshot::channel();
3778 mcp_tool_calls_tx
3779 .unbounded_send((params, response_tx))
3780 .unwrap();
3781 response_rx.await.unwrap()
3782 }
3783 });
3784 context_server_store.update(cx, |store, cx| {
3785 store.start_server(
3786 Arc::new(ContextServer::new(
3787 ContextServerId(name.into()),
3788 Arc::new(fake_transport),
3789 )),
3790 cx,
3791 );
3792 });
3793 cx.run_until_parked();
3794 mcp_tool_calls_rx
3795}
3796
3797#[gpui::test]
3798async fn test_tokens_before_message(cx: &mut TestAppContext) {
3799 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3800 let fake_model = model.as_fake();
3801
3802 // First message
3803 let message_1_id = UserMessageId::new();
3804 thread
3805 .update(cx, |thread, cx| {
3806 thread.send(message_1_id.clone(), ["First message"], cx)
3807 })
3808 .unwrap();
3809 cx.run_until_parked();
3810
3811 // Before any response, tokens_before_message should return None for first message
3812 thread.read_with(cx, |thread, _| {
3813 assert_eq!(
3814 thread.tokens_before_message(&message_1_id),
3815 None,
3816 "First message should have no tokens before it"
3817 );
3818 });
3819
3820 // Complete first message with usage
3821 fake_model.send_last_completion_stream_text_chunk("Response 1");
3822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3823 language_model::TokenUsage {
3824 input_tokens: 100,
3825 output_tokens: 50,
3826 cache_creation_input_tokens: 0,
3827 cache_read_input_tokens: 0,
3828 },
3829 ));
3830 fake_model.end_last_completion_stream();
3831 cx.run_until_parked();
3832
3833 // First message still has no tokens before it
3834 thread.read_with(cx, |thread, _| {
3835 assert_eq!(
3836 thread.tokens_before_message(&message_1_id),
3837 None,
3838 "First message should still have no tokens before it after response"
3839 );
3840 });
3841
3842 // Second message
3843 let message_2_id = UserMessageId::new();
3844 thread
3845 .update(cx, |thread, cx| {
3846 thread.send(message_2_id.clone(), ["Second message"], cx)
3847 })
3848 .unwrap();
3849 cx.run_until_parked();
3850
3851 // Second message should have first message's input tokens before it
3852 thread.read_with(cx, |thread, _| {
3853 assert_eq!(
3854 thread.tokens_before_message(&message_2_id),
3855 Some(100),
3856 "Second message should have 100 tokens before it (from first request)"
3857 );
3858 });
3859
3860 // Complete second message
3861 fake_model.send_last_completion_stream_text_chunk("Response 2");
3862 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3863 language_model::TokenUsage {
3864 input_tokens: 250, // Total for this request (includes previous context)
3865 output_tokens: 75,
3866 cache_creation_input_tokens: 0,
3867 cache_read_input_tokens: 0,
3868 },
3869 ));
3870 fake_model.end_last_completion_stream();
3871 cx.run_until_parked();
3872
3873 // Third message
3874 let message_3_id = UserMessageId::new();
3875 thread
3876 .update(cx, |thread, cx| {
3877 thread.send(message_3_id.clone(), ["Third message"], cx)
3878 })
3879 .unwrap();
3880 cx.run_until_parked();
3881
3882 // Third message should have second message's input tokens (250) before it
3883 thread.read_with(cx, |thread, _| {
3884 assert_eq!(
3885 thread.tokens_before_message(&message_3_id),
3886 Some(250),
3887 "Third message should have 250 tokens before it (from second request)"
3888 );
3889 // Second message should still have 100
3890 assert_eq!(
3891 thread.tokens_before_message(&message_2_id),
3892 Some(100),
3893 "Second message should still have 100 tokens before it"
3894 );
3895 // First message still has none
3896 assert_eq!(
3897 thread.tokens_before_message(&message_1_id),
3898 None,
3899 "First message should still have no tokens before it"
3900 );
3901 });
3902}
3903
3904#[gpui::test]
3905async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3906 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3907 let fake_model = model.as_fake();
3908
3909 // Set up three messages with responses
3910 let message_1_id = UserMessageId::new();
3911 thread
3912 .update(cx, |thread, cx| {
3913 thread.send(message_1_id.clone(), ["Message 1"], cx)
3914 })
3915 .unwrap();
3916 cx.run_until_parked();
3917 fake_model.send_last_completion_stream_text_chunk("Response 1");
3918 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3919 language_model::TokenUsage {
3920 input_tokens: 100,
3921 output_tokens: 50,
3922 cache_creation_input_tokens: 0,
3923 cache_read_input_tokens: 0,
3924 },
3925 ));
3926 fake_model.end_last_completion_stream();
3927 cx.run_until_parked();
3928
3929 let message_2_id = UserMessageId::new();
3930 thread
3931 .update(cx, |thread, cx| {
3932 thread.send(message_2_id.clone(), ["Message 2"], cx)
3933 })
3934 .unwrap();
3935 cx.run_until_parked();
3936 fake_model.send_last_completion_stream_text_chunk("Response 2");
3937 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3938 language_model::TokenUsage {
3939 input_tokens: 250,
3940 output_tokens: 75,
3941 cache_creation_input_tokens: 0,
3942 cache_read_input_tokens: 0,
3943 },
3944 ));
3945 fake_model.end_last_completion_stream();
3946 cx.run_until_parked();
3947
3948 // Verify initial state
3949 thread.read_with(cx, |thread, _| {
3950 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3951 });
3952
3953 // Truncate at message 2 (removes message 2 and everything after)
3954 thread
3955 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3956 .unwrap();
3957 cx.run_until_parked();
3958
3959 // After truncation, message_2_id no longer exists, so lookup should return None
3960 thread.read_with(cx, |thread, _| {
3961 assert_eq!(
3962 thread.tokens_before_message(&message_2_id),
3963 None,
3964 "After truncation, message 2 no longer exists"
3965 );
3966 // Message 1 still exists but has no tokens before it
3967 assert_eq!(
3968 thread.tokens_before_message(&message_1_id),
3969 None,
3970 "First message still has no tokens before it"
3971 );
3972 });
3973}
3974
3975#[gpui::test]
3976async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3977 init_test(cx);
3978
3979 let fs = FakeFs::new(cx.executor());
3980 fs.insert_tree("/root", json!({})).await;
3981 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3982
3983 // Test 1: Deny rule blocks command
3984 {
3985 let environment = Rc::new(cx.update(|cx| {
3986 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3987 }));
3988
3989 cx.update(|cx| {
3990 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3991 settings.tool_permissions.tools.insert(
3992 TerminalTool::NAME.into(),
3993 agent_settings::ToolRules {
3994 default: Some(settings::ToolPermissionMode::Confirm),
3995 always_allow: vec![],
3996 always_deny: vec![
3997 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3998 ],
3999 always_confirm: vec![],
4000 invalid_patterns: vec![],
4001 },
4002 );
4003 agent_settings::AgentSettings::override_global(settings, cx);
4004 });
4005
4006 #[allow(clippy::arc_with_non_send_sync)]
4007 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4008 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4009
4010 let task = cx.update(|cx| {
4011 tool.run(
4012 crate::TerminalToolInput {
4013 command: "rm -rf /".to_string(),
4014 cd: ".".to_string(),
4015 timeout_ms: None,
4016 },
4017 event_stream,
4018 cx,
4019 )
4020 });
4021
4022 let result = task.await;
4023 assert!(
4024 result.is_err(),
4025 "expected command to be blocked by deny rule"
4026 );
4027 let err_msg = result.unwrap_err().to_string().to_lowercase();
4028 assert!(
4029 err_msg.contains("blocked"),
4030 "error should mention the command was blocked"
4031 );
4032 }
4033
4034 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4035 {
4036 let environment = Rc::new(cx.update(|cx| {
4037 FakeThreadEnvironment::default()
4038 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4039 }));
4040
4041 cx.update(|cx| {
4042 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4043 settings.tool_permissions.tools.insert(
4044 TerminalTool::NAME.into(),
4045 agent_settings::ToolRules {
4046 default: Some(settings::ToolPermissionMode::Deny),
4047 always_allow: vec![
4048 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4049 ],
4050 always_deny: vec![],
4051 always_confirm: vec![],
4052 invalid_patterns: vec![],
4053 },
4054 );
4055 agent_settings::AgentSettings::override_global(settings, cx);
4056 });
4057
4058 #[allow(clippy::arc_with_non_send_sync)]
4059 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4060 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4061
4062 let task = cx.update(|cx| {
4063 tool.run(
4064 crate::TerminalToolInput {
4065 command: "echo hello".to_string(),
4066 cd: ".".to_string(),
4067 timeout_ms: None,
4068 },
4069 event_stream,
4070 cx,
4071 )
4072 });
4073
4074 let update = rx.expect_update_fields().await;
4075 assert!(
4076 update.content.iter().any(|blocks| {
4077 blocks
4078 .iter()
4079 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4080 }),
4081 "expected terminal content (allow rule should skip confirmation and override default deny)"
4082 );
4083
4084 let result = task.await;
4085 assert!(
4086 result.is_ok(),
4087 "expected command to succeed without confirmation"
4088 );
4089 }
4090
4091 // Test 3: global default: allow does NOT override always_confirm patterns
4092 {
4093 let environment = Rc::new(cx.update(|cx| {
4094 FakeThreadEnvironment::default()
4095 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4096 }));
4097
4098 cx.update(|cx| {
4099 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4100 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4101 settings.tool_permissions.tools.insert(
4102 TerminalTool::NAME.into(),
4103 agent_settings::ToolRules {
4104 default: Some(settings::ToolPermissionMode::Allow),
4105 always_allow: vec![],
4106 always_deny: vec![],
4107 always_confirm: vec![
4108 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4109 ],
4110 invalid_patterns: vec![],
4111 },
4112 );
4113 agent_settings::AgentSettings::override_global(settings, cx);
4114 });
4115
4116 #[allow(clippy::arc_with_non_send_sync)]
4117 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4118 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4119
4120 let _task = cx.update(|cx| {
4121 tool.run(
4122 crate::TerminalToolInput {
4123 command: "sudo rm file".to_string(),
4124 cd: ".".to_string(),
4125 timeout_ms: None,
4126 },
4127 event_stream,
4128 cx,
4129 )
4130 });
4131
4132 // With global default: allow, confirm patterns are still respected
4133 // The expect_authorization() call will panic if no authorization is requested,
4134 // which validates that the confirm pattern still triggers confirmation
4135 let _auth = rx.expect_authorization().await;
4136
4137 drop(_task);
4138 }
4139
4140 // Test 4: tool-specific default: deny is respected even with global default: allow
4141 {
4142 let environment = Rc::new(cx.update(|cx| {
4143 FakeThreadEnvironment::default()
4144 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4145 }));
4146
4147 cx.update(|cx| {
4148 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4149 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4150 settings.tool_permissions.tools.insert(
4151 TerminalTool::NAME.into(),
4152 agent_settings::ToolRules {
4153 default: Some(settings::ToolPermissionMode::Deny),
4154 always_allow: vec![],
4155 always_deny: vec![],
4156 always_confirm: vec![],
4157 invalid_patterns: vec![],
4158 },
4159 );
4160 agent_settings::AgentSettings::override_global(settings, cx);
4161 });
4162
4163 #[allow(clippy::arc_with_non_send_sync)]
4164 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4165 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4166
4167 let task = cx.update(|cx| {
4168 tool.run(
4169 crate::TerminalToolInput {
4170 command: "echo hello".to_string(),
4171 cd: ".".to_string(),
4172 timeout_ms: None,
4173 },
4174 event_stream,
4175 cx,
4176 )
4177 });
4178
4179 // tool-specific default: deny is respected even with global default: allow
4180 let result = task.await;
4181 assert!(
4182 result.is_err(),
4183 "expected command to be blocked by tool-specific deny default"
4184 );
4185 let err_msg = result.unwrap_err().to_string().to_lowercase();
4186 assert!(
4187 err_msg.contains("disabled"),
4188 "error should mention the tool is disabled, got: {err_msg}"
4189 );
4190 }
4191}
4192
4193#[gpui::test]
4194async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4195 init_test(cx);
4196
4197 cx.update(|cx| {
4198 cx.update_flags(true, vec!["subagents".to_string()]);
4199 });
4200
4201 let fs = FakeFs::new(cx.executor());
4202 fs.insert_tree(path!("/test"), json!({})).await;
4203 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4204 let project_context = cx.new(|_cx| ProjectContext::default());
4205 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4206 let context_server_registry =
4207 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4208 let model = Arc::new(FakeLanguageModel::default());
4209
4210 let environment = Rc::new(cx.update(|cx| {
4211 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4212 }));
4213
4214 let thread = cx.new(|cx| {
4215 let mut thread = Thread::new(
4216 project.clone(),
4217 project_context,
4218 context_server_registry,
4219 Templates::new(),
4220 Some(model),
4221 cx,
4222 );
4223 thread.add_default_tools(None, environment, cx);
4224 thread
4225 });
4226
4227 thread.read_with(cx, |thread, _| {
4228 assert!(
4229 thread.has_registered_tool(SubagentTool::NAME),
4230 "subagent tool should be present when feature flag is enabled"
4231 );
4232 });
4233}
4234
4235#[gpui::test]
4236async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4237 init_test(cx);
4238
4239 cx.update(|cx| {
4240 cx.update_flags(true, vec!["subagents".to_string()]);
4241 });
4242
4243 let fs = FakeFs::new(cx.executor());
4244 fs.insert_tree(path!("/test"), json!({})).await;
4245 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4246 let project_context = cx.new(|_cx| ProjectContext::default());
4247 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4248 let context_server_registry =
4249 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4250 let model = Arc::new(FakeLanguageModel::default());
4251
4252 let parent_thread = cx.new(|cx| {
4253 Thread::new(
4254 project.clone(),
4255 project_context,
4256 context_server_registry,
4257 Templates::new(),
4258 Some(model.clone()),
4259 cx,
4260 )
4261 });
4262
4263 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4264 subagent_thread.read_with(cx, |subagent_thread, cx| {
4265 assert!(subagent_thread.is_subagent());
4266 assert_eq!(subagent_thread.depth(), 1);
4267 assert_eq!(
4268 subagent_thread.model().map(|model| model.id()),
4269 Some(model.id())
4270 );
4271 assert_eq!(
4272 subagent_thread.parent_thread_id(),
4273 Some(parent_thread.read(cx).id().clone())
4274 );
4275 });
4276}
4277
4278#[gpui::test]
4279async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4280 init_test(cx);
4281
4282 cx.update(|cx| {
4283 cx.update_flags(true, vec!["subagents".to_string()]);
4284 });
4285
4286 let fs = FakeFs::new(cx.executor());
4287 fs.insert_tree(path!("/test"), json!({})).await;
4288 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4289 let project_context = cx.new(|_cx| ProjectContext::default());
4290 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4291 let context_server_registry =
4292 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4293 let model = Arc::new(FakeLanguageModel::default());
4294 let environment = Rc::new(cx.update(|cx| {
4295 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4296 }));
4297
4298 let deep_parent_thread = cx.new(|cx| {
4299 let mut thread = Thread::new(
4300 project.clone(),
4301 project_context,
4302 context_server_registry,
4303 Templates::new(),
4304 Some(model.clone()),
4305 cx,
4306 );
4307 thread.set_subagent_context(SubagentContext {
4308 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4309 depth: MAX_SUBAGENT_DEPTH - 1,
4310 });
4311 thread
4312 });
4313 let deep_subagent_thread = cx.new(|cx| {
4314 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4315 thread.add_default_tools(None, environment, cx);
4316 thread
4317 });
4318
4319 deep_subagent_thread.read_with(cx, |thread, _| {
4320 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4321 assert!(
4322 !thread.has_registered_tool(SubagentTool::NAME),
4323 "subagent tool should not be present at max depth"
4324 );
4325 });
4326}
4327
4328#[gpui::test]
4329async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4330 init_test(cx);
4331
4332 cx.update(|cx| {
4333 cx.update_flags(true, vec!["subagents".to_string()]);
4334 });
4335
4336 let fs = FakeFs::new(cx.executor());
4337 fs.insert_tree(path!("/test"), json!({})).await;
4338 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4339 let project_context = cx.new(|_cx| ProjectContext::default());
4340 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4341 let context_server_registry =
4342 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4343 let model = Arc::new(FakeLanguageModel::default());
4344
4345 let parent = cx.new(|cx| {
4346 Thread::new(
4347 project.clone(),
4348 project_context.clone(),
4349 context_server_registry.clone(),
4350 Templates::new(),
4351 Some(model.clone()),
4352 cx,
4353 )
4354 });
4355
4356 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4357
4358 parent.update(cx, |thread, _cx| {
4359 thread.register_running_subagent(subagent.downgrade());
4360 });
4361
4362 subagent
4363 .update(cx, |thread, cx| {
4364 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4365 })
4366 .unwrap();
4367 cx.run_until_parked();
4368
4369 subagent.read_with(cx, |thread, _| {
4370 assert!(!thread.is_turn_complete(), "subagent should be running");
4371 });
4372
4373 parent.update(cx, |thread, cx| {
4374 thread.cancel(cx).detach();
4375 });
4376
4377 subagent.read_with(cx, |thread, _| {
4378 assert!(
4379 thread.is_turn_complete(),
4380 "subagent should be cancelled when parent cancels"
4381 );
4382 });
4383}
4384
4385#[gpui::test]
4386async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4387 // This test verifies that the subagent tool properly handles user cancellation
4388 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4389 init_test(cx);
4390 always_allow_tools(cx);
4391
4392 cx.update(|cx| {
4393 cx.update_flags(true, vec!["subagents".to_string()]);
4394 });
4395
4396 let fs = FakeFs::new(cx.executor());
4397 fs.insert_tree(path!("/test"), json!({})).await;
4398 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4399 let project_context = cx.new(|_cx| ProjectContext::default());
4400 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4401 let context_server_registry =
4402 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4403 let model = Arc::new(FakeLanguageModel::default());
4404 let environment = Rc::new(cx.update(|cx| {
4405 FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
4406 }));
4407
4408 let parent = cx.new(|cx| {
4409 Thread::new(
4410 project.clone(),
4411 project_context.clone(),
4412 context_server_registry.clone(),
4413 Templates::new(),
4414 Some(model.clone()),
4415 cx,
4416 )
4417 });
4418
4419 #[allow(clippy::arc_with_non_send_sync)]
4420 let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
4421
4422 let (event_stream, _rx, mut cancellation_tx) =
4423 crate::ToolCallEventStream::test_with_cancellation();
4424
4425 // Start the subagent tool
4426 let task = cx.update(|cx| {
4427 tool.run(
4428 SubagentToolInput {
4429 label: "Long running task".to_string(),
4430 task_prompt: "Do a very long task that takes forever".to_string(),
4431 summary_prompt: "Summarize".to_string(),
4432 timeout_ms: None,
4433 allowed_tools: None,
4434 },
4435 event_stream.clone(),
4436 cx,
4437 )
4438 });
4439
4440 cx.run_until_parked();
4441
4442 // Signal cancellation via the event stream
4443 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4444
4445 // The task should complete promptly with a cancellation error
4446 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4447 let result = futures::select! {
4448 result = task.fuse() => result,
4449 _ = timeout.fuse() => {
4450 panic!("subagent tool did not respond to cancellation within timeout");
4451 }
4452 };
4453
4454 // Verify we got a cancellation error
4455 let err = result.unwrap_err();
4456 assert!(
4457 err.to_string().contains("cancelled by user"),
4458 "expected cancellation error, got: {}",
4459 err
4460 );
4461}
4462
4463#[gpui::test]
4464async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4465 init_test(cx);
4466 always_allow_tools(cx);
4467
4468 cx.update(|cx| {
4469 cx.update_flags(true, vec!["subagents".to_string()]);
4470 });
4471
4472 let fs = FakeFs::new(cx.executor());
4473 fs.insert_tree(path!("/test"), json!({})).await;
4474 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4475 let project_context = cx.new(|_cx| ProjectContext::default());
4476 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4477 let context_server_registry =
4478 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4479 cx.update(LanguageModelRegistry::test);
4480 let model = Arc::new(FakeLanguageModel::default());
4481 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4482 let native_agent = NativeAgent::new(
4483 project.clone(),
4484 thread_store,
4485 Templates::new(),
4486 None,
4487 fs,
4488 &mut cx.to_async(),
4489 )
4490 .await
4491 .unwrap();
4492 let parent_thread = cx.new(|cx| {
4493 Thread::new(
4494 project.clone(),
4495 project_context,
4496 context_server_registry,
4497 Templates::new(),
4498 Some(model.clone()),
4499 cx,
4500 )
4501 });
4502
4503 let mut handles = Vec::new();
4504 for _ in 0..MAX_PARALLEL_SUBAGENTS {
4505 let handle = cx
4506 .update(|cx| {
4507 NativeThreadEnvironment::create_subagent_thread(
4508 native_agent.downgrade(),
4509 parent_thread.clone(),
4510 "some title".to_string(),
4511 "some task".to_string(),
4512 None,
4513 None,
4514 cx,
4515 )
4516 })
4517 .expect("Expected to be able to create subagent thread");
4518 handles.push(handle);
4519 }
4520
4521 let result = cx.update(|cx| {
4522 NativeThreadEnvironment::create_subagent_thread(
4523 native_agent.downgrade(),
4524 parent_thread.clone(),
4525 "some title".to_string(),
4526 "some task".to_string(),
4527 None,
4528 None,
4529 cx,
4530 )
4531 });
4532 assert!(result.is_err());
4533 assert_eq!(
4534 result.err().unwrap().to_string(),
4535 format!(
4536 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
4537 MAX_PARALLEL_SUBAGENTS
4538 )
4539 );
4540}
4541
4542#[gpui::test]
4543async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4544 init_test(cx);
4545
4546 always_allow_tools(cx);
4547
4548 cx.update(|cx| {
4549 cx.update_flags(true, vec!["subagents".to_string()]);
4550 });
4551
4552 let fs = FakeFs::new(cx.executor());
4553 fs.insert_tree(path!("/test"), json!({})).await;
4554 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4555 let project_context = cx.new(|_cx| ProjectContext::default());
4556 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4557 let context_server_registry =
4558 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4559 cx.update(LanguageModelRegistry::test);
4560 let model = Arc::new(FakeLanguageModel::default());
4561 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4562 let native_agent = NativeAgent::new(
4563 project.clone(),
4564 thread_store,
4565 Templates::new(),
4566 None,
4567 fs,
4568 &mut cx.to_async(),
4569 )
4570 .await
4571 .unwrap();
4572 let parent_thread = cx.new(|cx| {
4573 Thread::new(
4574 project.clone(),
4575 project_context,
4576 context_server_registry,
4577 Templates::new(),
4578 Some(model.clone()),
4579 cx,
4580 )
4581 });
4582
4583 let subagent_handle = cx
4584 .update(|cx| {
4585 NativeThreadEnvironment::create_subagent_thread(
4586 native_agent.downgrade(),
4587 parent_thread.clone(),
4588 "some title".to_string(),
4589 "task prompt".to_string(),
4590 Some(Duration::from_millis(10)),
4591 None,
4592 cx,
4593 )
4594 })
4595 .expect("Failed to create subagent");
4596
4597 let summary_task =
4598 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4599
4600 cx.run_until_parked();
4601
4602 {
4603 let messages = model.pending_completions().last().unwrap().messages.clone();
4604 // Ensure that model received a system prompt
4605 assert_eq!(messages[0].role, Role::System);
4606 // Ensure that model received a task prompt
4607 assert_eq!(messages[1].role, Role::User);
4608 assert_eq!(
4609 messages[1].content,
4610 vec![MessageContent::Text("task prompt".to_string())]
4611 );
4612 }
4613
4614 model.send_last_completion_stream_text_chunk("Some task response...");
4615 model.end_last_completion_stream();
4616
4617 cx.run_until_parked();
4618
4619 {
4620 let messages = model.pending_completions().last().unwrap().messages.clone();
4621 assert_eq!(messages[2].role, Role::Assistant);
4622 assert_eq!(
4623 messages[2].content,
4624 vec![MessageContent::Text("Some task response...".to_string())]
4625 );
4626 // Ensure that model received a summary prompt
4627 assert_eq!(messages[3].role, Role::User);
4628 assert_eq!(
4629 messages[3].content,
4630 vec![MessageContent::Text("summary prompt".to_string())]
4631 );
4632 }
4633
4634 model.send_last_completion_stream_text_chunk("Some summary...");
4635 model.end_last_completion_stream();
4636
4637 let result = summary_task.await;
4638 assert_eq!(result.unwrap(), "Some summary...\n");
4639}
4640
4641#[gpui::test]
4642async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4643 cx: &mut TestAppContext,
4644) {
4645 init_test(cx);
4646
4647 always_allow_tools(cx);
4648
4649 cx.update(|cx| {
4650 cx.update_flags(true, vec!["subagents".to_string()]);
4651 });
4652
4653 let fs = FakeFs::new(cx.executor());
4654 fs.insert_tree(path!("/test"), json!({})).await;
4655 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4656 let project_context = cx.new(|_cx| ProjectContext::default());
4657 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4658 let context_server_registry =
4659 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4660 cx.update(LanguageModelRegistry::test);
4661 let model = Arc::new(FakeLanguageModel::default());
4662 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4663 let native_agent = NativeAgent::new(
4664 project.clone(),
4665 thread_store,
4666 Templates::new(),
4667 None,
4668 fs,
4669 &mut cx.to_async(),
4670 )
4671 .await
4672 .unwrap();
4673 let parent_thread = cx.new(|cx| {
4674 Thread::new(
4675 project.clone(),
4676 project_context,
4677 context_server_registry,
4678 Templates::new(),
4679 Some(model.clone()),
4680 cx,
4681 )
4682 });
4683
4684 let subagent_handle = cx
4685 .update(|cx| {
4686 NativeThreadEnvironment::create_subagent_thread(
4687 native_agent.downgrade(),
4688 parent_thread.clone(),
4689 "some title".to_string(),
4690 "task prompt".to_string(),
4691 Some(Duration::from_millis(100)),
4692 None,
4693 cx,
4694 )
4695 })
4696 .expect("Failed to create subagent");
4697
4698 let summary_task =
4699 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4700
4701 cx.run_until_parked();
4702
4703 {
4704 let messages = model.pending_completions().last().unwrap().messages.clone();
4705 // Ensure that model received a system prompt
4706 assert_eq!(messages[0].role, Role::System);
4707 // Ensure that model received a task prompt
4708 assert_eq!(
4709 messages[1].content,
4710 vec![MessageContent::Text("task prompt".to_string())]
4711 );
4712 }
4713
4714 // Don't complete the initial model stream — let the timeout expire instead.
4715 cx.executor().advance_clock(Duration::from_millis(200));
4716 cx.run_until_parked();
4717
4718 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
4719 // instead of the summary_prompt.
4720 {
4721 let messages = model.pending_completions().last().unwrap().messages.clone();
4722 let last_user_message = messages
4723 .iter()
4724 .rev()
4725 .find(|m| m.role == Role::User)
4726 .unwrap();
4727 assert_eq!(
4728 last_user_message.content,
4729 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
4730 );
4731 }
4732
4733 model.send_last_completion_stream_text_chunk("Some context low response...");
4734 model.end_last_completion_stream();
4735
4736 let result = summary_task.await;
4737 assert_eq!(result.unwrap(), "Some context low response...\n");
4738}
4739
4740#[gpui::test]
4741async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
4742 init_test(cx);
4743
4744 always_allow_tools(cx);
4745
4746 cx.update(|cx| {
4747 cx.update_flags(true, vec!["subagents".to_string()]);
4748 });
4749
4750 let fs = FakeFs::new(cx.executor());
4751 fs.insert_tree(path!("/test"), json!({})).await;
4752 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4753 let project_context = cx.new(|_cx| ProjectContext::default());
4754 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4755 let context_server_registry =
4756 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4757 cx.update(LanguageModelRegistry::test);
4758 let model = Arc::new(FakeLanguageModel::default());
4759 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4760 let native_agent = NativeAgent::new(
4761 project.clone(),
4762 thread_store,
4763 Templates::new(),
4764 None,
4765 fs,
4766 &mut cx.to_async(),
4767 )
4768 .await
4769 .unwrap();
4770 let parent_thread = cx.new(|cx| {
4771 let mut thread = Thread::new(
4772 project.clone(),
4773 project_context,
4774 context_server_registry,
4775 Templates::new(),
4776 Some(model.clone()),
4777 cx,
4778 );
4779 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4780 thread.add_tool(GrepTool::new(project.clone()), None);
4781 thread
4782 });
4783
4784 let _subagent_handle = cx
4785 .update(|cx| {
4786 NativeThreadEnvironment::create_subagent_thread(
4787 native_agent.downgrade(),
4788 parent_thread.clone(),
4789 "some title".to_string(),
4790 "task prompt".to_string(),
4791 Some(Duration::from_millis(10)),
4792 None,
4793 cx,
4794 )
4795 })
4796 .expect("Failed to create subagent");
4797
4798 cx.run_until_parked();
4799
4800 let tools = model
4801 .pending_completions()
4802 .last()
4803 .unwrap()
4804 .tools
4805 .iter()
4806 .map(|tool| tool.name.clone())
4807 .collect::<Vec<_>>();
4808 assert_eq!(tools.len(), 2);
4809 assert!(tools.contains(&"grep".to_string()));
4810 assert!(tools.contains(&"list_directory".to_string()));
4811}
4812
4813#[gpui::test]
4814async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
4815 init_test(cx);
4816
4817 always_allow_tools(cx);
4818
4819 cx.update(|cx| {
4820 cx.update_flags(true, vec!["subagents".to_string()]);
4821 });
4822
4823 let fs = FakeFs::new(cx.executor());
4824 fs.insert_tree(path!("/test"), json!({})).await;
4825 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4826 let project_context = cx.new(|_cx| ProjectContext::default());
4827 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4828 let context_server_registry =
4829 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4830 cx.update(LanguageModelRegistry::test);
4831 let model = Arc::new(FakeLanguageModel::default());
4832 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4833 let native_agent = NativeAgent::new(
4834 project.clone(),
4835 thread_store,
4836 Templates::new(),
4837 None,
4838 fs,
4839 &mut cx.to_async(),
4840 )
4841 .await
4842 .unwrap();
4843 let parent_thread = cx.new(|cx| {
4844 let mut thread = Thread::new(
4845 project.clone(),
4846 project_context,
4847 context_server_registry,
4848 Templates::new(),
4849 Some(model.clone()),
4850 cx,
4851 );
4852 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4853 thread.add_tool(GrepTool::new(project.clone()), None);
4854 thread
4855 });
4856
4857 let _subagent_handle = cx
4858 .update(|cx| {
4859 NativeThreadEnvironment::create_subagent_thread(
4860 native_agent.downgrade(),
4861 parent_thread.clone(),
4862 "some title".to_string(),
4863 "task prompt".to_string(),
4864 Some(Duration::from_millis(10)),
4865 Some(vec!["grep".to_string()]),
4866 cx,
4867 )
4868 })
4869 .expect("Failed to create subagent");
4870
4871 cx.run_until_parked();
4872
4873 let tools = model
4874 .pending_completions()
4875 .last()
4876 .unwrap()
4877 .tools
4878 .iter()
4879 .map(|tool| tool.name.clone())
4880 .collect::<Vec<_>>();
4881 assert_eq!(tools, vec!["grep"]);
4882}
4883
4884#[gpui::test]
4885async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4886 init_test(cx);
4887
4888 let fs = FakeFs::new(cx.executor());
4889 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4890 .await;
4891 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4892
4893 cx.update(|cx| {
4894 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4895 settings.tool_permissions.tools.insert(
4896 EditFileTool::NAME.into(),
4897 agent_settings::ToolRules {
4898 default: Some(settings::ToolPermissionMode::Allow),
4899 always_allow: vec![],
4900 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
4901 always_confirm: vec![],
4902 invalid_patterns: vec![],
4903 },
4904 );
4905 agent_settings::AgentSettings::override_global(settings, cx);
4906 });
4907
4908 let context_server_registry =
4909 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4910 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4911 let templates = crate::Templates::new();
4912 let thread = cx.new(|cx| {
4913 crate::Thread::new(
4914 project.clone(),
4915 cx.new(|_cx| prompt_store::ProjectContext::default()),
4916 context_server_registry,
4917 templates.clone(),
4918 None,
4919 cx,
4920 )
4921 });
4922
4923 #[allow(clippy::arc_with_non_send_sync)]
4924 let tool = Arc::new(crate::EditFileTool::new(
4925 project.clone(),
4926 thread.downgrade(),
4927 language_registry,
4928 templates,
4929 ));
4930 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4931
4932 let task = cx.update(|cx| {
4933 tool.run(
4934 crate::EditFileToolInput {
4935 display_description: "Edit sensitive file".to_string(),
4936 path: "root/sensitive_config.txt".into(),
4937 mode: crate::EditFileMode::Edit,
4938 },
4939 event_stream,
4940 cx,
4941 )
4942 });
4943
4944 let result = task.await;
4945 assert!(result.is_err(), "expected edit to be blocked");
4946 assert!(
4947 result.unwrap_err().to_string().contains("blocked"),
4948 "error should mention the edit was blocked"
4949 );
4950}
4951
4952#[gpui::test]
4953async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4954 init_test(cx);
4955
4956 let fs = FakeFs::new(cx.executor());
4957 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4958 .await;
4959 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4960
4961 cx.update(|cx| {
4962 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4963 settings.tool_permissions.tools.insert(
4964 DeletePathTool::NAME.into(),
4965 agent_settings::ToolRules {
4966 default: Some(settings::ToolPermissionMode::Allow),
4967 always_allow: vec![],
4968 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4969 always_confirm: vec![],
4970 invalid_patterns: vec![],
4971 },
4972 );
4973 agent_settings::AgentSettings::override_global(settings, cx);
4974 });
4975
4976 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4977
4978 #[allow(clippy::arc_with_non_send_sync)]
4979 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4980 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4981
4982 let task = cx.update(|cx| {
4983 tool.run(
4984 crate::DeletePathToolInput {
4985 path: "root/important_data.txt".to_string(),
4986 },
4987 event_stream,
4988 cx,
4989 )
4990 });
4991
4992 let result = task.await;
4993 assert!(result.is_err(), "expected deletion to be blocked");
4994 assert!(
4995 result.unwrap_err().to_string().contains("blocked"),
4996 "error should mention the deletion was blocked"
4997 );
4998}
4999
5000#[gpui::test]
5001async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5002 init_test(cx);
5003
5004 let fs = FakeFs::new(cx.executor());
5005 fs.insert_tree(
5006 "/root",
5007 json!({
5008 "safe.txt": "content",
5009 "protected": {}
5010 }),
5011 )
5012 .await;
5013 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5014
5015 cx.update(|cx| {
5016 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5017 settings.tool_permissions.tools.insert(
5018 MovePathTool::NAME.into(),
5019 agent_settings::ToolRules {
5020 default: Some(settings::ToolPermissionMode::Allow),
5021 always_allow: vec![],
5022 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5023 always_confirm: vec![],
5024 invalid_patterns: vec![],
5025 },
5026 );
5027 agent_settings::AgentSettings::override_global(settings, cx);
5028 });
5029
5030 #[allow(clippy::arc_with_non_send_sync)]
5031 let tool = Arc::new(crate::MovePathTool::new(project));
5032 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5033
5034 let task = cx.update(|cx| {
5035 tool.run(
5036 crate::MovePathToolInput {
5037 source_path: "root/safe.txt".to_string(),
5038 destination_path: "root/protected/safe.txt".to_string(),
5039 },
5040 event_stream,
5041 cx,
5042 )
5043 });
5044
5045 let result = task.await;
5046 assert!(
5047 result.is_err(),
5048 "expected move to be blocked due to destination path"
5049 );
5050 assert!(
5051 result.unwrap_err().to_string().contains("blocked"),
5052 "error should mention the move was blocked"
5053 );
5054}
5055
5056#[gpui::test]
5057async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5058 init_test(cx);
5059
5060 let fs = FakeFs::new(cx.executor());
5061 fs.insert_tree(
5062 "/root",
5063 json!({
5064 "secret.txt": "secret content",
5065 "public": {}
5066 }),
5067 )
5068 .await;
5069 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5070
5071 cx.update(|cx| {
5072 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5073 settings.tool_permissions.tools.insert(
5074 MovePathTool::NAME.into(),
5075 agent_settings::ToolRules {
5076 default: Some(settings::ToolPermissionMode::Allow),
5077 always_allow: vec![],
5078 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5079 always_confirm: vec![],
5080 invalid_patterns: vec![],
5081 },
5082 );
5083 agent_settings::AgentSettings::override_global(settings, cx);
5084 });
5085
5086 #[allow(clippy::arc_with_non_send_sync)]
5087 let tool = Arc::new(crate::MovePathTool::new(project));
5088 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5089
5090 let task = cx.update(|cx| {
5091 tool.run(
5092 crate::MovePathToolInput {
5093 source_path: "root/secret.txt".to_string(),
5094 destination_path: "root/public/not_secret.txt".to_string(),
5095 },
5096 event_stream,
5097 cx,
5098 )
5099 });
5100
5101 let result = task.await;
5102 assert!(
5103 result.is_err(),
5104 "expected move to be blocked due to source path"
5105 );
5106 assert!(
5107 result.unwrap_err().to_string().contains("blocked"),
5108 "error should mention the move was blocked"
5109 );
5110}
5111
5112#[gpui::test]
5113async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5114 init_test(cx);
5115
5116 let fs = FakeFs::new(cx.executor());
5117 fs.insert_tree(
5118 "/root",
5119 json!({
5120 "confidential.txt": "confidential data",
5121 "dest": {}
5122 }),
5123 )
5124 .await;
5125 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5126
5127 cx.update(|cx| {
5128 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5129 settings.tool_permissions.tools.insert(
5130 CopyPathTool::NAME.into(),
5131 agent_settings::ToolRules {
5132 default: Some(settings::ToolPermissionMode::Allow),
5133 always_allow: vec![],
5134 always_deny: vec![
5135 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5136 ],
5137 always_confirm: vec![],
5138 invalid_patterns: vec![],
5139 },
5140 );
5141 agent_settings::AgentSettings::override_global(settings, cx);
5142 });
5143
5144 #[allow(clippy::arc_with_non_send_sync)]
5145 let tool = Arc::new(crate::CopyPathTool::new(project));
5146 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5147
5148 let task = cx.update(|cx| {
5149 tool.run(
5150 crate::CopyPathToolInput {
5151 source_path: "root/confidential.txt".to_string(),
5152 destination_path: "root/dest/copy.txt".to_string(),
5153 },
5154 event_stream,
5155 cx,
5156 )
5157 });
5158
5159 let result = task.await;
5160 assert!(result.is_err(), "expected copy to be blocked");
5161 assert!(
5162 result.unwrap_err().to_string().contains("blocked"),
5163 "error should mention the copy was blocked"
5164 );
5165}
5166
5167#[gpui::test]
5168async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5169 init_test(cx);
5170
5171 let fs = FakeFs::new(cx.executor());
5172 fs.insert_tree(
5173 "/root",
5174 json!({
5175 "normal.txt": "normal content",
5176 "readonly": {
5177 "config.txt": "readonly content"
5178 }
5179 }),
5180 )
5181 .await;
5182 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5183
5184 cx.update(|cx| {
5185 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5186 settings.tool_permissions.tools.insert(
5187 SaveFileTool::NAME.into(),
5188 agent_settings::ToolRules {
5189 default: Some(settings::ToolPermissionMode::Allow),
5190 always_allow: vec![],
5191 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5192 always_confirm: vec![],
5193 invalid_patterns: vec![],
5194 },
5195 );
5196 agent_settings::AgentSettings::override_global(settings, cx);
5197 });
5198
5199 #[allow(clippy::arc_with_non_send_sync)]
5200 let tool = Arc::new(crate::SaveFileTool::new(project));
5201 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5202
5203 let task = cx.update(|cx| {
5204 tool.run(
5205 crate::SaveFileToolInput {
5206 paths: vec![
5207 std::path::PathBuf::from("root/normal.txt"),
5208 std::path::PathBuf::from("root/readonly/config.txt"),
5209 ],
5210 },
5211 event_stream,
5212 cx,
5213 )
5214 });
5215
5216 let result = task.await;
5217 assert!(
5218 result.is_err(),
5219 "expected save to be blocked due to denied path"
5220 );
5221 assert!(
5222 result.unwrap_err().to_string().contains("blocked"),
5223 "error should mention the save was blocked"
5224 );
5225}
5226
5227#[gpui::test]
5228async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5229 init_test(cx);
5230
5231 let fs = FakeFs::new(cx.executor());
5232 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5233 .await;
5234 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5235
5236 cx.update(|cx| {
5237 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5238 settings.tool_permissions.tools.insert(
5239 SaveFileTool::NAME.into(),
5240 agent_settings::ToolRules {
5241 default: Some(settings::ToolPermissionMode::Allow),
5242 always_allow: vec![],
5243 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5244 always_confirm: vec![],
5245 invalid_patterns: vec![],
5246 },
5247 );
5248 agent_settings::AgentSettings::override_global(settings, cx);
5249 });
5250
5251 #[allow(clippy::arc_with_non_send_sync)]
5252 let tool = Arc::new(crate::SaveFileTool::new(project));
5253 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5254
5255 let task = cx.update(|cx| {
5256 tool.run(
5257 crate::SaveFileToolInput {
5258 paths: vec![std::path::PathBuf::from("root/config.secret")],
5259 },
5260 event_stream,
5261 cx,
5262 )
5263 });
5264
5265 let result = task.await;
5266 assert!(result.is_err(), "expected save to be blocked");
5267 assert!(
5268 result.unwrap_err().to_string().contains("blocked"),
5269 "error should mention the save was blocked"
5270 );
5271}
5272
5273#[gpui::test]
5274async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5275 init_test(cx);
5276
5277 cx.update(|cx| {
5278 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5279 settings.tool_permissions.tools.insert(
5280 WebSearchTool::NAME.into(),
5281 agent_settings::ToolRules {
5282 default: Some(settings::ToolPermissionMode::Allow),
5283 always_allow: vec![],
5284 always_deny: vec![
5285 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5286 ],
5287 always_confirm: vec![],
5288 invalid_patterns: vec![],
5289 },
5290 );
5291 agent_settings::AgentSettings::override_global(settings, cx);
5292 });
5293
5294 #[allow(clippy::arc_with_non_send_sync)]
5295 let tool = Arc::new(crate::WebSearchTool);
5296 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5297
5298 let input: crate::WebSearchToolInput =
5299 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5300
5301 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5302
5303 let result = task.await;
5304 assert!(result.is_err(), "expected search to be blocked");
5305 assert!(
5306 result.unwrap_err().to_string().contains("blocked"),
5307 "error should mention the search was blocked"
5308 );
5309}
5310
5311#[gpui::test]
5312async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5313 init_test(cx);
5314
5315 let fs = FakeFs::new(cx.executor());
5316 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5317 .await;
5318 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5319
5320 cx.update(|cx| {
5321 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5322 settings.tool_permissions.tools.insert(
5323 EditFileTool::NAME.into(),
5324 agent_settings::ToolRules {
5325 default: Some(settings::ToolPermissionMode::Confirm),
5326 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5327 always_deny: vec![],
5328 always_confirm: vec![],
5329 invalid_patterns: vec![],
5330 },
5331 );
5332 agent_settings::AgentSettings::override_global(settings, cx);
5333 });
5334
5335 let context_server_registry =
5336 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5337 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5338 let templates = crate::Templates::new();
5339 let thread = cx.new(|cx| {
5340 crate::Thread::new(
5341 project.clone(),
5342 cx.new(|_cx| prompt_store::ProjectContext::default()),
5343 context_server_registry,
5344 templates.clone(),
5345 None,
5346 cx,
5347 )
5348 });
5349
5350 #[allow(clippy::arc_with_non_send_sync)]
5351 let tool = Arc::new(crate::EditFileTool::new(
5352 project,
5353 thread.downgrade(),
5354 language_registry,
5355 templates,
5356 ));
5357 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5358
5359 let _task = cx.update(|cx| {
5360 tool.run(
5361 crate::EditFileToolInput {
5362 display_description: "Edit README".to_string(),
5363 path: "root/README.md".into(),
5364 mode: crate::EditFileMode::Edit,
5365 },
5366 event_stream,
5367 cx,
5368 )
5369 });
5370
5371 cx.run_until_parked();
5372
5373 let event = rx.try_next();
5374 assert!(
5375 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5376 "expected no authorization request for allowed .md file"
5377 );
5378}
5379
5380#[gpui::test]
5381async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5382 init_test(cx);
5383
5384 let fs = FakeFs::new(cx.executor());
5385 fs.insert_tree(
5386 "/root",
5387 json!({
5388 ".zed": {
5389 "settings.json": "{}"
5390 },
5391 "README.md": "# Hello"
5392 }),
5393 )
5394 .await;
5395 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5396
5397 cx.update(|cx| {
5398 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5399 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5400 agent_settings::AgentSettings::override_global(settings, cx);
5401 });
5402
5403 let context_server_registry =
5404 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5405 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5406 let templates = crate::Templates::new();
5407 let thread = cx.new(|cx| {
5408 crate::Thread::new(
5409 project.clone(),
5410 cx.new(|_cx| prompt_store::ProjectContext::default()),
5411 context_server_registry,
5412 templates.clone(),
5413 None,
5414 cx,
5415 )
5416 });
5417
5418 #[allow(clippy::arc_with_non_send_sync)]
5419 let tool = Arc::new(crate::EditFileTool::new(
5420 project,
5421 thread.downgrade(),
5422 language_registry,
5423 templates,
5424 ));
5425
5426 // Editing a file inside .zed/ should still prompt even with global default: allow,
5427 // because local settings paths are sensitive and require confirmation regardless.
5428 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5429 let _task = cx.update(|cx| {
5430 tool.run(
5431 crate::EditFileToolInput {
5432 display_description: "Edit local settings".to_string(),
5433 path: "root/.zed/settings.json".into(),
5434 mode: crate::EditFileMode::Edit,
5435 },
5436 event_stream,
5437 cx,
5438 )
5439 });
5440
5441 let _update = rx.expect_update_fields().await;
5442 let _auth = rx.expect_authorization().await;
5443}
5444
5445#[gpui::test]
5446async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5447 init_test(cx);
5448
5449 cx.update(|cx| {
5450 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5451 settings.tool_permissions.tools.insert(
5452 FetchTool::NAME.into(),
5453 agent_settings::ToolRules {
5454 default: Some(settings::ToolPermissionMode::Allow),
5455 always_allow: vec![],
5456 always_deny: vec![
5457 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5458 ],
5459 always_confirm: vec![],
5460 invalid_patterns: vec![],
5461 },
5462 );
5463 agent_settings::AgentSettings::override_global(settings, cx);
5464 });
5465
5466 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5467
5468 #[allow(clippy::arc_with_non_send_sync)]
5469 let tool = Arc::new(crate::FetchTool::new(http_client));
5470 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5471
5472 let input: crate::FetchToolInput =
5473 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5474
5475 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5476
5477 let result = task.await;
5478 assert!(result.is_err(), "expected fetch to be blocked");
5479 assert!(
5480 result.unwrap_err().to_string().contains("blocked"),
5481 "error should mention the fetch was blocked"
5482 );
5483}
5484
5485#[gpui::test]
5486async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5487 init_test(cx);
5488
5489 cx.update(|cx| {
5490 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5491 settings.tool_permissions.tools.insert(
5492 FetchTool::NAME.into(),
5493 agent_settings::ToolRules {
5494 default: Some(settings::ToolPermissionMode::Confirm),
5495 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5496 always_deny: vec![],
5497 always_confirm: vec![],
5498 invalid_patterns: vec![],
5499 },
5500 );
5501 agent_settings::AgentSettings::override_global(settings, cx);
5502 });
5503
5504 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5505
5506 #[allow(clippy::arc_with_non_send_sync)]
5507 let tool = Arc::new(crate::FetchTool::new(http_client));
5508 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5509
5510 let input: crate::FetchToolInput =
5511 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5512
5513 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5514
5515 cx.run_until_parked();
5516
5517 let event = rx.try_next();
5518 assert!(
5519 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5520 "expected no authorization request for allowed docs.rs URL"
5521 );
5522}
5523
5524#[gpui::test]
5525async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5526 init_test(cx);
5527 always_allow_tools(cx);
5528
5529 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5530 let fake_model = model.as_fake();
5531
5532 // Add a tool so we can simulate tool calls
5533 thread.update(cx, |thread, _cx| {
5534 thread.add_tool(EchoTool, None);
5535 });
5536
5537 // Start a turn by sending a message
5538 let mut events = thread
5539 .update(cx, |thread, cx| {
5540 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5541 })
5542 .unwrap();
5543 cx.run_until_parked();
5544
5545 // Simulate the model making a tool call
5546 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5547 LanguageModelToolUse {
5548 id: "tool_1".into(),
5549 name: "echo".into(),
5550 raw_input: r#"{"text": "hello"}"#.into(),
5551 input: json!({"text": "hello"}),
5552 is_input_complete: true,
5553 thought_signature: None,
5554 },
5555 ));
5556 fake_model
5557 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5558
5559 // Signal that a message is queued before ending the stream
5560 thread.update(cx, |thread, _cx| {
5561 thread.set_has_queued_message(true);
5562 });
5563
5564 // Now end the stream - tool will run, and the boundary check should see the queue
5565 fake_model.end_last_completion_stream();
5566
5567 // Collect all events until the turn stops
5568 let all_events = collect_events_until_stop(&mut events, cx).await;
5569
5570 // Verify we received the tool call event
5571 let tool_call_ids: Vec<_> = all_events
5572 .iter()
5573 .filter_map(|e| match e {
5574 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5575 _ => None,
5576 })
5577 .collect();
5578 assert_eq!(
5579 tool_call_ids,
5580 vec!["tool_1"],
5581 "Should have received a tool call event for our echo tool"
5582 );
5583
5584 // The turn should have stopped with EndTurn
5585 let stop_reasons = stop_events(all_events);
5586 assert_eq!(
5587 stop_reasons,
5588 vec![acp::StopReason::EndTurn],
5589 "Turn should have ended after tool completion due to queued message"
5590 );
5591
5592 // Verify the queued message flag is still set
5593 thread.update(cx, |thread, _cx| {
5594 assert!(
5595 thread.has_queued_message(),
5596 "Should still have queued message flag set"
5597 );
5598 });
5599
5600 // Thread should be idle now
5601 thread.update(cx, |thread, _cx| {
5602 assert!(
5603 thread.is_turn_complete(),
5604 "Thread should not be running after turn ends"
5605 );
5606 });
5607}