1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, UserMessageId,
4};
5use agent_client_protocol::{self as acp};
6use agent_settings::AgentProfileId;
7use anyhow::Result;
8use client::{Client, UserStore};
9use cloud_llm_client::CompletionIntent;
10use collections::IndexMap;
11use context_server::{ContextServer, ContextServerCommand, ContextServerId};
12use feature_flags::FeatureFlagAppExt as _;
13use fs::{FakeFs, Fs};
14use futures::{
15 FutureExt as _, StreamExt,
16 channel::{
17 mpsc::{self, UnboundedReceiver},
18 oneshot,
19 },
20 future::{Fuse, Shared},
21};
22use gpui::{
23 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
24 http_client::FakeHttpClient,
25};
26use indoc::indoc;
27use language_model::{
28 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
29 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
30 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
31 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
32};
33use pretty_assertions::assert_eq;
34use project::{
35 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
36};
37use prompt_store::ProjectContext;
38use reqwest_client::ReqwestClient;
39use schemars::JsonSchema;
40use serde::{Deserialize, Serialize};
41use serde_json::json;
42use settings::{Settings, SettingsStore};
43use std::{
44 path::Path,
45 pin::Pin,
46 rc::Rc,
47 sync::{
48 Arc,
49 atomic::{AtomicBool, Ordering},
50 },
51 time::Duration,
52};
53use util::path;
54
55mod edit_file_thread_test;
56mod test_tools;
57use test_tools::*;
58
59fn init_test(cx: &mut TestAppContext) {
60 cx.update(|cx| {
61 let settings_store = SettingsStore::test(cx);
62 cx.set_global(settings_store);
63 });
64}
65
66struct FakeTerminalHandle {
67 killed: Arc<AtomicBool>,
68 stopped_by_user: Arc<AtomicBool>,
69 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
70 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
71 output: acp::TerminalOutputResponse,
72 id: acp::TerminalId,
73}
74
75impl FakeTerminalHandle {
76 fn new_never_exits(cx: &mut App) -> Self {
77 let killed = Arc::new(AtomicBool::new(false));
78 let stopped_by_user = Arc::new(AtomicBool::new(false));
79
80 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
81
82 let wait_for_exit = cx
83 .spawn(async move |_cx| {
84 // Wait for the exit signal (sent when kill() is called)
85 let _ = exit_receiver.await;
86 acp::TerminalExitStatus::new()
87 })
88 .shared();
89
90 Self {
91 killed,
92 stopped_by_user,
93 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
94 wait_for_exit,
95 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
96 id: acp::TerminalId::new("fake_terminal".to_string()),
97 }
98 }
99
100 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
101 let killed = Arc::new(AtomicBool::new(false));
102 let stopped_by_user = Arc::new(AtomicBool::new(false));
103 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
104
105 let wait_for_exit = cx
106 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
107 .shared();
108
109 Self {
110 killed,
111 stopped_by_user,
112 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
113 wait_for_exit,
114 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
115 id: acp::TerminalId::new("fake_terminal".to_string()),
116 }
117 }
118
119 fn was_killed(&self) -> bool {
120 self.killed.load(Ordering::SeqCst)
121 }
122
123 fn set_stopped_by_user(&self, stopped: bool) {
124 self.stopped_by_user.store(stopped, Ordering::SeqCst);
125 }
126
127 fn signal_exit(&self) {
128 if let Some(sender) = self.exit_sender.borrow_mut().take() {
129 let _ = sender.send(());
130 }
131 }
132}
133
134impl crate::TerminalHandle for FakeTerminalHandle {
135 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
136 Ok(self.id.clone())
137 }
138
139 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
140 Ok(self.output.clone())
141 }
142
143 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
144 Ok(self.wait_for_exit.clone())
145 }
146
147 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
148 self.killed.store(true, Ordering::SeqCst);
149 self.signal_exit();
150 Ok(())
151 }
152
153 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
154 Ok(self.stopped_by_user.load(Ordering::SeqCst))
155 }
156}
157
158struct FakeSubagentHandle {
159 session_id: acp::SessionId,
160 wait_for_summary_task: Shared<Task<String>>,
161}
162
163impl FakeSubagentHandle {
164 fn new_never_completes(cx: &App) -> Self {
165 Self {
166 session_id: acp::SessionId::new("subagent-id"),
167 wait_for_summary_task: cx.background_spawn(std::future::pending()).shared(),
168 }
169 }
170}
171
172impl SubagentHandle for FakeSubagentHandle {
173 fn id(&self) -> acp::SessionId {
174 self.session_id.clone()
175 }
176
177 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
178 let task = self.wait_for_summary_task.clone();
179 cx.background_spawn(async move { Ok(task.await) })
180 }
181}
182
183#[derive(Default)]
184struct FakeThreadEnvironment {
185 terminal_handle: Option<Rc<FakeTerminalHandle>>,
186 subagent_handle: Option<Rc<FakeSubagentHandle>>,
187}
188
189impl FakeThreadEnvironment {
190 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
191 Self {
192 terminal_handle: Some(terminal_handle.into()),
193 ..self
194 }
195 }
196
197 pub fn with_subagent(self, subagent_handle: FakeSubagentHandle) -> Self {
198 Self {
199 subagent_handle: Some(subagent_handle.into()),
200 ..self
201 }
202 }
203}
204
205impl crate::ThreadEnvironment for FakeThreadEnvironment {
206 fn create_terminal(
207 &self,
208 _command: String,
209 _cwd: Option<std::path::PathBuf>,
210 _output_byte_limit: Option<u64>,
211 _cx: &mut AsyncApp,
212 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
213 let handle = self
214 .terminal_handle
215 .clone()
216 .expect("Terminal handle not available on FakeThreadEnvironment");
217 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
218 }
219
220 fn create_subagent(
221 &self,
222 _parent_thread: Entity<Thread>,
223 _label: String,
224 _initial_prompt: String,
225 _timeout_ms: Option<Duration>,
226 _allowed_tools: Option<Vec<String>>,
227 _cx: &mut App,
228 ) -> Result<Rc<dyn SubagentHandle>> {
229 Ok(self
230 .subagent_handle
231 .clone()
232 .expect("Subagent handle not available on FakeThreadEnvironment")
233 as Rc<dyn SubagentHandle>)
234 }
235}
236
237/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
238struct MultiTerminalEnvironment {
239 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
240}
241
242impl MultiTerminalEnvironment {
243 fn new() -> Self {
244 Self {
245 handles: std::cell::RefCell::new(Vec::new()),
246 }
247 }
248
249 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
250 self.handles.borrow().clone()
251 }
252}
253
254impl crate::ThreadEnvironment for MultiTerminalEnvironment {
255 fn create_terminal(
256 &self,
257 _command: String,
258 _cwd: Option<std::path::PathBuf>,
259 _output_byte_limit: Option<u64>,
260 cx: &mut AsyncApp,
261 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
262 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
263 self.handles.borrow_mut().push(handle.clone());
264 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
265 }
266
267 fn create_subagent(
268 &self,
269 _parent_thread: Entity<Thread>,
270 _label: String,
271 _initial_prompt: String,
272 _timeout: Option<Duration>,
273 _allowed_tools: Option<Vec<String>>,
274 _cx: &mut App,
275 ) -> Result<Rc<dyn SubagentHandle>> {
276 unimplemented!()
277 }
278}
279
280fn always_allow_tools(cx: &mut TestAppContext) {
281 cx.update(|cx| {
282 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
283 settings.always_allow_tool_actions = true;
284 agent_settings::AgentSettings::override_global(settings, cx);
285 });
286}
287
288#[gpui::test]
289async fn test_echo(cx: &mut TestAppContext) {
290 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
291 let fake_model = model.as_fake();
292
293 let events = thread
294 .update(cx, |thread, cx| {
295 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
296 })
297 .unwrap();
298 cx.run_until_parked();
299 fake_model.send_last_completion_stream_text_chunk("Hello");
300 fake_model
301 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
302 fake_model.end_last_completion_stream();
303
304 let events = events.collect().await;
305 thread.update(cx, |thread, _cx| {
306 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
307 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
308 });
309 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
310}
311
312#[gpui::test]
313async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
314 init_test(cx);
315 always_allow_tools(cx);
316
317 let fs = FakeFs::new(cx.executor());
318 let project = Project::test(fs, [], cx).await;
319
320 let environment = Rc::new(cx.update(|cx| {
321 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
322 }));
323 let handle = environment.terminal_handle.clone().unwrap();
324
325 #[allow(clippy::arc_with_non_send_sync)]
326 let tool = Arc::new(crate::TerminalTool::new(project, environment));
327 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
328
329 let task = cx.update(|cx| {
330 tool.run(
331 crate::TerminalToolInput {
332 command: "sleep 1000".to_string(),
333 cd: ".".to_string(),
334 timeout_ms: Some(5),
335 },
336 event_stream,
337 cx,
338 )
339 });
340
341 let update = rx.expect_update_fields().await;
342 assert!(
343 update.content.iter().any(|blocks| {
344 blocks
345 .iter()
346 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
347 }),
348 "expected tool call update to include terminal content"
349 );
350
351 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
352
353 let deadline = std::time::Instant::now() + Duration::from_millis(500);
354 loop {
355 if let Some(result) = task_future.as_mut().now_or_never() {
356 let result = result.expect("terminal tool task should complete");
357
358 assert!(
359 handle.was_killed(),
360 "expected terminal handle to be killed on timeout"
361 );
362 assert!(
363 result.contains("partial output"),
364 "expected result to include terminal output, got: {result}"
365 );
366 return;
367 }
368
369 if std::time::Instant::now() >= deadline {
370 panic!("timed out waiting for terminal tool task to complete");
371 }
372
373 cx.run_until_parked();
374 cx.background_executor.timer(Duration::from_millis(1)).await;
375 }
376}
377
378#[gpui::test]
379#[ignore]
380async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
381 init_test(cx);
382 always_allow_tools(cx);
383
384 let fs = FakeFs::new(cx.executor());
385 let project = Project::test(fs, [], cx).await;
386
387 let environment = Rc::new(cx.update(|cx| {
388 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
389 }));
390 let handle = environment.terminal_handle.clone().unwrap();
391
392 #[allow(clippy::arc_with_non_send_sync)]
393 let tool = Arc::new(crate::TerminalTool::new(project, environment));
394 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
395
396 let _task = cx.update(|cx| {
397 tool.run(
398 crate::TerminalToolInput {
399 command: "sleep 1000".to_string(),
400 cd: ".".to_string(),
401 timeout_ms: None,
402 },
403 event_stream,
404 cx,
405 )
406 });
407
408 let update = rx.expect_update_fields().await;
409 assert!(
410 update.content.iter().any(|blocks| {
411 blocks
412 .iter()
413 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
414 }),
415 "expected tool call update to include terminal content"
416 );
417
418 cx.background_executor
419 .timer(Duration::from_millis(25))
420 .await;
421
422 assert!(
423 !handle.was_killed(),
424 "did not expect terminal handle to be killed without a timeout"
425 );
426}
427
428#[gpui::test]
429async fn test_thinking(cx: &mut TestAppContext) {
430 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
431 let fake_model = model.as_fake();
432
433 let events = thread
434 .update(cx, |thread, cx| {
435 thread.send(
436 UserMessageId::new(),
437 [indoc! {"
438 Testing:
439
440 Generate a thinking step where you just think the word 'Think',
441 and have your final answer be 'Hello'
442 "}],
443 cx,
444 )
445 })
446 .unwrap();
447 cx.run_until_parked();
448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
449 text: "Think".to_string(),
450 signature: None,
451 });
452 fake_model.send_last_completion_stream_text_chunk("Hello");
453 fake_model
454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
455 fake_model.end_last_completion_stream();
456
457 let events = events.collect().await;
458 thread.update(cx, |thread, _cx| {
459 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
460 assert_eq!(
461 thread.last_message().unwrap().to_markdown(),
462 indoc! {"
463 <think>Think</think>
464 Hello
465 "}
466 )
467 });
468 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
469}
470
471#[gpui::test]
472async fn test_system_prompt(cx: &mut TestAppContext) {
473 let ThreadTest {
474 model,
475 thread,
476 project_context,
477 ..
478 } = setup(cx, TestModel::Fake).await;
479 let fake_model = model.as_fake();
480
481 project_context.update(cx, |project_context, _cx| {
482 project_context.shell = "test-shell".into()
483 });
484 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
485 thread
486 .update(cx, |thread, cx| {
487 thread.send(UserMessageId::new(), ["abc"], cx)
488 })
489 .unwrap();
490 cx.run_until_parked();
491 let mut pending_completions = fake_model.pending_completions();
492 assert_eq!(
493 pending_completions.len(),
494 1,
495 "unexpected pending completions: {:?}",
496 pending_completions
497 );
498
499 let pending_completion = pending_completions.pop().unwrap();
500 assert_eq!(pending_completion.messages[0].role, Role::System);
501
502 let system_message = &pending_completion.messages[0];
503 let system_prompt = system_message.content[0].to_str().unwrap();
504 assert!(
505 system_prompt.contains("test-shell"),
506 "unexpected system message: {:?}",
507 system_message
508 );
509 assert!(
510 system_prompt.contains("## Fixing Diagnostics"),
511 "unexpected system message: {:?}",
512 system_message
513 );
514}
515
516#[gpui::test]
517async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
518 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
519 let fake_model = model.as_fake();
520
521 thread
522 .update(cx, |thread, cx| {
523 thread.send(UserMessageId::new(), ["abc"], cx)
524 })
525 .unwrap();
526 cx.run_until_parked();
527 let mut pending_completions = fake_model.pending_completions();
528 assert_eq!(
529 pending_completions.len(),
530 1,
531 "unexpected pending completions: {:?}",
532 pending_completions
533 );
534
535 let pending_completion = pending_completions.pop().unwrap();
536 assert_eq!(pending_completion.messages[0].role, Role::System);
537
538 let system_message = &pending_completion.messages[0];
539 let system_prompt = system_message.content[0].to_str().unwrap();
540 assert!(
541 !system_prompt.contains("## Tool Use"),
542 "unexpected system message: {:?}",
543 system_message
544 );
545 assert!(
546 !system_prompt.contains("## Fixing Diagnostics"),
547 "unexpected system message: {:?}",
548 system_message
549 );
550}
551
552#[gpui::test]
553async fn test_prompt_caching(cx: &mut TestAppContext) {
554 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
555 let fake_model = model.as_fake();
556
557 // Send initial user message and verify it's cached
558 thread
559 .update(cx, |thread, cx| {
560 thread.send(UserMessageId::new(), ["Message 1"], cx)
561 })
562 .unwrap();
563 cx.run_until_parked();
564
565 let completion = fake_model.pending_completions().pop().unwrap();
566 assert_eq!(
567 completion.messages[1..],
568 vec![LanguageModelRequestMessage {
569 role: Role::User,
570 content: vec!["Message 1".into()],
571 cache: true,
572 reasoning_details: None,
573 }]
574 );
575 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
576 "Response to Message 1".into(),
577 ));
578 fake_model.end_last_completion_stream();
579 cx.run_until_parked();
580
581 // Send another user message and verify only the latest is cached
582 thread
583 .update(cx, |thread, cx| {
584 thread.send(UserMessageId::new(), ["Message 2"], cx)
585 })
586 .unwrap();
587 cx.run_until_parked();
588
589 let completion = fake_model.pending_completions().pop().unwrap();
590 assert_eq!(
591 completion.messages[1..],
592 vec![
593 LanguageModelRequestMessage {
594 role: Role::User,
595 content: vec!["Message 1".into()],
596 cache: false,
597 reasoning_details: None,
598 },
599 LanguageModelRequestMessage {
600 role: Role::Assistant,
601 content: vec!["Response to Message 1".into()],
602 cache: false,
603 reasoning_details: None,
604 },
605 LanguageModelRequestMessage {
606 role: Role::User,
607 content: vec!["Message 2".into()],
608 cache: true,
609 reasoning_details: None,
610 }
611 ]
612 );
613 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
614 "Response to Message 2".into(),
615 ));
616 fake_model.end_last_completion_stream();
617 cx.run_until_parked();
618
619 // Simulate a tool call and verify that the latest tool result is cached
620 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
621 thread
622 .update(cx, |thread, cx| {
623 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
624 })
625 .unwrap();
626 cx.run_until_parked();
627
628 let tool_use = LanguageModelToolUse {
629 id: "tool_1".into(),
630 name: EchoTool::NAME.into(),
631 raw_input: json!({"text": "test"}).to_string(),
632 input: json!({"text": "test"}),
633 is_input_complete: true,
634 thought_signature: None,
635 };
636 fake_model
637 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
638 fake_model.end_last_completion_stream();
639 cx.run_until_parked();
640
641 let completion = fake_model.pending_completions().pop().unwrap();
642 let tool_result = LanguageModelToolResult {
643 tool_use_id: "tool_1".into(),
644 tool_name: EchoTool::NAME.into(),
645 is_error: false,
646 content: "test".into(),
647 output: Some("test".into()),
648 };
649 assert_eq!(
650 completion.messages[1..],
651 vec![
652 LanguageModelRequestMessage {
653 role: Role::User,
654 content: vec!["Message 1".into()],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::Assistant,
660 content: vec!["Response to Message 1".into()],
661 cache: false,
662 reasoning_details: None,
663 },
664 LanguageModelRequestMessage {
665 role: Role::User,
666 content: vec!["Message 2".into()],
667 cache: false,
668 reasoning_details: None,
669 },
670 LanguageModelRequestMessage {
671 role: Role::Assistant,
672 content: vec!["Response to Message 2".into()],
673 cache: false,
674 reasoning_details: None,
675 },
676 LanguageModelRequestMessage {
677 role: Role::User,
678 content: vec!["Use the echo tool".into()],
679 cache: false,
680 reasoning_details: None,
681 },
682 LanguageModelRequestMessage {
683 role: Role::Assistant,
684 content: vec![MessageContent::ToolUse(tool_use)],
685 cache: false,
686 reasoning_details: None,
687 },
688 LanguageModelRequestMessage {
689 role: Role::User,
690 content: vec![MessageContent::ToolResult(tool_result)],
691 cache: true,
692 reasoning_details: None,
693 }
694 ]
695 );
696}
697
698#[gpui::test]
699#[cfg_attr(not(feature = "e2e"), ignore)]
700async fn test_basic_tool_calls(cx: &mut TestAppContext) {
701 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
702
703 // Test a tool call that's likely to complete *before* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.add_tool(EchoTool, None);
707 thread.send(
708 UserMessageId::new(),
709 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
710 cx,
711 )
712 })
713 .unwrap()
714 .collect()
715 .await;
716 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
717
718 // Test a tool calls that's likely to complete *after* streaming stops.
719 let events = thread
720 .update(cx, |thread, cx| {
721 thread.remove_tool(&EchoTool::NAME);
722 thread.add_tool(DelayTool, None);
723 thread.send(
724 UserMessageId::new(),
725 [
726 "Now call the delay tool with 200ms.",
727 "When the timer goes off, then you echo the output of the tool.",
728 ],
729 cx,
730 )
731 })
732 .unwrap()
733 .collect()
734 .await;
735 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
736 thread.update(cx, |thread, _cx| {
737 assert!(
738 thread
739 .last_message()
740 .unwrap()
741 .as_agent_message()
742 .unwrap()
743 .content
744 .iter()
745 .any(|content| {
746 if let AgentMessageContent::Text(text) = content {
747 text.contains("Ding")
748 } else {
749 false
750 }
751 }),
752 "{}",
753 thread.to_markdown()
754 );
755 });
756}
757
758#[gpui::test]
759#[cfg_attr(not(feature = "e2e"), ignore)]
760async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
761 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
762
763 // Test a tool call that's likely to complete *before* streaming stops.
764 let mut events = thread
765 .update(cx, |thread, cx| {
766 thread.add_tool(WordListTool, None);
767 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
768 })
769 .unwrap();
770
771 let mut saw_partial_tool_use = false;
772 while let Some(event) = events.next().await {
773 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
774 thread.update(cx, |thread, _cx| {
775 // Look for a tool use in the thread's last message
776 let message = thread.last_message().unwrap();
777 let agent_message = message.as_agent_message().unwrap();
778 let last_content = agent_message.content.last().unwrap();
779 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
780 assert_eq!(last_tool_use.name.as_ref(), "word_list");
781 if tool_call.status == acp::ToolCallStatus::Pending {
782 if !last_tool_use.is_input_complete
783 && last_tool_use.input.get("g").is_none()
784 {
785 saw_partial_tool_use = true;
786 }
787 } else {
788 last_tool_use
789 .input
790 .get("a")
791 .expect("'a' has streamed because input is now complete");
792 last_tool_use
793 .input
794 .get("g")
795 .expect("'g' has streamed because input is now complete");
796 }
797 } else {
798 panic!("last content should be a tool use");
799 }
800 });
801 }
802 }
803
804 assert!(
805 saw_partial_tool_use,
806 "should see at least one partially streamed tool use in the history"
807 );
808}
809
810#[gpui::test]
811async fn test_tool_authorization(cx: &mut TestAppContext) {
812 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
813 let fake_model = model.as_fake();
814
815 let mut events = thread
816 .update(cx, |thread, cx| {
817 thread.add_tool(ToolRequiringPermission, None);
818 thread.send(UserMessageId::new(), ["abc"], cx)
819 })
820 .unwrap();
821 cx.run_until_parked();
822 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
823 LanguageModelToolUse {
824 id: "tool_id_1".into(),
825 name: ToolRequiringPermission::NAME.into(),
826 raw_input: "{}".into(),
827 input: json!({}),
828 is_input_complete: true,
829 thought_signature: None,
830 },
831 ));
832 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
833 LanguageModelToolUse {
834 id: "tool_id_2".into(),
835 name: ToolRequiringPermission::NAME.into(),
836 raw_input: "{}".into(),
837 input: json!({}),
838 is_input_complete: true,
839 thought_signature: None,
840 },
841 ));
842 fake_model.end_last_completion_stream();
843 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
844 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
845
846 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
847 tool_call_auth_1
848 .response
849 .send(acp::PermissionOptionId::new("allow"))
850 .unwrap();
851 cx.run_until_parked();
852
853 // Reject the second - send "deny" option_id directly since Deny is now a button
854 tool_call_auth_2
855 .response
856 .send(acp::PermissionOptionId::new("deny"))
857 .unwrap();
858 cx.run_until_parked();
859
860 let completion = fake_model.pending_completions().pop().unwrap();
861 let message = completion.messages.last().unwrap();
862 assert_eq!(
863 message.content,
864 vec![
865 language_model::MessageContent::ToolResult(LanguageModelToolResult {
866 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
867 tool_name: ToolRequiringPermission::NAME.into(),
868 is_error: false,
869 content: "Allowed".into(),
870 output: Some("Allowed".into())
871 }),
872 language_model::MessageContent::ToolResult(LanguageModelToolResult {
873 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
874 tool_name: ToolRequiringPermission::NAME.into(),
875 is_error: true,
876 content: "Permission to run tool denied by user".into(),
877 output: Some("Permission to run tool denied by user".into())
878 })
879 ]
880 );
881
882 // Simulate yet another tool call.
883 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
884 LanguageModelToolUse {
885 id: "tool_id_3".into(),
886 name: ToolRequiringPermission::NAME.into(),
887 raw_input: "{}".into(),
888 input: json!({}),
889 is_input_complete: true,
890 thought_signature: None,
891 },
892 ));
893 fake_model.end_last_completion_stream();
894
895 // Respond by always allowing tools - send transformed option_id
896 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
897 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
898 tool_call_auth_3
899 .response
900 .send(acp::PermissionOptionId::new(
901 "always_allow:tool_requiring_permission",
902 ))
903 .unwrap();
904 cx.run_until_parked();
905 let completion = fake_model.pending_completions().pop().unwrap();
906 let message = completion.messages.last().unwrap();
907 assert_eq!(
908 message.content,
909 vec![language_model::MessageContent::ToolResult(
910 LanguageModelToolResult {
911 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
912 tool_name: ToolRequiringPermission::NAME.into(),
913 is_error: false,
914 content: "Allowed".into(),
915 output: Some("Allowed".into())
916 }
917 )]
918 );
919
920 // Simulate a final tool call, ensuring we don't trigger authorization.
921 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
922 LanguageModelToolUse {
923 id: "tool_id_4".into(),
924 name: ToolRequiringPermission::NAME.into(),
925 raw_input: "{}".into(),
926 input: json!({}),
927 is_input_complete: true,
928 thought_signature: None,
929 },
930 ));
931 fake_model.end_last_completion_stream();
932 cx.run_until_parked();
933 let completion = fake_model.pending_completions().pop().unwrap();
934 let message = completion.messages.last().unwrap();
935 assert_eq!(
936 message.content,
937 vec![language_model::MessageContent::ToolResult(
938 LanguageModelToolResult {
939 tool_use_id: "tool_id_4".into(),
940 tool_name: ToolRequiringPermission::NAME.into(),
941 is_error: false,
942 content: "Allowed".into(),
943 output: Some("Allowed".into())
944 }
945 )]
946 );
947}
948
949#[gpui::test]
950async fn test_tool_hallucination(cx: &mut TestAppContext) {
951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
952 let fake_model = model.as_fake();
953
954 let mut events = thread
955 .update(cx, |thread, cx| {
956 thread.send(UserMessageId::new(), ["abc"], cx)
957 })
958 .unwrap();
959 cx.run_until_parked();
960 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
961 LanguageModelToolUse {
962 id: "tool_id_1".into(),
963 name: "nonexistent_tool".into(),
964 raw_input: "{}".into(),
965 input: json!({}),
966 is_input_complete: true,
967 thought_signature: None,
968 },
969 ));
970 fake_model.end_last_completion_stream();
971
972 let tool_call = expect_tool_call(&mut events).await;
973 assert_eq!(tool_call.title, "nonexistent_tool");
974 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
975 let update = expect_tool_call_update_fields(&mut events).await;
976 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
977}
978
979async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
980 let event = events
981 .next()
982 .await
983 .expect("no tool call authorization event received")
984 .unwrap();
985 match event {
986 ThreadEvent::ToolCall(tool_call) => tool_call,
987 event => {
988 panic!("Unexpected event {event:?}");
989 }
990 }
991}
992
993async fn expect_tool_call_update_fields(
994 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
995) -> acp::ToolCallUpdate {
996 let event = events
997 .next()
998 .await
999 .expect("no tool call authorization event received")
1000 .unwrap();
1001 match event {
1002 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1003 event => {
1004 panic!("Unexpected event {event:?}");
1005 }
1006 }
1007}
1008
1009async fn next_tool_call_authorization(
1010 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1011) -> ToolCallAuthorization {
1012 loop {
1013 let event = events
1014 .next()
1015 .await
1016 .expect("no tool call authorization event received")
1017 .unwrap();
1018 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1019 let permission_kinds = tool_call_authorization
1020 .options
1021 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1022 .map(|option| option.kind);
1023 let allow_once = tool_call_authorization
1024 .options
1025 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1026 .map(|option| option.kind);
1027
1028 assert_eq!(
1029 permission_kinds,
1030 Some(acp::PermissionOptionKind::AllowAlways)
1031 );
1032 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1033 return tool_call_authorization;
1034 }
1035 }
1036}
1037
1038#[test]
1039fn test_permission_options_terminal_with_pattern() {
1040 let permission_options =
1041 ToolPermissionContext::new(TerminalTool::NAME, "cargo build --release")
1042 .build_permission_options();
1043
1044 let PermissionOptions::Dropdown(choices) = permission_options else {
1045 panic!("Expected dropdown permission options");
1046 };
1047
1048 assert_eq!(choices.len(), 3);
1049 let labels: Vec<&str> = choices
1050 .iter()
1051 .map(|choice| choice.allow.name.as_ref())
1052 .collect();
1053 assert!(labels.contains(&"Always for terminal"));
1054 assert!(labels.contains(&"Always for `cargo` commands"));
1055 assert!(labels.contains(&"Only this time"));
1056}
1057
1058#[test]
1059fn test_permission_options_edit_file_with_path_pattern() {
1060 let permission_options =
1061 ToolPermissionContext::new(EditFileTool::NAME, "src/main.rs").build_permission_options();
1062
1063 let PermissionOptions::Dropdown(choices) = permission_options else {
1064 panic!("Expected dropdown permission options");
1065 };
1066
1067 let labels: Vec<&str> = choices
1068 .iter()
1069 .map(|choice| choice.allow.name.as_ref())
1070 .collect();
1071 assert!(labels.contains(&"Always for edit file"));
1072 assert!(labels.contains(&"Always for `src/`"));
1073}
1074
1075#[test]
1076fn test_permission_options_fetch_with_domain_pattern() {
1077 let permission_options = ToolPermissionContext::new(FetchTool::NAME, "https://docs.rs/gpui")
1078 .build_permission_options();
1079
1080 let PermissionOptions::Dropdown(choices) = permission_options else {
1081 panic!("Expected dropdown permission options");
1082 };
1083
1084 let labels: Vec<&str> = choices
1085 .iter()
1086 .map(|choice| choice.allow.name.as_ref())
1087 .collect();
1088 assert!(labels.contains(&"Always for fetch"));
1089 assert!(labels.contains(&"Always for `docs.rs`"));
1090}
1091
1092#[test]
1093fn test_permission_options_without_pattern() {
1094 let permission_options =
1095 ToolPermissionContext::new(TerminalTool::NAME, "./deploy.sh --production")
1096 .build_permission_options();
1097
1098 let PermissionOptions::Dropdown(choices) = permission_options else {
1099 panic!("Expected dropdown permission options");
1100 };
1101
1102 assert_eq!(choices.len(), 2);
1103 let labels: Vec<&str> = choices
1104 .iter()
1105 .map(|choice| choice.allow.name.as_ref())
1106 .collect();
1107 assert!(labels.contains(&"Always for terminal"));
1108 assert!(labels.contains(&"Only this time"));
1109 assert!(!labels.iter().any(|label| label.contains("commands")));
1110}
1111
1112#[test]
1113fn test_permission_option_ids_for_terminal() {
1114 let permission_options =
1115 ToolPermissionContext::new(TerminalTool::NAME, "cargo build --release")
1116 .build_permission_options();
1117
1118 let PermissionOptions::Dropdown(choices) = permission_options else {
1119 panic!("Expected dropdown permission options");
1120 };
1121
1122 let allow_ids: Vec<String> = choices
1123 .iter()
1124 .map(|choice| choice.allow.option_id.0.to_string())
1125 .collect();
1126 let deny_ids: Vec<String> = choices
1127 .iter()
1128 .map(|choice| choice.deny.option_id.0.to_string())
1129 .collect();
1130
1131 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1132 assert!(allow_ids.contains(&"allow".to_string()));
1133 assert!(
1134 allow_ids
1135 .iter()
1136 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1137 "Missing allow pattern option"
1138 );
1139
1140 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1141 assert!(deny_ids.contains(&"deny".to_string()));
1142 assert!(
1143 deny_ids
1144 .iter()
1145 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1146 "Missing deny pattern option"
1147 );
1148}
1149
1150#[gpui::test]
1151#[cfg_attr(not(feature = "e2e"), ignore)]
1152async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1153 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1154
1155 // Test concurrent tool calls with different delay times
1156 let events = thread
1157 .update(cx, |thread, cx| {
1158 thread.add_tool(DelayTool, None);
1159 thread.send(
1160 UserMessageId::new(),
1161 [
1162 "Call the delay tool twice in the same message.",
1163 "Once with 100ms. Once with 300ms.",
1164 "When both timers are complete, describe the outputs.",
1165 ],
1166 cx,
1167 )
1168 })
1169 .unwrap()
1170 .collect()
1171 .await;
1172
1173 let stop_reasons = stop_events(events);
1174 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1175
1176 thread.update(cx, |thread, _cx| {
1177 let last_message = thread.last_message().unwrap();
1178 let agent_message = last_message.as_agent_message().unwrap();
1179 let text = agent_message
1180 .content
1181 .iter()
1182 .filter_map(|content| {
1183 if let AgentMessageContent::Text(text) = content {
1184 Some(text.as_str())
1185 } else {
1186 None
1187 }
1188 })
1189 .collect::<String>();
1190
1191 assert!(text.contains("Ding"));
1192 });
1193}
1194
1195#[gpui::test]
1196async fn test_profiles(cx: &mut TestAppContext) {
1197 let ThreadTest {
1198 model, thread, fs, ..
1199 } = setup(cx, TestModel::Fake).await;
1200 let fake_model = model.as_fake();
1201
1202 thread.update(cx, |thread, _cx| {
1203 thread.add_tool(DelayTool, None);
1204 thread.add_tool(EchoTool, None);
1205 thread.add_tool(InfiniteTool, None);
1206 });
1207
1208 // Override profiles and wait for settings to be loaded.
1209 fs.insert_file(
1210 paths::settings_file(),
1211 json!({
1212 "agent": {
1213 "profiles": {
1214 "test-1": {
1215 "name": "Test Profile 1",
1216 "tools": {
1217 EchoTool::NAME: true,
1218 DelayTool::NAME: true,
1219 }
1220 },
1221 "test-2": {
1222 "name": "Test Profile 2",
1223 "tools": {
1224 InfiniteTool::NAME: true,
1225 }
1226 }
1227 }
1228 }
1229 })
1230 .to_string()
1231 .into_bytes(),
1232 )
1233 .await;
1234 cx.run_until_parked();
1235
1236 // Test that test-1 profile (default) has echo and delay tools
1237 thread
1238 .update(cx, |thread, cx| {
1239 thread.set_profile(AgentProfileId("test-1".into()), cx);
1240 thread.send(UserMessageId::new(), ["test"], cx)
1241 })
1242 .unwrap();
1243 cx.run_until_parked();
1244
1245 let mut pending_completions = fake_model.pending_completions();
1246 assert_eq!(pending_completions.len(), 1);
1247 let completion = pending_completions.pop().unwrap();
1248 let tool_names: Vec<String> = completion
1249 .tools
1250 .iter()
1251 .map(|tool| tool.name.clone())
1252 .collect();
1253 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1254 fake_model.end_last_completion_stream();
1255
1256 // Switch to test-2 profile, and verify that it has only the infinite tool.
1257 thread
1258 .update(cx, |thread, cx| {
1259 thread.set_profile(AgentProfileId("test-2".into()), cx);
1260 thread.send(UserMessageId::new(), ["test2"], cx)
1261 })
1262 .unwrap();
1263 cx.run_until_parked();
1264 let mut pending_completions = fake_model.pending_completions();
1265 assert_eq!(pending_completions.len(), 1);
1266 let completion = pending_completions.pop().unwrap();
1267 let tool_names: Vec<String> = completion
1268 .tools
1269 .iter()
1270 .map(|tool| tool.name.clone())
1271 .collect();
1272 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1273}
1274
1275#[gpui::test]
1276async fn test_mcp_tools(cx: &mut TestAppContext) {
1277 let ThreadTest {
1278 model,
1279 thread,
1280 context_server_store,
1281 fs,
1282 ..
1283 } = setup(cx, TestModel::Fake).await;
1284 let fake_model = model.as_fake();
1285
1286 // Override profiles and wait for settings to be loaded.
1287 fs.insert_file(
1288 paths::settings_file(),
1289 json!({
1290 "agent": {
1291 "always_allow_tool_actions": true,
1292 "profiles": {
1293 "test": {
1294 "name": "Test Profile",
1295 "enable_all_context_servers": true,
1296 "tools": {
1297 EchoTool::NAME: true,
1298 }
1299 },
1300 }
1301 }
1302 })
1303 .to_string()
1304 .into_bytes(),
1305 )
1306 .await;
1307 cx.run_until_parked();
1308 thread.update(cx, |thread, cx| {
1309 thread.set_profile(AgentProfileId("test".into()), cx)
1310 });
1311
1312 let mut mcp_tool_calls = setup_context_server(
1313 "test_server",
1314 vec![context_server::types::Tool {
1315 name: "echo".into(),
1316 description: None,
1317 input_schema: serde_json::to_value(EchoTool::input_schema(
1318 LanguageModelToolSchemaFormat::JsonSchema,
1319 ))
1320 .unwrap(),
1321 output_schema: None,
1322 annotations: None,
1323 }],
1324 &context_server_store,
1325 cx,
1326 );
1327
1328 let events = thread.update(cx, |thread, cx| {
1329 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1330 });
1331 cx.run_until_parked();
1332
1333 // Simulate the model calling the MCP tool.
1334 let completion = fake_model.pending_completions().pop().unwrap();
1335 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1336 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1337 LanguageModelToolUse {
1338 id: "tool_1".into(),
1339 name: "echo".into(),
1340 raw_input: json!({"text": "test"}).to_string(),
1341 input: json!({"text": "test"}),
1342 is_input_complete: true,
1343 thought_signature: None,
1344 },
1345 ));
1346 fake_model.end_last_completion_stream();
1347 cx.run_until_parked();
1348
1349 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1350 assert_eq!(tool_call_params.name, "echo");
1351 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1352 tool_call_response
1353 .send(context_server::types::CallToolResponse {
1354 content: vec![context_server::types::ToolResponseContent::Text {
1355 text: "test".into(),
1356 }],
1357 is_error: None,
1358 meta: None,
1359 structured_content: None,
1360 })
1361 .unwrap();
1362 cx.run_until_parked();
1363
1364 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1365 fake_model.send_last_completion_stream_text_chunk("Done!");
1366 fake_model.end_last_completion_stream();
1367 events.collect::<Vec<_>>().await;
1368
1369 // Send again after adding the echo tool, ensuring the name collision is resolved.
1370 let events = thread.update(cx, |thread, cx| {
1371 thread.add_tool(EchoTool, None);
1372 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1373 });
1374 cx.run_until_parked();
1375 let completion = fake_model.pending_completions().pop().unwrap();
1376 assert_eq!(
1377 tool_names_for_completion(&completion),
1378 vec!["echo", "test_server_echo"]
1379 );
1380 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1381 LanguageModelToolUse {
1382 id: "tool_2".into(),
1383 name: "test_server_echo".into(),
1384 raw_input: json!({"text": "mcp"}).to_string(),
1385 input: json!({"text": "mcp"}),
1386 is_input_complete: true,
1387 thought_signature: None,
1388 },
1389 ));
1390 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1391 LanguageModelToolUse {
1392 id: "tool_3".into(),
1393 name: "echo".into(),
1394 raw_input: json!({"text": "native"}).to_string(),
1395 input: json!({"text": "native"}),
1396 is_input_complete: true,
1397 thought_signature: None,
1398 },
1399 ));
1400 fake_model.end_last_completion_stream();
1401 cx.run_until_parked();
1402
1403 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1404 assert_eq!(tool_call_params.name, "echo");
1405 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1406 tool_call_response
1407 .send(context_server::types::CallToolResponse {
1408 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1409 is_error: None,
1410 meta: None,
1411 structured_content: None,
1412 })
1413 .unwrap();
1414 cx.run_until_parked();
1415
1416 // Ensure the tool results were inserted with the correct names.
1417 let completion = fake_model.pending_completions().pop().unwrap();
1418 assert_eq!(
1419 completion.messages.last().unwrap().content,
1420 vec![
1421 MessageContent::ToolResult(LanguageModelToolResult {
1422 tool_use_id: "tool_3".into(),
1423 tool_name: "echo".into(),
1424 is_error: false,
1425 content: "native".into(),
1426 output: Some("native".into()),
1427 },),
1428 MessageContent::ToolResult(LanguageModelToolResult {
1429 tool_use_id: "tool_2".into(),
1430 tool_name: "test_server_echo".into(),
1431 is_error: false,
1432 content: "mcp".into(),
1433 output: Some("mcp".into()),
1434 },),
1435 ]
1436 );
1437 fake_model.end_last_completion_stream();
1438 events.collect::<Vec<_>>().await;
1439}
1440
1441#[gpui::test]
1442async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1443 let ThreadTest {
1444 model,
1445 thread,
1446 context_server_store,
1447 fs,
1448 ..
1449 } = setup(cx, TestModel::Fake).await;
1450 let fake_model = model.as_fake();
1451
1452 // Set up a profile with all tools enabled
1453 fs.insert_file(
1454 paths::settings_file(),
1455 json!({
1456 "agent": {
1457 "profiles": {
1458 "test": {
1459 "name": "Test Profile",
1460 "enable_all_context_servers": true,
1461 "tools": {
1462 EchoTool::NAME: true,
1463 DelayTool::NAME: true,
1464 WordListTool::NAME: true,
1465 ToolRequiringPermission::NAME: true,
1466 InfiniteTool::NAME: true,
1467 }
1468 },
1469 }
1470 }
1471 })
1472 .to_string()
1473 .into_bytes(),
1474 )
1475 .await;
1476 cx.run_until_parked();
1477
1478 thread.update(cx, |thread, cx| {
1479 thread.set_profile(AgentProfileId("test".into()), cx);
1480 thread.add_tool(EchoTool, None);
1481 thread.add_tool(DelayTool, None);
1482 thread.add_tool(WordListTool, None);
1483 thread.add_tool(ToolRequiringPermission, None);
1484 thread.add_tool(InfiniteTool, None);
1485 });
1486
1487 // Set up multiple context servers with some overlapping tool names
1488 let _server1_calls = setup_context_server(
1489 "xxx",
1490 vec![
1491 context_server::types::Tool {
1492 name: "echo".into(), // Conflicts with native EchoTool
1493 description: None,
1494 input_schema: serde_json::to_value(EchoTool::input_schema(
1495 LanguageModelToolSchemaFormat::JsonSchema,
1496 ))
1497 .unwrap(),
1498 output_schema: None,
1499 annotations: None,
1500 },
1501 context_server::types::Tool {
1502 name: "unique_tool_1".into(),
1503 description: None,
1504 input_schema: json!({"type": "object", "properties": {}}),
1505 output_schema: None,
1506 annotations: None,
1507 },
1508 ],
1509 &context_server_store,
1510 cx,
1511 );
1512
1513 let _server2_calls = setup_context_server(
1514 "yyy",
1515 vec![
1516 context_server::types::Tool {
1517 name: "echo".into(), // Also conflicts with native EchoTool
1518 description: None,
1519 input_schema: serde_json::to_value(EchoTool::input_schema(
1520 LanguageModelToolSchemaFormat::JsonSchema,
1521 ))
1522 .unwrap(),
1523 output_schema: None,
1524 annotations: None,
1525 },
1526 context_server::types::Tool {
1527 name: "unique_tool_2".into(),
1528 description: None,
1529 input_schema: json!({"type": "object", "properties": {}}),
1530 output_schema: None,
1531 annotations: None,
1532 },
1533 context_server::types::Tool {
1534 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1535 description: None,
1536 input_schema: json!({"type": "object", "properties": {}}),
1537 output_schema: None,
1538 annotations: None,
1539 },
1540 context_server::types::Tool {
1541 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1542 description: None,
1543 input_schema: json!({"type": "object", "properties": {}}),
1544 output_schema: None,
1545 annotations: None,
1546 },
1547 ],
1548 &context_server_store,
1549 cx,
1550 );
1551 let _server3_calls = setup_context_server(
1552 "zzz",
1553 vec![
1554 context_server::types::Tool {
1555 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1556 description: None,
1557 input_schema: json!({"type": "object", "properties": {}}),
1558 output_schema: None,
1559 annotations: None,
1560 },
1561 context_server::types::Tool {
1562 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1563 description: None,
1564 input_schema: json!({"type": "object", "properties": {}}),
1565 output_schema: None,
1566 annotations: None,
1567 },
1568 context_server::types::Tool {
1569 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1570 description: None,
1571 input_schema: json!({"type": "object", "properties": {}}),
1572 output_schema: None,
1573 annotations: None,
1574 },
1575 ],
1576 &context_server_store,
1577 cx,
1578 );
1579
1580 thread
1581 .update(cx, |thread, cx| {
1582 thread.send(UserMessageId::new(), ["Go"], cx)
1583 })
1584 .unwrap();
1585 cx.run_until_parked();
1586 let completion = fake_model.pending_completions().pop().unwrap();
1587 assert_eq!(
1588 tool_names_for_completion(&completion),
1589 vec![
1590 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1591 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1592 "delay",
1593 "echo",
1594 "infinite",
1595 "tool_requiring_permission",
1596 "unique_tool_1",
1597 "unique_tool_2",
1598 "word_list",
1599 "xxx_echo",
1600 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1601 "yyy_echo",
1602 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1603 ]
1604 );
1605}
1606
1607#[gpui::test]
1608#[cfg_attr(not(feature = "e2e"), ignore)]
1609async fn test_cancellation(cx: &mut TestAppContext) {
1610 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1611
1612 let mut events = thread
1613 .update(cx, |thread, cx| {
1614 thread.add_tool(InfiniteTool, None);
1615 thread.add_tool(EchoTool, None);
1616 thread.send(
1617 UserMessageId::new(),
1618 ["Call the echo tool, then call the infinite tool, then explain their output"],
1619 cx,
1620 )
1621 })
1622 .unwrap();
1623
1624 // Wait until both tools are called.
1625 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1626 let mut echo_id = None;
1627 let mut echo_completed = false;
1628 while let Some(event) = events.next().await {
1629 match event.unwrap() {
1630 ThreadEvent::ToolCall(tool_call) => {
1631 assert_eq!(tool_call.title, expected_tools.remove(0));
1632 if tool_call.title == "Echo" {
1633 echo_id = Some(tool_call.tool_call_id);
1634 }
1635 }
1636 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1637 acp::ToolCallUpdate {
1638 tool_call_id,
1639 fields:
1640 acp::ToolCallUpdateFields {
1641 status: Some(acp::ToolCallStatus::Completed),
1642 ..
1643 },
1644 ..
1645 },
1646 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1647 echo_completed = true;
1648 }
1649 _ => {}
1650 }
1651
1652 if expected_tools.is_empty() && echo_completed {
1653 break;
1654 }
1655 }
1656
1657 // Cancel the current send and ensure that the event stream is closed, even
1658 // if one of the tools is still running.
1659 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1660 let events = events.collect::<Vec<_>>().await;
1661 let last_event = events.last();
1662 assert!(
1663 matches!(
1664 last_event,
1665 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1666 ),
1667 "unexpected event {last_event:?}"
1668 );
1669
1670 // Ensure we can still send a new message after cancellation.
1671 let events = thread
1672 .update(cx, |thread, cx| {
1673 thread.send(
1674 UserMessageId::new(),
1675 ["Testing: reply with 'Hello' then stop."],
1676 cx,
1677 )
1678 })
1679 .unwrap()
1680 .collect::<Vec<_>>()
1681 .await;
1682 thread.update(cx, |thread, _cx| {
1683 let message = thread.last_message().unwrap();
1684 let agent_message = message.as_agent_message().unwrap();
1685 assert_eq!(
1686 agent_message.content,
1687 vec![AgentMessageContent::Text("Hello".to_string())]
1688 );
1689 });
1690 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1691}
1692
1693#[gpui::test]
1694async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1695 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1696 always_allow_tools(cx);
1697 let fake_model = model.as_fake();
1698
1699 let environment = Rc::new(cx.update(|cx| {
1700 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1701 }));
1702 let handle = environment.terminal_handle.clone().unwrap();
1703
1704 let mut events = thread
1705 .update(cx, |thread, cx| {
1706 thread.add_tool(
1707 crate::TerminalTool::new(thread.project().clone(), environment),
1708 None,
1709 );
1710 thread.send(UserMessageId::new(), ["run a command"], cx)
1711 })
1712 .unwrap();
1713
1714 cx.run_until_parked();
1715
1716 // Simulate the model calling the terminal tool
1717 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1718 LanguageModelToolUse {
1719 id: "terminal_tool_1".into(),
1720 name: TerminalTool::NAME.into(),
1721 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1722 input: json!({"command": "sleep 1000", "cd": "."}),
1723 is_input_complete: true,
1724 thought_signature: None,
1725 },
1726 ));
1727 fake_model.end_last_completion_stream();
1728
1729 // Wait for the terminal tool to start running
1730 wait_for_terminal_tool_started(&mut events, cx).await;
1731
1732 // Cancel the thread while the terminal is running
1733 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1734
1735 // Collect remaining events, driving the executor to let cancellation complete
1736 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1737
1738 // Verify the terminal was killed
1739 assert!(
1740 handle.was_killed(),
1741 "expected terminal handle to be killed on cancellation"
1742 );
1743
1744 // Verify we got a cancellation stop event
1745 assert_eq!(
1746 stop_events(remaining_events),
1747 vec![acp::StopReason::Cancelled],
1748 );
1749
1750 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1751 thread.update(cx, |thread, _cx| {
1752 let message = thread.last_message().unwrap();
1753 let agent_message = message.as_agent_message().unwrap();
1754
1755 let tool_use = agent_message
1756 .content
1757 .iter()
1758 .find_map(|content| match content {
1759 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1760 _ => None,
1761 })
1762 .expect("expected tool use in agent message");
1763
1764 let tool_result = agent_message
1765 .tool_results
1766 .get(&tool_use.id)
1767 .expect("expected tool result");
1768
1769 let result_text = match &tool_result.content {
1770 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1771 _ => panic!("expected text content in tool result"),
1772 };
1773
1774 // "partial output" comes from FakeTerminalHandle's output field
1775 assert!(
1776 result_text.contains("partial output"),
1777 "expected tool result to contain terminal output, got: {result_text}"
1778 );
1779 // Match the actual format from process_content in terminal_tool.rs
1780 assert!(
1781 result_text.contains("The user stopped this command"),
1782 "expected tool result to indicate user stopped, got: {result_text}"
1783 );
1784 });
1785
1786 // Verify we can send a new message after cancellation
1787 verify_thread_recovery(&thread, &fake_model, cx).await;
1788}
1789
1790#[gpui::test]
1791async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1792 // This test verifies that tools which properly handle cancellation via
1793 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1794 // to cancellation and report that they were cancelled.
1795 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1796 always_allow_tools(cx);
1797 let fake_model = model.as_fake();
1798
1799 let (tool, was_cancelled) = CancellationAwareTool::new();
1800
1801 let mut events = thread
1802 .update(cx, |thread, cx| {
1803 thread.add_tool(tool, None);
1804 thread.send(
1805 UserMessageId::new(),
1806 ["call the cancellation aware tool"],
1807 cx,
1808 )
1809 })
1810 .unwrap();
1811
1812 cx.run_until_parked();
1813
1814 // Simulate the model calling the cancellation-aware tool
1815 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1816 LanguageModelToolUse {
1817 id: "cancellation_aware_1".into(),
1818 name: "cancellation_aware".into(),
1819 raw_input: r#"{}"#.into(),
1820 input: json!({}),
1821 is_input_complete: true,
1822 thought_signature: None,
1823 },
1824 ));
1825 fake_model.end_last_completion_stream();
1826
1827 cx.run_until_parked();
1828
1829 // Wait for the tool call to be reported
1830 let mut tool_started = false;
1831 let deadline = cx.executor().num_cpus() * 100;
1832 for _ in 0..deadline {
1833 cx.run_until_parked();
1834
1835 while let Some(Some(event)) = events.next().now_or_never() {
1836 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1837 if tool_call.title == "Cancellation Aware Tool" {
1838 tool_started = true;
1839 break;
1840 }
1841 }
1842 }
1843
1844 if tool_started {
1845 break;
1846 }
1847
1848 cx.background_executor
1849 .timer(Duration::from_millis(10))
1850 .await;
1851 }
1852 assert!(tool_started, "expected cancellation aware tool to start");
1853
1854 // Cancel the thread and wait for it to complete
1855 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1856
1857 // The cancel task should complete promptly because the tool handles cancellation
1858 let timeout = cx.background_executor.timer(Duration::from_secs(5));
1859 futures::select! {
1860 _ = cancel_task.fuse() => {}
1861 _ = timeout.fuse() => {
1862 panic!("cancel task timed out - tool did not respond to cancellation");
1863 }
1864 }
1865
1866 // Verify the tool detected cancellation via its flag
1867 assert!(
1868 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1869 "tool should have detected cancellation via event_stream.cancelled_by_user()"
1870 );
1871
1872 // Collect remaining events
1873 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1874
1875 // Verify we got a cancellation stop event
1876 assert_eq!(
1877 stop_events(remaining_events),
1878 vec![acp::StopReason::Cancelled],
1879 );
1880
1881 // Verify we can send a new message after cancellation
1882 verify_thread_recovery(&thread, &fake_model, cx).await;
1883}
1884
1885/// Helper to verify thread can recover after cancellation by sending a simple message.
1886async fn verify_thread_recovery(
1887 thread: &Entity<Thread>,
1888 fake_model: &FakeLanguageModel,
1889 cx: &mut TestAppContext,
1890) {
1891 let events = thread
1892 .update(cx, |thread, cx| {
1893 thread.send(
1894 UserMessageId::new(),
1895 ["Testing: reply with 'Hello' then stop."],
1896 cx,
1897 )
1898 })
1899 .unwrap();
1900 cx.run_until_parked();
1901 fake_model.send_last_completion_stream_text_chunk("Hello");
1902 fake_model
1903 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1904 fake_model.end_last_completion_stream();
1905
1906 let events = events.collect::<Vec<_>>().await;
1907 thread.update(cx, |thread, _cx| {
1908 let message = thread.last_message().unwrap();
1909 let agent_message = message.as_agent_message().unwrap();
1910 assert_eq!(
1911 agent_message.content,
1912 vec![AgentMessageContent::Text("Hello".to_string())]
1913 );
1914 });
1915 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1916}
1917
1918/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1919async fn wait_for_terminal_tool_started(
1920 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1921 cx: &mut TestAppContext,
1922) {
1923 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1924 for _ in 0..deadline {
1925 cx.run_until_parked();
1926
1927 while let Some(Some(event)) = events.next().now_or_never() {
1928 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1929 update,
1930 ))) = &event
1931 {
1932 if update.fields.content.as_ref().is_some_and(|content| {
1933 content
1934 .iter()
1935 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1936 }) {
1937 return;
1938 }
1939 }
1940 }
1941
1942 cx.background_executor
1943 .timer(Duration::from_millis(10))
1944 .await;
1945 }
1946 panic!("terminal tool did not start within the expected time");
1947}
1948
1949/// Collects events until a Stop event is received, driving the executor to completion.
1950async fn collect_events_until_stop(
1951 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1952 cx: &mut TestAppContext,
1953) -> Vec<Result<ThreadEvent>> {
1954 let mut collected = Vec::new();
1955 let deadline = cx.executor().num_cpus() * 200;
1956
1957 for _ in 0..deadline {
1958 cx.executor().advance_clock(Duration::from_millis(10));
1959 cx.run_until_parked();
1960
1961 while let Some(Some(event)) = events.next().now_or_never() {
1962 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1963 collected.push(event);
1964 if is_stop {
1965 return collected;
1966 }
1967 }
1968 }
1969 panic!(
1970 "did not receive Stop event within the expected time; collected {} events",
1971 collected.len()
1972 );
1973}
1974
1975#[gpui::test]
1976async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1977 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1978 always_allow_tools(cx);
1979 let fake_model = model.as_fake();
1980
1981 let environment = Rc::new(cx.update(|cx| {
1982 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1983 }));
1984 let handle = environment.terminal_handle.clone().unwrap();
1985
1986 let message_id = UserMessageId::new();
1987 let mut events = thread
1988 .update(cx, |thread, cx| {
1989 thread.add_tool(
1990 crate::TerminalTool::new(thread.project().clone(), environment),
1991 None,
1992 );
1993 thread.send(message_id.clone(), ["run a command"], cx)
1994 })
1995 .unwrap();
1996
1997 cx.run_until_parked();
1998
1999 // Simulate the model calling the terminal tool
2000 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2001 LanguageModelToolUse {
2002 id: "terminal_tool_1".into(),
2003 name: TerminalTool::NAME.into(),
2004 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2005 input: json!({"command": "sleep 1000", "cd": "."}),
2006 is_input_complete: true,
2007 thought_signature: None,
2008 },
2009 ));
2010 fake_model.end_last_completion_stream();
2011
2012 // Wait for the terminal tool to start running
2013 wait_for_terminal_tool_started(&mut events, cx).await;
2014
2015 // Truncate the thread while the terminal is running
2016 thread
2017 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2018 .unwrap();
2019
2020 // Drive the executor to let cancellation complete
2021 let _ = collect_events_until_stop(&mut events, cx).await;
2022
2023 // Verify the terminal was killed
2024 assert!(
2025 handle.was_killed(),
2026 "expected terminal handle to be killed on truncate"
2027 );
2028
2029 // Verify the thread is empty after truncation
2030 thread.update(cx, |thread, _cx| {
2031 assert_eq!(
2032 thread.to_markdown(),
2033 "",
2034 "expected thread to be empty after truncating the only message"
2035 );
2036 });
2037
2038 // Verify we can send a new message after truncation
2039 verify_thread_recovery(&thread, &fake_model, cx).await;
2040}
2041
2042#[gpui::test]
2043async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2044 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2045 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2046 always_allow_tools(cx);
2047 let fake_model = model.as_fake();
2048
2049 let environment = Rc::new(MultiTerminalEnvironment::new());
2050
2051 let mut events = thread
2052 .update(cx, |thread, cx| {
2053 thread.add_tool(
2054 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2055 None,
2056 );
2057 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2058 })
2059 .unwrap();
2060
2061 cx.run_until_parked();
2062
2063 // Simulate the model calling two terminal tools
2064 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2065 LanguageModelToolUse {
2066 id: "terminal_tool_1".into(),
2067 name: TerminalTool::NAME.into(),
2068 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2069 input: json!({"command": "sleep 1000", "cd": "."}),
2070 is_input_complete: true,
2071 thought_signature: None,
2072 },
2073 ));
2074 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2075 LanguageModelToolUse {
2076 id: "terminal_tool_2".into(),
2077 name: TerminalTool::NAME.into(),
2078 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2079 input: json!({"command": "sleep 2000", "cd": "."}),
2080 is_input_complete: true,
2081 thought_signature: None,
2082 },
2083 ));
2084 fake_model.end_last_completion_stream();
2085
2086 // Wait for both terminal tools to start by counting terminal content updates
2087 let mut terminals_started = 0;
2088 let deadline = cx.executor().num_cpus() * 100;
2089 for _ in 0..deadline {
2090 cx.run_until_parked();
2091
2092 while let Some(Some(event)) = events.next().now_or_never() {
2093 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2094 update,
2095 ))) = &event
2096 {
2097 if update.fields.content.as_ref().is_some_and(|content| {
2098 content
2099 .iter()
2100 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2101 }) {
2102 terminals_started += 1;
2103 if terminals_started >= 2 {
2104 break;
2105 }
2106 }
2107 }
2108 }
2109 if terminals_started >= 2 {
2110 break;
2111 }
2112
2113 cx.background_executor
2114 .timer(Duration::from_millis(10))
2115 .await;
2116 }
2117 assert!(
2118 terminals_started >= 2,
2119 "expected 2 terminal tools to start, got {terminals_started}"
2120 );
2121
2122 // Cancel the thread while both terminals are running
2123 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2124
2125 // Collect remaining events
2126 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2127
2128 // Verify both terminal handles were killed
2129 let handles = environment.handles();
2130 assert_eq!(
2131 handles.len(),
2132 2,
2133 "expected 2 terminal handles to be created"
2134 );
2135 assert!(
2136 handles[0].was_killed(),
2137 "expected first terminal handle to be killed on cancellation"
2138 );
2139 assert!(
2140 handles[1].was_killed(),
2141 "expected second terminal handle to be killed on cancellation"
2142 );
2143
2144 // Verify we got a cancellation stop event
2145 assert_eq!(
2146 stop_events(remaining_events),
2147 vec![acp::StopReason::Cancelled],
2148 );
2149}
2150
2151#[gpui::test]
2152async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2153 // Tests that clicking the stop button on the terminal card (as opposed to the main
2154 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2155 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2156 always_allow_tools(cx);
2157 let fake_model = model.as_fake();
2158
2159 let environment = Rc::new(cx.update(|cx| {
2160 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2161 }));
2162 let handle = environment.terminal_handle.clone().unwrap();
2163
2164 let mut events = thread
2165 .update(cx, |thread, cx| {
2166 thread.add_tool(
2167 crate::TerminalTool::new(thread.project().clone(), environment),
2168 None,
2169 );
2170 thread.send(UserMessageId::new(), ["run a command"], cx)
2171 })
2172 .unwrap();
2173
2174 cx.run_until_parked();
2175
2176 // Simulate the model calling the terminal tool
2177 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2178 LanguageModelToolUse {
2179 id: "terminal_tool_1".into(),
2180 name: TerminalTool::NAME.into(),
2181 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2182 input: json!({"command": "sleep 1000", "cd": "."}),
2183 is_input_complete: true,
2184 thought_signature: None,
2185 },
2186 ));
2187 fake_model.end_last_completion_stream();
2188
2189 // Wait for the terminal tool to start running
2190 wait_for_terminal_tool_started(&mut events, cx).await;
2191
2192 // Simulate user clicking stop on the terminal card itself.
2193 // This sets the flag and signals exit (simulating what the real UI would do).
2194 handle.set_stopped_by_user(true);
2195 handle.killed.store(true, Ordering::SeqCst);
2196 handle.signal_exit();
2197
2198 // Wait for the tool to complete
2199 cx.run_until_parked();
2200
2201 // The thread continues after tool completion - simulate the model ending its turn
2202 fake_model
2203 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2204 fake_model.end_last_completion_stream();
2205
2206 // Collect remaining events
2207 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2208
2209 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2210 assert_eq!(
2211 stop_events(remaining_events),
2212 vec![acp::StopReason::EndTurn],
2213 );
2214
2215 // Verify the tool result indicates user stopped
2216 thread.update(cx, |thread, _cx| {
2217 let message = thread.last_message().unwrap();
2218 let agent_message = message.as_agent_message().unwrap();
2219
2220 let tool_use = agent_message
2221 .content
2222 .iter()
2223 .find_map(|content| match content {
2224 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2225 _ => None,
2226 })
2227 .expect("expected tool use in agent message");
2228
2229 let tool_result = agent_message
2230 .tool_results
2231 .get(&tool_use.id)
2232 .expect("expected tool result");
2233
2234 let result_text = match &tool_result.content {
2235 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2236 _ => panic!("expected text content in tool result"),
2237 };
2238
2239 assert!(
2240 result_text.contains("The user stopped this command"),
2241 "expected tool result to indicate user stopped, got: {result_text}"
2242 );
2243 });
2244}
2245
2246#[gpui::test]
2247async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2248 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2249 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2250 always_allow_tools(cx);
2251 let fake_model = model.as_fake();
2252
2253 let environment = Rc::new(cx.update(|cx| {
2254 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2255 }));
2256 let handle = environment.terminal_handle.clone().unwrap();
2257
2258 let mut events = thread
2259 .update(cx, |thread, cx| {
2260 thread.add_tool(
2261 crate::TerminalTool::new(thread.project().clone(), environment),
2262 None,
2263 );
2264 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2265 })
2266 .unwrap();
2267
2268 cx.run_until_parked();
2269
2270 // Simulate the model calling the terminal tool with a short timeout
2271 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2272 LanguageModelToolUse {
2273 id: "terminal_tool_1".into(),
2274 name: TerminalTool::NAME.into(),
2275 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2276 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2277 is_input_complete: true,
2278 thought_signature: None,
2279 },
2280 ));
2281 fake_model.end_last_completion_stream();
2282
2283 // Wait for the terminal tool to start running
2284 wait_for_terminal_tool_started(&mut events, cx).await;
2285
2286 // Advance clock past the timeout
2287 cx.executor().advance_clock(Duration::from_millis(200));
2288 cx.run_until_parked();
2289
2290 // The thread continues after tool completion - simulate the model ending its turn
2291 fake_model
2292 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2293 fake_model.end_last_completion_stream();
2294
2295 // Collect remaining events
2296 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2297
2298 // Verify the terminal was killed due to timeout
2299 assert!(
2300 handle.was_killed(),
2301 "expected terminal handle to be killed on timeout"
2302 );
2303
2304 // Verify we got an EndTurn (the tool completed, just with timeout)
2305 assert_eq!(
2306 stop_events(remaining_events),
2307 vec![acp::StopReason::EndTurn],
2308 );
2309
2310 // Verify the tool result indicates timeout, not user stopped
2311 thread.update(cx, |thread, _cx| {
2312 let message = thread.last_message().unwrap();
2313 let agent_message = message.as_agent_message().unwrap();
2314
2315 let tool_use = agent_message
2316 .content
2317 .iter()
2318 .find_map(|content| match content {
2319 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2320 _ => None,
2321 })
2322 .expect("expected tool use in agent message");
2323
2324 let tool_result = agent_message
2325 .tool_results
2326 .get(&tool_use.id)
2327 .expect("expected tool result");
2328
2329 let result_text = match &tool_result.content {
2330 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2331 _ => panic!("expected text content in tool result"),
2332 };
2333
2334 assert!(
2335 result_text.contains("timed out"),
2336 "expected tool result to indicate timeout, got: {result_text}"
2337 );
2338 assert!(
2339 !result_text.contains("The user stopped"),
2340 "tool result should not mention user stopped when it timed out, got: {result_text}"
2341 );
2342 });
2343}
2344
2345#[gpui::test]
2346async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2347 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2348 let fake_model = model.as_fake();
2349
2350 let events_1 = thread
2351 .update(cx, |thread, cx| {
2352 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2353 })
2354 .unwrap();
2355 cx.run_until_parked();
2356 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2357 cx.run_until_parked();
2358
2359 let events_2 = thread
2360 .update(cx, |thread, cx| {
2361 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2362 })
2363 .unwrap();
2364 cx.run_until_parked();
2365 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2366 fake_model
2367 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2368 fake_model.end_last_completion_stream();
2369
2370 let events_1 = events_1.collect::<Vec<_>>().await;
2371 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2372 let events_2 = events_2.collect::<Vec<_>>().await;
2373 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2374}
2375
2376#[gpui::test]
2377async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2378 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2379 let fake_model = model.as_fake();
2380
2381 let events_1 = thread
2382 .update(cx, |thread, cx| {
2383 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2384 })
2385 .unwrap();
2386 cx.run_until_parked();
2387 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2388 fake_model
2389 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2390 fake_model.end_last_completion_stream();
2391 let events_1 = events_1.collect::<Vec<_>>().await;
2392
2393 let events_2 = thread
2394 .update(cx, |thread, cx| {
2395 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2396 })
2397 .unwrap();
2398 cx.run_until_parked();
2399 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2400 fake_model
2401 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2402 fake_model.end_last_completion_stream();
2403 let events_2 = events_2.collect::<Vec<_>>().await;
2404
2405 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2406 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2407}
2408
2409#[gpui::test]
2410async fn test_refusal(cx: &mut TestAppContext) {
2411 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2412 let fake_model = model.as_fake();
2413
2414 let events = thread
2415 .update(cx, |thread, cx| {
2416 thread.send(UserMessageId::new(), ["Hello"], cx)
2417 })
2418 .unwrap();
2419 cx.run_until_parked();
2420 thread.read_with(cx, |thread, _| {
2421 assert_eq!(
2422 thread.to_markdown(),
2423 indoc! {"
2424 ## User
2425
2426 Hello
2427 "}
2428 );
2429 });
2430
2431 fake_model.send_last_completion_stream_text_chunk("Hey!");
2432 cx.run_until_parked();
2433 thread.read_with(cx, |thread, _| {
2434 assert_eq!(
2435 thread.to_markdown(),
2436 indoc! {"
2437 ## User
2438
2439 Hello
2440
2441 ## Assistant
2442
2443 Hey!
2444 "}
2445 );
2446 });
2447
2448 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2449 fake_model
2450 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2451 let events = events.collect::<Vec<_>>().await;
2452 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2453 thread.read_with(cx, |thread, _| {
2454 assert_eq!(thread.to_markdown(), "");
2455 });
2456}
2457
2458#[gpui::test]
2459async fn test_truncate_first_message(cx: &mut TestAppContext) {
2460 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2461 let fake_model = model.as_fake();
2462
2463 let message_id = UserMessageId::new();
2464 thread
2465 .update(cx, |thread, cx| {
2466 thread.send(message_id.clone(), ["Hello"], cx)
2467 })
2468 .unwrap();
2469 cx.run_until_parked();
2470 thread.read_with(cx, |thread, _| {
2471 assert_eq!(
2472 thread.to_markdown(),
2473 indoc! {"
2474 ## User
2475
2476 Hello
2477 "}
2478 );
2479 assert_eq!(thread.latest_token_usage(), None);
2480 });
2481
2482 fake_model.send_last_completion_stream_text_chunk("Hey!");
2483 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2484 language_model::TokenUsage {
2485 input_tokens: 32_000,
2486 output_tokens: 16_000,
2487 cache_creation_input_tokens: 0,
2488 cache_read_input_tokens: 0,
2489 },
2490 ));
2491 cx.run_until_parked();
2492 thread.read_with(cx, |thread, _| {
2493 assert_eq!(
2494 thread.to_markdown(),
2495 indoc! {"
2496 ## User
2497
2498 Hello
2499
2500 ## Assistant
2501
2502 Hey!
2503 "}
2504 );
2505 assert_eq!(
2506 thread.latest_token_usage(),
2507 Some(acp_thread::TokenUsage {
2508 used_tokens: 32_000 + 16_000,
2509 max_tokens: 1_000_000,
2510 input_tokens: 32_000,
2511 output_tokens: 16_000,
2512 })
2513 );
2514 });
2515
2516 thread
2517 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2518 .unwrap();
2519 cx.run_until_parked();
2520 thread.read_with(cx, |thread, _| {
2521 assert_eq!(thread.to_markdown(), "");
2522 assert_eq!(thread.latest_token_usage(), None);
2523 });
2524
2525 // Ensure we can still send a new message after truncation.
2526 thread
2527 .update(cx, |thread, cx| {
2528 thread.send(UserMessageId::new(), ["Hi"], cx)
2529 })
2530 .unwrap();
2531 thread.update(cx, |thread, _cx| {
2532 assert_eq!(
2533 thread.to_markdown(),
2534 indoc! {"
2535 ## User
2536
2537 Hi
2538 "}
2539 );
2540 });
2541 cx.run_until_parked();
2542 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2543 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2544 language_model::TokenUsage {
2545 input_tokens: 40_000,
2546 output_tokens: 20_000,
2547 cache_creation_input_tokens: 0,
2548 cache_read_input_tokens: 0,
2549 },
2550 ));
2551 cx.run_until_parked();
2552 thread.read_with(cx, |thread, _| {
2553 assert_eq!(
2554 thread.to_markdown(),
2555 indoc! {"
2556 ## User
2557
2558 Hi
2559
2560 ## Assistant
2561
2562 Ahoy!
2563 "}
2564 );
2565
2566 assert_eq!(
2567 thread.latest_token_usage(),
2568 Some(acp_thread::TokenUsage {
2569 used_tokens: 40_000 + 20_000,
2570 max_tokens: 1_000_000,
2571 input_tokens: 40_000,
2572 output_tokens: 20_000,
2573 })
2574 );
2575 });
2576}
2577
2578#[gpui::test]
2579async fn test_truncate_second_message(cx: &mut TestAppContext) {
2580 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2581 let fake_model = model.as_fake();
2582
2583 thread
2584 .update(cx, |thread, cx| {
2585 thread.send(UserMessageId::new(), ["Message 1"], cx)
2586 })
2587 .unwrap();
2588 cx.run_until_parked();
2589 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2590 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2591 language_model::TokenUsage {
2592 input_tokens: 32_000,
2593 output_tokens: 16_000,
2594 cache_creation_input_tokens: 0,
2595 cache_read_input_tokens: 0,
2596 },
2597 ));
2598 fake_model.end_last_completion_stream();
2599 cx.run_until_parked();
2600
2601 let assert_first_message_state = |cx: &mut TestAppContext| {
2602 thread.clone().read_with(cx, |thread, _| {
2603 assert_eq!(
2604 thread.to_markdown(),
2605 indoc! {"
2606 ## User
2607
2608 Message 1
2609
2610 ## Assistant
2611
2612 Message 1 response
2613 "}
2614 );
2615
2616 assert_eq!(
2617 thread.latest_token_usage(),
2618 Some(acp_thread::TokenUsage {
2619 used_tokens: 32_000 + 16_000,
2620 max_tokens: 1_000_000,
2621 input_tokens: 32_000,
2622 output_tokens: 16_000,
2623 })
2624 );
2625 });
2626 };
2627
2628 assert_first_message_state(cx);
2629
2630 let second_message_id = UserMessageId::new();
2631 thread
2632 .update(cx, |thread, cx| {
2633 thread.send(second_message_id.clone(), ["Message 2"], cx)
2634 })
2635 .unwrap();
2636 cx.run_until_parked();
2637
2638 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2639 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2640 language_model::TokenUsage {
2641 input_tokens: 40_000,
2642 output_tokens: 20_000,
2643 cache_creation_input_tokens: 0,
2644 cache_read_input_tokens: 0,
2645 },
2646 ));
2647 fake_model.end_last_completion_stream();
2648 cx.run_until_parked();
2649
2650 thread.read_with(cx, |thread, _| {
2651 assert_eq!(
2652 thread.to_markdown(),
2653 indoc! {"
2654 ## User
2655
2656 Message 1
2657
2658 ## Assistant
2659
2660 Message 1 response
2661
2662 ## User
2663
2664 Message 2
2665
2666 ## Assistant
2667
2668 Message 2 response
2669 "}
2670 );
2671
2672 assert_eq!(
2673 thread.latest_token_usage(),
2674 Some(acp_thread::TokenUsage {
2675 used_tokens: 40_000 + 20_000,
2676 max_tokens: 1_000_000,
2677 input_tokens: 40_000,
2678 output_tokens: 20_000,
2679 })
2680 );
2681 });
2682
2683 thread
2684 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2685 .unwrap();
2686 cx.run_until_parked();
2687
2688 assert_first_message_state(cx);
2689}
2690
2691#[gpui::test]
2692async fn test_title_generation(cx: &mut TestAppContext) {
2693 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2694 let fake_model = model.as_fake();
2695
2696 let summary_model = Arc::new(FakeLanguageModel::default());
2697 thread.update(cx, |thread, cx| {
2698 thread.set_summarization_model(Some(summary_model.clone()), cx)
2699 });
2700
2701 let send = thread
2702 .update(cx, |thread, cx| {
2703 thread.send(UserMessageId::new(), ["Hello"], cx)
2704 })
2705 .unwrap();
2706 cx.run_until_parked();
2707
2708 fake_model.send_last_completion_stream_text_chunk("Hey!");
2709 fake_model.end_last_completion_stream();
2710 cx.run_until_parked();
2711 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2712
2713 // Ensure the summary model has been invoked to generate a title.
2714 summary_model.send_last_completion_stream_text_chunk("Hello ");
2715 summary_model.send_last_completion_stream_text_chunk("world\nG");
2716 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2717 summary_model.end_last_completion_stream();
2718 send.collect::<Vec<_>>().await;
2719 cx.run_until_parked();
2720 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2721
2722 // Send another message, ensuring no title is generated this time.
2723 let send = thread
2724 .update(cx, |thread, cx| {
2725 thread.send(UserMessageId::new(), ["Hello again"], cx)
2726 })
2727 .unwrap();
2728 cx.run_until_parked();
2729 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2730 fake_model.end_last_completion_stream();
2731 cx.run_until_parked();
2732 assert_eq!(summary_model.pending_completions(), Vec::new());
2733 send.collect::<Vec<_>>().await;
2734 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2735}
2736
2737#[gpui::test]
2738async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2739 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2740 let fake_model = model.as_fake();
2741
2742 let _events = thread
2743 .update(cx, |thread, cx| {
2744 thread.add_tool(ToolRequiringPermission, None);
2745 thread.add_tool(EchoTool, None);
2746 thread.send(UserMessageId::new(), ["Hey!"], cx)
2747 })
2748 .unwrap();
2749 cx.run_until_parked();
2750
2751 let permission_tool_use = LanguageModelToolUse {
2752 id: "tool_id_1".into(),
2753 name: ToolRequiringPermission::NAME.into(),
2754 raw_input: "{}".into(),
2755 input: json!({}),
2756 is_input_complete: true,
2757 thought_signature: None,
2758 };
2759 let echo_tool_use = LanguageModelToolUse {
2760 id: "tool_id_2".into(),
2761 name: EchoTool::NAME.into(),
2762 raw_input: json!({"text": "test"}).to_string(),
2763 input: json!({"text": "test"}),
2764 is_input_complete: true,
2765 thought_signature: None,
2766 };
2767 fake_model.send_last_completion_stream_text_chunk("Hi!");
2768 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2769 permission_tool_use,
2770 ));
2771 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2772 echo_tool_use.clone(),
2773 ));
2774 fake_model.end_last_completion_stream();
2775 cx.run_until_parked();
2776
2777 // Ensure pending tools are skipped when building a request.
2778 let request = thread
2779 .read_with(cx, |thread, cx| {
2780 thread.build_completion_request(CompletionIntent::EditFile, cx)
2781 })
2782 .unwrap();
2783 assert_eq!(
2784 request.messages[1..],
2785 vec![
2786 LanguageModelRequestMessage {
2787 role: Role::User,
2788 content: vec!["Hey!".into()],
2789 cache: true,
2790 reasoning_details: None,
2791 },
2792 LanguageModelRequestMessage {
2793 role: Role::Assistant,
2794 content: vec![
2795 MessageContent::Text("Hi!".into()),
2796 MessageContent::ToolUse(echo_tool_use.clone())
2797 ],
2798 cache: false,
2799 reasoning_details: None,
2800 },
2801 LanguageModelRequestMessage {
2802 role: Role::User,
2803 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2804 tool_use_id: echo_tool_use.id.clone(),
2805 tool_name: echo_tool_use.name,
2806 is_error: false,
2807 content: "test".into(),
2808 output: Some("test".into())
2809 })],
2810 cache: false,
2811 reasoning_details: None,
2812 },
2813 ],
2814 );
2815}
2816
2817#[gpui::test]
2818async fn test_agent_connection(cx: &mut TestAppContext) {
2819 cx.update(settings::init);
2820 let templates = Templates::new();
2821
2822 // Initialize language model system with test provider
2823 cx.update(|cx| {
2824 gpui_tokio::init(cx);
2825
2826 let http_client = FakeHttpClient::with_404_response();
2827 let clock = Arc::new(clock::FakeSystemClock::new());
2828 let client = Client::new(clock, http_client, cx);
2829 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2830 language_model::init(client.clone(), cx);
2831 language_models::init(user_store, client.clone(), cx);
2832 LanguageModelRegistry::test(cx);
2833 });
2834 cx.executor().forbid_parking();
2835
2836 // Create a project for new_thread
2837 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2838 fake_fs.insert_tree(path!("/test"), json!({})).await;
2839 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2840 let cwd = Path::new("/test");
2841 let thread_store = cx.new(|cx| ThreadStore::new(cx));
2842
2843 // Create agent and connection
2844 let agent = NativeAgent::new(
2845 project.clone(),
2846 thread_store,
2847 templates.clone(),
2848 None,
2849 fake_fs.clone(),
2850 &mut cx.to_async(),
2851 )
2852 .await
2853 .unwrap();
2854 let connection = NativeAgentConnection(agent.clone());
2855
2856 // Create a thread using new_thread
2857 let connection_rc = Rc::new(connection.clone());
2858 let acp_thread = cx
2859 .update(|cx| connection_rc.new_session(project, cwd, cx))
2860 .await
2861 .expect("new_thread should succeed");
2862
2863 // Get the session_id from the AcpThread
2864 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2865
2866 // Test model_selector returns Some
2867 let selector_opt = connection.model_selector(&session_id);
2868 assert!(
2869 selector_opt.is_some(),
2870 "agent should always support ModelSelector"
2871 );
2872 let selector = selector_opt.unwrap();
2873
2874 // Test list_models
2875 let listed_models = cx
2876 .update(|cx| selector.list_models(cx))
2877 .await
2878 .expect("list_models should succeed");
2879 let AgentModelList::Grouped(listed_models) = listed_models else {
2880 panic!("Unexpected model list type");
2881 };
2882 assert!(!listed_models.is_empty(), "should have at least one model");
2883 assert_eq!(
2884 listed_models[&AgentModelGroupName("Fake".into())][0]
2885 .id
2886 .0
2887 .as_ref(),
2888 "fake/fake"
2889 );
2890
2891 // Test selected_model returns the default
2892 let model = cx
2893 .update(|cx| selector.selected_model(cx))
2894 .await
2895 .expect("selected_model should succeed");
2896 let model = cx
2897 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2898 .unwrap();
2899 let model = model.as_fake();
2900 assert_eq!(model.id().0, "fake", "should return default model");
2901
2902 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2903 cx.run_until_parked();
2904 model.send_last_completion_stream_text_chunk("def");
2905 cx.run_until_parked();
2906 acp_thread.read_with(cx, |thread, cx| {
2907 assert_eq!(
2908 thread.to_markdown(cx),
2909 indoc! {"
2910 ## User
2911
2912 abc
2913
2914 ## Assistant
2915
2916 def
2917
2918 "}
2919 )
2920 });
2921
2922 // Test cancel
2923 cx.update(|cx| connection.cancel(&session_id, cx));
2924 request.await.expect("prompt should fail gracefully");
2925
2926 // Explicitly close the session and drop the ACP thread.
2927 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
2928 .await
2929 .unwrap();
2930 drop(acp_thread);
2931 let result = cx
2932 .update(|cx| {
2933 connection.prompt(
2934 Some(acp_thread::UserMessageId::new()),
2935 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2936 cx,
2937 )
2938 })
2939 .await;
2940 assert_eq!(
2941 result.as_ref().unwrap_err().to_string(),
2942 "Session not found",
2943 "unexpected result: {:?}",
2944 result
2945 );
2946}
2947
2948#[gpui::test]
2949async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2950 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2951 thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool, None));
2952 let fake_model = model.as_fake();
2953
2954 let mut events = thread
2955 .update(cx, |thread, cx| {
2956 thread.send(UserMessageId::new(), ["Think"], cx)
2957 })
2958 .unwrap();
2959 cx.run_until_parked();
2960
2961 // Simulate streaming partial input.
2962 let input = json!({});
2963 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2964 LanguageModelToolUse {
2965 id: "1".into(),
2966 name: ThinkingTool::NAME.into(),
2967 raw_input: input.to_string(),
2968 input,
2969 is_input_complete: false,
2970 thought_signature: None,
2971 },
2972 ));
2973
2974 // Input streaming completed
2975 let input = json!({ "content": "Thinking hard!" });
2976 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2977 LanguageModelToolUse {
2978 id: "1".into(),
2979 name: "thinking".into(),
2980 raw_input: input.to_string(),
2981 input,
2982 is_input_complete: true,
2983 thought_signature: None,
2984 },
2985 ));
2986 fake_model.end_last_completion_stream();
2987 cx.run_until_parked();
2988
2989 let tool_call = expect_tool_call(&mut events).await;
2990 assert_eq!(
2991 tool_call,
2992 acp::ToolCall::new("1", "Thinking")
2993 .kind(acp::ToolKind::Think)
2994 .raw_input(json!({}))
2995 .meta(acp::Meta::from_iter([(
2996 "tool_name".into(),
2997 "thinking".into()
2998 )]))
2999 );
3000 let update = expect_tool_call_update_fields(&mut events).await;
3001 assert_eq!(
3002 update,
3003 acp::ToolCallUpdate::new(
3004 "1",
3005 acp::ToolCallUpdateFields::new()
3006 .title("Thinking")
3007 .kind(acp::ToolKind::Think)
3008 .raw_input(json!({ "content": "Thinking hard!"}))
3009 )
3010 );
3011 let update = expect_tool_call_update_fields(&mut events).await;
3012 assert_eq!(
3013 update,
3014 acp::ToolCallUpdate::new(
3015 "1",
3016 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3017 )
3018 );
3019 let update = expect_tool_call_update_fields(&mut events).await;
3020 assert_eq!(
3021 update,
3022 acp::ToolCallUpdate::new(
3023 "1",
3024 acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
3025 )
3026 );
3027 let update = expect_tool_call_update_fields(&mut events).await;
3028 assert_eq!(
3029 update,
3030 acp::ToolCallUpdate::new(
3031 "1",
3032 acp::ToolCallUpdateFields::new()
3033 .status(acp::ToolCallStatus::Completed)
3034 .raw_output("Finished thinking.")
3035 )
3036 );
3037}
3038
3039#[gpui::test]
3040async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3041 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3042 let fake_model = model.as_fake();
3043
3044 let mut events = thread
3045 .update(cx, |thread, cx| {
3046 thread.send(UserMessageId::new(), ["Hello!"], cx)
3047 })
3048 .unwrap();
3049 cx.run_until_parked();
3050
3051 fake_model.send_last_completion_stream_text_chunk("Hey!");
3052 fake_model.end_last_completion_stream();
3053
3054 let mut retry_events = Vec::new();
3055 while let Some(Ok(event)) = events.next().await {
3056 match event {
3057 ThreadEvent::Retry(retry_status) => {
3058 retry_events.push(retry_status);
3059 }
3060 ThreadEvent::Stop(..) => break,
3061 _ => {}
3062 }
3063 }
3064
3065 assert_eq!(retry_events.len(), 0);
3066 thread.read_with(cx, |thread, _cx| {
3067 assert_eq!(
3068 thread.to_markdown(),
3069 indoc! {"
3070 ## User
3071
3072 Hello!
3073
3074 ## Assistant
3075
3076 Hey!
3077 "}
3078 )
3079 });
3080}
3081
3082#[gpui::test]
3083async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3084 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3085 let fake_model = model.as_fake();
3086
3087 let mut events = thread
3088 .update(cx, |thread, cx| {
3089 thread.send(UserMessageId::new(), ["Hello!"], cx)
3090 })
3091 .unwrap();
3092 cx.run_until_parked();
3093
3094 fake_model.send_last_completion_stream_text_chunk("Hey,");
3095 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3096 provider: LanguageModelProviderName::new("Anthropic"),
3097 retry_after: Some(Duration::from_secs(3)),
3098 });
3099 fake_model.end_last_completion_stream();
3100
3101 cx.executor().advance_clock(Duration::from_secs(3));
3102 cx.run_until_parked();
3103
3104 fake_model.send_last_completion_stream_text_chunk("there!");
3105 fake_model.end_last_completion_stream();
3106 cx.run_until_parked();
3107
3108 let mut retry_events = Vec::new();
3109 while let Some(Ok(event)) = events.next().await {
3110 match event {
3111 ThreadEvent::Retry(retry_status) => {
3112 retry_events.push(retry_status);
3113 }
3114 ThreadEvent::Stop(..) => break,
3115 _ => {}
3116 }
3117 }
3118
3119 assert_eq!(retry_events.len(), 1);
3120 assert!(matches!(
3121 retry_events[0],
3122 acp_thread::RetryStatus { attempt: 1, .. }
3123 ));
3124 thread.read_with(cx, |thread, _cx| {
3125 assert_eq!(
3126 thread.to_markdown(),
3127 indoc! {"
3128 ## User
3129
3130 Hello!
3131
3132 ## Assistant
3133
3134 Hey,
3135
3136 [resume]
3137
3138 ## Assistant
3139
3140 there!
3141 "}
3142 )
3143 });
3144}
3145
3146#[gpui::test]
3147async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3148 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3149 let fake_model = model.as_fake();
3150
3151 let events = thread
3152 .update(cx, |thread, cx| {
3153 thread.add_tool(EchoTool, None);
3154 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3155 })
3156 .unwrap();
3157 cx.run_until_parked();
3158
3159 let tool_use_1 = LanguageModelToolUse {
3160 id: "tool_1".into(),
3161 name: EchoTool::NAME.into(),
3162 raw_input: json!({"text": "test"}).to_string(),
3163 input: json!({"text": "test"}),
3164 is_input_complete: true,
3165 thought_signature: None,
3166 };
3167 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3168 tool_use_1.clone(),
3169 ));
3170 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3171 provider: LanguageModelProviderName::new("Anthropic"),
3172 retry_after: Some(Duration::from_secs(3)),
3173 });
3174 fake_model.end_last_completion_stream();
3175
3176 cx.executor().advance_clock(Duration::from_secs(3));
3177 let completion = fake_model.pending_completions().pop().unwrap();
3178 assert_eq!(
3179 completion.messages[1..],
3180 vec![
3181 LanguageModelRequestMessage {
3182 role: Role::User,
3183 content: vec!["Call the echo tool!".into()],
3184 cache: false,
3185 reasoning_details: None,
3186 },
3187 LanguageModelRequestMessage {
3188 role: Role::Assistant,
3189 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3190 cache: false,
3191 reasoning_details: None,
3192 },
3193 LanguageModelRequestMessage {
3194 role: Role::User,
3195 content: vec![language_model::MessageContent::ToolResult(
3196 LanguageModelToolResult {
3197 tool_use_id: tool_use_1.id.clone(),
3198 tool_name: tool_use_1.name.clone(),
3199 is_error: false,
3200 content: "test".into(),
3201 output: Some("test".into())
3202 }
3203 )],
3204 cache: true,
3205 reasoning_details: None,
3206 },
3207 ]
3208 );
3209
3210 fake_model.send_last_completion_stream_text_chunk("Done");
3211 fake_model.end_last_completion_stream();
3212 cx.run_until_parked();
3213 events.collect::<Vec<_>>().await;
3214 thread.read_with(cx, |thread, _cx| {
3215 assert_eq!(
3216 thread.last_message(),
3217 Some(Message::Agent(AgentMessage {
3218 content: vec![AgentMessageContent::Text("Done".into())],
3219 tool_results: IndexMap::default(),
3220 reasoning_details: None,
3221 }))
3222 );
3223 })
3224}
3225
3226#[gpui::test]
3227async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3228 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3229 let fake_model = model.as_fake();
3230
3231 let mut events = thread
3232 .update(cx, |thread, cx| {
3233 thread.send(UserMessageId::new(), ["Hello!"], cx)
3234 })
3235 .unwrap();
3236 cx.run_until_parked();
3237
3238 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3239 fake_model.send_last_completion_stream_error(
3240 LanguageModelCompletionError::ServerOverloaded {
3241 provider: LanguageModelProviderName::new("Anthropic"),
3242 retry_after: Some(Duration::from_secs(3)),
3243 },
3244 );
3245 fake_model.end_last_completion_stream();
3246 cx.executor().advance_clock(Duration::from_secs(3));
3247 cx.run_until_parked();
3248 }
3249
3250 let mut errors = Vec::new();
3251 let mut retry_events = Vec::new();
3252 while let Some(event) = events.next().await {
3253 match event {
3254 Ok(ThreadEvent::Retry(retry_status)) => {
3255 retry_events.push(retry_status);
3256 }
3257 Ok(ThreadEvent::Stop(..)) => break,
3258 Err(error) => errors.push(error),
3259 _ => {}
3260 }
3261 }
3262
3263 assert_eq!(
3264 retry_events.len(),
3265 crate::thread::MAX_RETRY_ATTEMPTS as usize
3266 );
3267 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3268 assert_eq!(retry_events[i].attempt, i + 1);
3269 }
3270 assert_eq!(errors.len(), 1);
3271 let error = errors[0]
3272 .downcast_ref::<LanguageModelCompletionError>()
3273 .unwrap();
3274 assert!(matches!(
3275 error,
3276 LanguageModelCompletionError::ServerOverloaded { .. }
3277 ));
3278}
3279
3280/// Filters out the stop events for asserting against in tests
3281fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3282 result_events
3283 .into_iter()
3284 .filter_map(|event| match event.unwrap() {
3285 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3286 _ => None,
3287 })
3288 .collect()
3289}
3290
3291struct ThreadTest {
3292 model: Arc<dyn LanguageModel>,
3293 thread: Entity<Thread>,
3294 project_context: Entity<ProjectContext>,
3295 context_server_store: Entity<ContextServerStore>,
3296 fs: Arc<FakeFs>,
3297}
3298
3299enum TestModel {
3300 Sonnet4,
3301 Fake,
3302}
3303
3304impl TestModel {
3305 fn id(&self) -> LanguageModelId {
3306 match self {
3307 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3308 TestModel::Fake => unreachable!(),
3309 }
3310 }
3311}
3312
3313async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3314 cx.executor().allow_parking();
3315
3316 let fs = FakeFs::new(cx.background_executor.clone());
3317 fs.create_dir(paths::settings_file().parent().unwrap())
3318 .await
3319 .unwrap();
3320 fs.insert_file(
3321 paths::settings_file(),
3322 json!({
3323 "agent": {
3324 "default_profile": "test-profile",
3325 "profiles": {
3326 "test-profile": {
3327 "name": "Test Profile",
3328 "tools": {
3329 EchoTool::NAME: true,
3330 DelayTool::NAME: true,
3331 WordListTool::NAME: true,
3332 ToolRequiringPermission::NAME: true,
3333 InfiniteTool::NAME: true,
3334 CancellationAwareTool::NAME: true,
3335 ThinkingTool::NAME: true,
3336 (TerminalTool::NAME): true,
3337 }
3338 }
3339 }
3340 }
3341 })
3342 .to_string()
3343 .into_bytes(),
3344 )
3345 .await;
3346
3347 cx.update(|cx| {
3348 settings::init(cx);
3349
3350 match model {
3351 TestModel::Fake => {}
3352 TestModel::Sonnet4 => {
3353 gpui_tokio::init(cx);
3354 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3355 cx.set_http_client(Arc::new(http_client));
3356 let client = Client::production(cx);
3357 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3358 language_model::init(client.clone(), cx);
3359 language_models::init(user_store, client.clone(), cx);
3360 }
3361 };
3362
3363 watch_settings(fs.clone(), cx);
3364 });
3365
3366 let templates = Templates::new();
3367
3368 fs.insert_tree(path!("/test"), json!({})).await;
3369 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3370
3371 let model = cx
3372 .update(|cx| {
3373 if let TestModel::Fake = model {
3374 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3375 } else {
3376 let model_id = model.id();
3377 let models = LanguageModelRegistry::read_global(cx);
3378 let model = models
3379 .available_models(cx)
3380 .find(|model| model.id() == model_id)
3381 .unwrap();
3382
3383 let provider = models.provider(&model.provider_id()).unwrap();
3384 let authenticated = provider.authenticate(cx);
3385
3386 cx.spawn(async move |_cx| {
3387 authenticated.await.unwrap();
3388 model
3389 })
3390 }
3391 })
3392 .await;
3393
3394 let project_context = cx.new(|_cx| ProjectContext::default());
3395 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3396 let context_server_registry =
3397 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3398 let thread = cx.new(|cx| {
3399 Thread::new(
3400 project,
3401 project_context.clone(),
3402 context_server_registry,
3403 templates,
3404 Some(model.clone()),
3405 cx,
3406 )
3407 });
3408 ThreadTest {
3409 model,
3410 thread,
3411 project_context,
3412 context_server_store,
3413 fs,
3414 }
3415}
3416
3417#[cfg(test)]
3418#[ctor::ctor]
3419fn init_logger() {
3420 if std::env::var("RUST_LOG").is_ok() {
3421 env_logger::init();
3422 }
3423}
3424
3425fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3426 let fs = fs.clone();
3427 cx.spawn({
3428 async move |cx| {
3429 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3430 cx.background_executor(),
3431 fs,
3432 paths::settings_file().clone(),
3433 );
3434 let _watcher_task = watcher_task;
3435
3436 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3437 cx.update(|cx| {
3438 SettingsStore::update_global(cx, |settings, cx| {
3439 settings.set_user_settings(&new_settings_content, cx)
3440 })
3441 })
3442 .ok();
3443 }
3444 }
3445 })
3446 .detach();
3447}
3448
3449fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3450 completion
3451 .tools
3452 .iter()
3453 .map(|tool| tool.name.clone())
3454 .collect()
3455}
3456
3457fn setup_context_server(
3458 name: &'static str,
3459 tools: Vec<context_server::types::Tool>,
3460 context_server_store: &Entity<ContextServerStore>,
3461 cx: &mut TestAppContext,
3462) -> mpsc::UnboundedReceiver<(
3463 context_server::types::CallToolParams,
3464 oneshot::Sender<context_server::types::CallToolResponse>,
3465)> {
3466 cx.update(|cx| {
3467 let mut settings = ProjectSettings::get_global(cx).clone();
3468 settings.context_servers.insert(
3469 name.into(),
3470 project::project_settings::ContextServerSettings::Stdio {
3471 enabled: true,
3472 remote: false,
3473 command: ContextServerCommand {
3474 path: "somebinary".into(),
3475 args: Vec::new(),
3476 env: None,
3477 timeout: None,
3478 },
3479 },
3480 );
3481 ProjectSettings::override_global(settings, cx);
3482 });
3483
3484 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3485 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3486 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3487 context_server::types::InitializeResponse {
3488 protocol_version: context_server::types::ProtocolVersion(
3489 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3490 ),
3491 server_info: context_server::types::Implementation {
3492 name: name.into(),
3493 version: "1.0.0".to_string(),
3494 },
3495 capabilities: context_server::types::ServerCapabilities {
3496 tools: Some(context_server::types::ToolsCapabilities {
3497 list_changed: Some(true),
3498 }),
3499 ..Default::default()
3500 },
3501 meta: None,
3502 }
3503 })
3504 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3505 let tools = tools.clone();
3506 async move {
3507 context_server::types::ListToolsResponse {
3508 tools,
3509 next_cursor: None,
3510 meta: None,
3511 }
3512 }
3513 })
3514 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3515 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3516 async move {
3517 let (response_tx, response_rx) = oneshot::channel();
3518 mcp_tool_calls_tx
3519 .unbounded_send((params, response_tx))
3520 .unwrap();
3521 response_rx.await.unwrap()
3522 }
3523 });
3524 context_server_store.update(cx, |store, cx| {
3525 store.start_server(
3526 Arc::new(ContextServer::new(
3527 ContextServerId(name.into()),
3528 Arc::new(fake_transport),
3529 )),
3530 cx,
3531 );
3532 });
3533 cx.run_until_parked();
3534 mcp_tool_calls_rx
3535}
3536
3537#[gpui::test]
3538async fn test_tokens_before_message(cx: &mut TestAppContext) {
3539 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3540 let fake_model = model.as_fake();
3541
3542 // First message
3543 let message_1_id = UserMessageId::new();
3544 thread
3545 .update(cx, |thread, cx| {
3546 thread.send(message_1_id.clone(), ["First message"], cx)
3547 })
3548 .unwrap();
3549 cx.run_until_parked();
3550
3551 // Before any response, tokens_before_message should return None for first message
3552 thread.read_with(cx, |thread, _| {
3553 assert_eq!(
3554 thread.tokens_before_message(&message_1_id),
3555 None,
3556 "First message should have no tokens before it"
3557 );
3558 });
3559
3560 // Complete first message with usage
3561 fake_model.send_last_completion_stream_text_chunk("Response 1");
3562 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3563 language_model::TokenUsage {
3564 input_tokens: 100,
3565 output_tokens: 50,
3566 cache_creation_input_tokens: 0,
3567 cache_read_input_tokens: 0,
3568 },
3569 ));
3570 fake_model.end_last_completion_stream();
3571 cx.run_until_parked();
3572
3573 // First message still has no tokens before it
3574 thread.read_with(cx, |thread, _| {
3575 assert_eq!(
3576 thread.tokens_before_message(&message_1_id),
3577 None,
3578 "First message should still have no tokens before it after response"
3579 );
3580 });
3581
3582 // Second message
3583 let message_2_id = UserMessageId::new();
3584 thread
3585 .update(cx, |thread, cx| {
3586 thread.send(message_2_id.clone(), ["Second message"], cx)
3587 })
3588 .unwrap();
3589 cx.run_until_parked();
3590
3591 // Second message should have first message's input tokens before it
3592 thread.read_with(cx, |thread, _| {
3593 assert_eq!(
3594 thread.tokens_before_message(&message_2_id),
3595 Some(100),
3596 "Second message should have 100 tokens before it (from first request)"
3597 );
3598 });
3599
3600 // Complete second message
3601 fake_model.send_last_completion_stream_text_chunk("Response 2");
3602 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3603 language_model::TokenUsage {
3604 input_tokens: 250, // Total for this request (includes previous context)
3605 output_tokens: 75,
3606 cache_creation_input_tokens: 0,
3607 cache_read_input_tokens: 0,
3608 },
3609 ));
3610 fake_model.end_last_completion_stream();
3611 cx.run_until_parked();
3612
3613 // Third message
3614 let message_3_id = UserMessageId::new();
3615 thread
3616 .update(cx, |thread, cx| {
3617 thread.send(message_3_id.clone(), ["Third message"], cx)
3618 })
3619 .unwrap();
3620 cx.run_until_parked();
3621
3622 // Third message should have second message's input tokens (250) before it
3623 thread.read_with(cx, |thread, _| {
3624 assert_eq!(
3625 thread.tokens_before_message(&message_3_id),
3626 Some(250),
3627 "Third message should have 250 tokens before it (from second request)"
3628 );
3629 // Second message should still have 100
3630 assert_eq!(
3631 thread.tokens_before_message(&message_2_id),
3632 Some(100),
3633 "Second message should still have 100 tokens before it"
3634 );
3635 // First message still has none
3636 assert_eq!(
3637 thread.tokens_before_message(&message_1_id),
3638 None,
3639 "First message should still have no tokens before it"
3640 );
3641 });
3642}
3643
3644#[gpui::test]
3645async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3646 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3647 let fake_model = model.as_fake();
3648
3649 // Set up three messages with responses
3650 let message_1_id = UserMessageId::new();
3651 thread
3652 .update(cx, |thread, cx| {
3653 thread.send(message_1_id.clone(), ["Message 1"], cx)
3654 })
3655 .unwrap();
3656 cx.run_until_parked();
3657 fake_model.send_last_completion_stream_text_chunk("Response 1");
3658 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3659 language_model::TokenUsage {
3660 input_tokens: 100,
3661 output_tokens: 50,
3662 cache_creation_input_tokens: 0,
3663 cache_read_input_tokens: 0,
3664 },
3665 ));
3666 fake_model.end_last_completion_stream();
3667 cx.run_until_parked();
3668
3669 let message_2_id = UserMessageId::new();
3670 thread
3671 .update(cx, |thread, cx| {
3672 thread.send(message_2_id.clone(), ["Message 2"], cx)
3673 })
3674 .unwrap();
3675 cx.run_until_parked();
3676 fake_model.send_last_completion_stream_text_chunk("Response 2");
3677 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3678 language_model::TokenUsage {
3679 input_tokens: 250,
3680 output_tokens: 75,
3681 cache_creation_input_tokens: 0,
3682 cache_read_input_tokens: 0,
3683 },
3684 ));
3685 fake_model.end_last_completion_stream();
3686 cx.run_until_parked();
3687
3688 // Verify initial state
3689 thread.read_with(cx, |thread, _| {
3690 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3691 });
3692
3693 // Truncate at message 2 (removes message 2 and everything after)
3694 thread
3695 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3696 .unwrap();
3697 cx.run_until_parked();
3698
3699 // After truncation, message_2_id no longer exists, so lookup should return None
3700 thread.read_with(cx, |thread, _| {
3701 assert_eq!(
3702 thread.tokens_before_message(&message_2_id),
3703 None,
3704 "After truncation, message 2 no longer exists"
3705 );
3706 // Message 1 still exists but has no tokens before it
3707 assert_eq!(
3708 thread.tokens_before_message(&message_1_id),
3709 None,
3710 "First message still has no tokens before it"
3711 );
3712 });
3713}
3714
3715#[gpui::test]
3716async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3717 init_test(cx);
3718
3719 let fs = FakeFs::new(cx.executor());
3720 fs.insert_tree("/root", json!({})).await;
3721 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3722
3723 // Test 1: Deny rule blocks command
3724 {
3725 let environment = Rc::new(cx.update(|cx| {
3726 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3727 }));
3728
3729 cx.update(|cx| {
3730 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3731 settings.tool_permissions.tools.insert(
3732 TerminalTool::NAME.into(),
3733 agent_settings::ToolRules {
3734 default_mode: settings::ToolPermissionMode::Confirm,
3735 always_allow: vec![],
3736 always_deny: vec![
3737 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3738 ],
3739 always_confirm: vec![],
3740 invalid_patterns: vec![],
3741 },
3742 );
3743 agent_settings::AgentSettings::override_global(settings, cx);
3744 });
3745
3746 #[allow(clippy::arc_with_non_send_sync)]
3747 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3748 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3749
3750 let task = cx.update(|cx| {
3751 tool.run(
3752 crate::TerminalToolInput {
3753 command: "rm -rf /".to_string(),
3754 cd: ".".to_string(),
3755 timeout_ms: None,
3756 },
3757 event_stream,
3758 cx,
3759 )
3760 });
3761
3762 let result = task.await;
3763 assert!(
3764 result.is_err(),
3765 "expected command to be blocked by deny rule"
3766 );
3767 let err_msg = result.unwrap_err().to_string().to_lowercase();
3768 assert!(
3769 err_msg.contains("blocked"),
3770 "error should mention the command was blocked"
3771 );
3772 }
3773
3774 // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3775 {
3776 let environment = Rc::new(cx.update(|cx| {
3777 FakeThreadEnvironment::default()
3778 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3779 }));
3780
3781 cx.update(|cx| {
3782 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3783 settings.always_allow_tool_actions = false;
3784 settings.tool_permissions.tools.insert(
3785 TerminalTool::NAME.into(),
3786 agent_settings::ToolRules {
3787 default_mode: settings::ToolPermissionMode::Deny,
3788 always_allow: vec![
3789 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3790 ],
3791 always_deny: vec![],
3792 always_confirm: vec![],
3793 invalid_patterns: vec![],
3794 },
3795 );
3796 agent_settings::AgentSettings::override_global(settings, cx);
3797 });
3798
3799 #[allow(clippy::arc_with_non_send_sync)]
3800 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3801 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3802
3803 let task = cx.update(|cx| {
3804 tool.run(
3805 crate::TerminalToolInput {
3806 command: "echo hello".to_string(),
3807 cd: ".".to_string(),
3808 timeout_ms: None,
3809 },
3810 event_stream,
3811 cx,
3812 )
3813 });
3814
3815 let update = rx.expect_update_fields().await;
3816 assert!(
3817 update.content.iter().any(|blocks| {
3818 blocks
3819 .iter()
3820 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3821 }),
3822 "expected terminal content (allow rule should skip confirmation and override default deny)"
3823 );
3824
3825 let result = task.await;
3826 assert!(
3827 result.is_ok(),
3828 "expected command to succeed without confirmation"
3829 );
3830 }
3831
3832 // Test 3: always_allow_tool_actions=true overrides always_confirm patterns
3833 {
3834 let environment = Rc::new(cx.update(|cx| {
3835 FakeThreadEnvironment::default()
3836 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3837 }));
3838
3839 cx.update(|cx| {
3840 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3841 settings.always_allow_tool_actions = true;
3842 settings.tool_permissions.tools.insert(
3843 TerminalTool::NAME.into(),
3844 agent_settings::ToolRules {
3845 default_mode: settings::ToolPermissionMode::Allow,
3846 always_allow: vec![],
3847 always_deny: vec![],
3848 always_confirm: vec![
3849 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3850 ],
3851 invalid_patterns: vec![],
3852 },
3853 );
3854 agent_settings::AgentSettings::override_global(settings, cx);
3855 });
3856
3857 #[allow(clippy::arc_with_non_send_sync)]
3858 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3859 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3860
3861 let task = cx.update(|cx| {
3862 tool.run(
3863 crate::TerminalToolInput {
3864 command: "sudo rm file".to_string(),
3865 cd: ".".to_string(),
3866 timeout_ms: None,
3867 },
3868 event_stream,
3869 cx,
3870 )
3871 });
3872
3873 // With always_allow_tool_actions=true, confirm patterns are overridden
3874 task.await
3875 .expect("command should be allowed with always_allow_tool_actions=true");
3876 }
3877
3878 // Test 4: always_allow_tool_actions=true overrides default_mode: Deny
3879 {
3880 let environment = Rc::new(cx.update(|cx| {
3881 FakeThreadEnvironment::default()
3882 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
3883 }));
3884
3885 cx.update(|cx| {
3886 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3887 settings.always_allow_tool_actions = true;
3888 settings.tool_permissions.tools.insert(
3889 TerminalTool::NAME.into(),
3890 agent_settings::ToolRules {
3891 default_mode: settings::ToolPermissionMode::Deny,
3892 always_allow: vec![],
3893 always_deny: vec![],
3894 always_confirm: vec![],
3895 invalid_patterns: vec![],
3896 },
3897 );
3898 agent_settings::AgentSettings::override_global(settings, cx);
3899 });
3900
3901 #[allow(clippy::arc_with_non_send_sync)]
3902 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3903 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3904
3905 let task = cx.update(|cx| {
3906 tool.run(
3907 crate::TerminalToolInput {
3908 command: "echo hello".to_string(),
3909 cd: ".".to_string(),
3910 timeout_ms: None,
3911 },
3912 event_stream,
3913 cx,
3914 )
3915 });
3916
3917 // With always_allow_tool_actions=true, even default_mode: Deny is overridden
3918 task.await
3919 .expect("command should be allowed with always_allow_tool_actions=true");
3920 }
3921}
3922
3923#[gpui::test]
3924async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3925 init_test(cx);
3926
3927 cx.update(|cx| {
3928 cx.update_flags(true, vec!["subagents".to_string()]);
3929 });
3930
3931 let fs = FakeFs::new(cx.executor());
3932 fs.insert_tree(path!("/test"), json!({})).await;
3933 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3934 let project_context = cx.new(|_cx| ProjectContext::default());
3935 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3936 let context_server_registry =
3937 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3938 let model = Arc::new(FakeLanguageModel::default());
3939
3940 let environment = Rc::new(cx.update(|cx| {
3941 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3942 }));
3943
3944 let thread = cx.new(|cx| {
3945 let mut thread = Thread::new(
3946 project.clone(),
3947 project_context,
3948 context_server_registry,
3949 Templates::new(),
3950 Some(model),
3951 cx,
3952 );
3953 thread.add_default_tools(None, environment, cx);
3954 thread
3955 });
3956
3957 thread.read_with(cx, |thread, _| {
3958 assert!(
3959 thread.has_registered_tool(SubagentTool::NAME),
3960 "subagent tool should be present when feature flag is enabled"
3961 );
3962 });
3963}
3964
3965#[gpui::test]
3966async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
3967 init_test(cx);
3968
3969 cx.update(|cx| {
3970 cx.update_flags(true, vec!["subagents".to_string()]);
3971 });
3972
3973 let fs = FakeFs::new(cx.executor());
3974 fs.insert_tree(path!("/test"), json!({})).await;
3975 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3976 let project_context = cx.new(|_cx| ProjectContext::default());
3977 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3978 let context_server_registry =
3979 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3980 let model = Arc::new(FakeLanguageModel::default());
3981
3982 let parent_thread = cx.new(|cx| {
3983 Thread::new(
3984 project.clone(),
3985 project_context,
3986 context_server_registry,
3987 Templates::new(),
3988 Some(model.clone()),
3989 cx,
3990 )
3991 });
3992
3993 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
3994 subagent_thread.read_with(cx, |subagent_thread, cx| {
3995 assert!(subagent_thread.is_subagent());
3996 assert_eq!(subagent_thread.depth(), 1);
3997 assert_eq!(
3998 subagent_thread.model().map(|model| model.id()),
3999 Some(model.id())
4000 );
4001 assert_eq!(
4002 subagent_thread.parent_thread_id(),
4003 Some(parent_thread.read(cx).id().clone())
4004 );
4005 });
4006}
4007
4008#[gpui::test]
4009async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4010 init_test(cx);
4011
4012 cx.update(|cx| {
4013 cx.update_flags(true, vec!["subagents".to_string()]);
4014 });
4015
4016 let fs = FakeFs::new(cx.executor());
4017 fs.insert_tree(path!("/test"), json!({})).await;
4018 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4019 let project_context = cx.new(|_cx| ProjectContext::default());
4020 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4021 let context_server_registry =
4022 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4023 let model = Arc::new(FakeLanguageModel::default());
4024 let environment = Rc::new(cx.update(|cx| {
4025 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4026 }));
4027
4028 let deep_parent_thread = cx.new(|cx| {
4029 let mut thread = Thread::new(
4030 project.clone(),
4031 project_context,
4032 context_server_registry,
4033 Templates::new(),
4034 Some(model.clone()),
4035 cx,
4036 );
4037 thread.set_subagent_context(SubagentContext {
4038 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4039 depth: MAX_SUBAGENT_DEPTH - 1,
4040 });
4041 thread
4042 });
4043 let deep_subagent_thread = cx.new(|cx| {
4044 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4045 thread.add_default_tools(None, environment, cx);
4046 thread
4047 });
4048
4049 deep_subagent_thread.read_with(cx, |thread, _| {
4050 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4051 assert!(
4052 !thread.has_registered_tool(SubagentTool::NAME),
4053 "subagent tool should not be present at max depth"
4054 );
4055 });
4056}
4057
4058#[gpui::test]
4059async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4060 init_test(cx);
4061
4062 cx.update(|cx| {
4063 cx.update_flags(true, vec!["subagents".to_string()]);
4064 });
4065
4066 let fs = FakeFs::new(cx.executor());
4067 fs.insert_tree(path!("/test"), json!({})).await;
4068 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4069 let project_context = cx.new(|_cx| ProjectContext::default());
4070 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4071 let context_server_registry =
4072 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4073 let model = Arc::new(FakeLanguageModel::default());
4074
4075 let parent = cx.new(|cx| {
4076 Thread::new(
4077 project.clone(),
4078 project_context.clone(),
4079 context_server_registry.clone(),
4080 Templates::new(),
4081 Some(model.clone()),
4082 cx,
4083 )
4084 });
4085
4086 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4087
4088 parent.update(cx, |thread, _cx| {
4089 thread.register_running_subagent(subagent.downgrade());
4090 });
4091
4092 subagent
4093 .update(cx, |thread, cx| {
4094 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4095 })
4096 .unwrap();
4097 cx.run_until_parked();
4098
4099 subagent.read_with(cx, |thread, _| {
4100 assert!(!thread.is_turn_complete(), "subagent should be running");
4101 });
4102
4103 parent.update(cx, |thread, cx| {
4104 thread.cancel(cx).detach();
4105 });
4106
4107 subagent.read_with(cx, |thread, _| {
4108 assert!(
4109 thread.is_turn_complete(),
4110 "subagent should be cancelled when parent cancels"
4111 );
4112 });
4113}
4114
4115#[gpui::test]
4116async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4117 // This test verifies that the subagent tool properly handles user cancellation
4118 // via `event_stream.cancelled_by_user()` and stops all running subagents.
4119 init_test(cx);
4120 always_allow_tools(cx);
4121
4122 cx.update(|cx| {
4123 cx.update_flags(true, vec!["subagents".to_string()]);
4124 });
4125
4126 let fs = FakeFs::new(cx.executor());
4127 fs.insert_tree(path!("/test"), json!({})).await;
4128 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4129 let project_context = cx.new(|_cx| ProjectContext::default());
4130 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4131 let context_server_registry =
4132 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4133 let model = Arc::new(FakeLanguageModel::default());
4134 let environment = Rc::new(cx.update(|cx| {
4135 FakeThreadEnvironment::default().with_subagent(FakeSubagentHandle::new_never_completes(cx))
4136 }));
4137
4138 let parent = cx.new(|cx| {
4139 Thread::new(
4140 project.clone(),
4141 project_context.clone(),
4142 context_server_registry.clone(),
4143 Templates::new(),
4144 Some(model.clone()),
4145 cx,
4146 )
4147 });
4148
4149 #[allow(clippy::arc_with_non_send_sync)]
4150 let tool = Arc::new(SubagentTool::new(parent.downgrade(), environment));
4151
4152 let (event_stream, _rx, mut cancellation_tx) =
4153 crate::ToolCallEventStream::test_with_cancellation();
4154
4155 // Start the subagent tool
4156 let task = cx.update(|cx| {
4157 tool.run(
4158 SubagentToolInput {
4159 label: "Long running task".to_string(),
4160 task_prompt: "Do a very long task that takes forever".to_string(),
4161 summary_prompt: "Summarize".to_string(),
4162 timeout_ms: None,
4163 allowed_tools: None,
4164 },
4165 event_stream.clone(),
4166 cx,
4167 )
4168 });
4169
4170 cx.run_until_parked();
4171
4172 // Signal cancellation via the event stream
4173 crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4174
4175 // The task should complete promptly with a cancellation error
4176 let timeout = cx.background_executor.timer(Duration::from_secs(5));
4177 let result = futures::select! {
4178 result = task.fuse() => result,
4179 _ = timeout.fuse() => {
4180 panic!("subagent tool did not respond to cancellation within timeout");
4181 }
4182 };
4183
4184 // Verify we got a cancellation error
4185 let err = result.unwrap_err();
4186 assert!(
4187 err.to_string().contains("cancelled by user"),
4188 "expected cancellation error, got: {}",
4189 err
4190 );
4191}
4192
4193#[gpui::test]
4194async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4195 init_test(cx);
4196 always_allow_tools(cx);
4197
4198 cx.update(|cx| {
4199 cx.update_flags(true, vec!["subagents".to_string()]);
4200 });
4201
4202 let fs = FakeFs::new(cx.executor());
4203 fs.insert_tree(path!("/test"), json!({})).await;
4204 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4205 let project_context = cx.new(|_cx| ProjectContext::default());
4206 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4207 let context_server_registry =
4208 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4209 cx.update(LanguageModelRegistry::test);
4210 let model = Arc::new(FakeLanguageModel::default());
4211 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4212 let native_agent = NativeAgent::new(
4213 project.clone(),
4214 thread_store,
4215 Templates::new(),
4216 None,
4217 fs,
4218 &mut cx.to_async(),
4219 )
4220 .await
4221 .unwrap();
4222 let parent_thread = cx.new(|cx| {
4223 Thread::new(
4224 project.clone(),
4225 project_context,
4226 context_server_registry,
4227 Templates::new(),
4228 Some(model.clone()),
4229 cx,
4230 )
4231 });
4232
4233 let mut handles = Vec::new();
4234 for _ in 0..MAX_PARALLEL_SUBAGENTS {
4235 let handle = cx
4236 .update(|cx| {
4237 NativeThreadEnvironment::create_subagent_thread(
4238 native_agent.downgrade(),
4239 parent_thread.clone(),
4240 "some title".to_string(),
4241 "some task".to_string(),
4242 None,
4243 None,
4244 cx,
4245 )
4246 })
4247 .expect("Expected to be able to create subagent thread");
4248 handles.push(handle);
4249 }
4250
4251 let result = cx.update(|cx| {
4252 NativeThreadEnvironment::create_subagent_thread(
4253 native_agent.downgrade(),
4254 parent_thread.clone(),
4255 "some title".to_string(),
4256 "some task".to_string(),
4257 None,
4258 None,
4259 cx,
4260 )
4261 });
4262 assert!(result.is_err());
4263 assert_eq!(
4264 result.err().unwrap().to_string(),
4265 format!(
4266 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
4267 MAX_PARALLEL_SUBAGENTS
4268 )
4269 );
4270}
4271
4272#[gpui::test]
4273async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4274 init_test(cx);
4275
4276 always_allow_tools(cx);
4277
4278 cx.update(|cx| {
4279 cx.update_flags(true, vec!["subagents".to_string()]);
4280 });
4281
4282 let fs = FakeFs::new(cx.executor());
4283 fs.insert_tree(path!("/test"), json!({})).await;
4284 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4285 let project_context = cx.new(|_cx| ProjectContext::default());
4286 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4287 let context_server_registry =
4288 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4289 cx.update(LanguageModelRegistry::test);
4290 let model = Arc::new(FakeLanguageModel::default());
4291 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4292 let native_agent = NativeAgent::new(
4293 project.clone(),
4294 thread_store,
4295 Templates::new(),
4296 None,
4297 fs,
4298 &mut cx.to_async(),
4299 )
4300 .await
4301 .unwrap();
4302 let parent_thread = cx.new(|cx| {
4303 Thread::new(
4304 project.clone(),
4305 project_context,
4306 context_server_registry,
4307 Templates::new(),
4308 Some(model.clone()),
4309 cx,
4310 )
4311 });
4312
4313 let subagent_handle = cx
4314 .update(|cx| {
4315 NativeThreadEnvironment::create_subagent_thread(
4316 native_agent.downgrade(),
4317 parent_thread.clone(),
4318 "some title".to_string(),
4319 "task prompt".to_string(),
4320 Some(Duration::from_millis(10)),
4321 None,
4322 cx,
4323 )
4324 })
4325 .expect("Failed to create subagent");
4326
4327 let summary_task =
4328 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4329
4330 cx.run_until_parked();
4331
4332 {
4333 let messages = model.pending_completions().last().unwrap().messages.clone();
4334 // Ensure that model received a system prompt
4335 assert_eq!(messages[0].role, Role::System);
4336 // Ensure that model received a task prompt
4337 assert_eq!(messages[1].role, Role::User);
4338 assert_eq!(
4339 messages[1].content,
4340 vec![MessageContent::Text("task prompt".to_string())]
4341 );
4342 }
4343
4344 model.send_last_completion_stream_text_chunk("Some task response...");
4345 model.end_last_completion_stream();
4346
4347 cx.run_until_parked();
4348
4349 {
4350 let messages = model.pending_completions().last().unwrap().messages.clone();
4351 assert_eq!(messages[2].role, Role::Assistant);
4352 assert_eq!(
4353 messages[2].content,
4354 vec![MessageContent::Text("Some task response...".to_string())]
4355 );
4356 // Ensure that model received a summary prompt
4357 assert_eq!(messages[3].role, Role::User);
4358 assert_eq!(
4359 messages[3].content,
4360 vec![MessageContent::Text("summary prompt".to_string())]
4361 );
4362 }
4363
4364 model.send_last_completion_stream_text_chunk("Some summary...");
4365 model.end_last_completion_stream();
4366
4367 let result = summary_task.await;
4368 assert_eq!(result.unwrap(), "Some summary...\n");
4369}
4370
4371#[gpui::test]
4372async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4373 cx: &mut TestAppContext,
4374) {
4375 init_test(cx);
4376
4377 always_allow_tools(cx);
4378
4379 cx.update(|cx| {
4380 cx.update_flags(true, vec!["subagents".to_string()]);
4381 });
4382
4383 let fs = FakeFs::new(cx.executor());
4384 fs.insert_tree(path!("/test"), json!({})).await;
4385 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4386 let project_context = cx.new(|_cx| ProjectContext::default());
4387 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4388 let context_server_registry =
4389 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4390 cx.update(LanguageModelRegistry::test);
4391 let model = Arc::new(FakeLanguageModel::default());
4392 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4393 let native_agent = NativeAgent::new(
4394 project.clone(),
4395 thread_store,
4396 Templates::new(),
4397 None,
4398 fs,
4399 &mut cx.to_async(),
4400 )
4401 .await
4402 .unwrap();
4403 let parent_thread = cx.new(|cx| {
4404 Thread::new(
4405 project.clone(),
4406 project_context,
4407 context_server_registry,
4408 Templates::new(),
4409 Some(model.clone()),
4410 cx,
4411 )
4412 });
4413
4414 let subagent_handle = cx
4415 .update(|cx| {
4416 NativeThreadEnvironment::create_subagent_thread(
4417 native_agent.downgrade(),
4418 parent_thread.clone(),
4419 "some title".to_string(),
4420 "task prompt".to_string(),
4421 Some(Duration::from_millis(100)),
4422 None,
4423 cx,
4424 )
4425 })
4426 .expect("Failed to create subagent");
4427
4428 let summary_task =
4429 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4430
4431 cx.run_until_parked();
4432
4433 {
4434 let messages = model.pending_completions().last().unwrap().messages.clone();
4435 // Ensure that model received a system prompt
4436 assert_eq!(messages[0].role, Role::System);
4437 // Ensure that model received a task prompt
4438 assert_eq!(
4439 messages[1].content,
4440 vec![MessageContent::Text("task prompt".to_string())]
4441 );
4442 }
4443
4444 // Don't complete the initial model stream — let the timeout expire instead.
4445 cx.executor().advance_clock(Duration::from_millis(200));
4446 cx.run_until_parked();
4447
4448 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
4449 // instead of the summary_prompt.
4450 {
4451 let messages = model.pending_completions().last().unwrap().messages.clone();
4452 let last_user_message = messages
4453 .iter()
4454 .rev()
4455 .find(|m| m.role == Role::User)
4456 .unwrap();
4457 assert_eq!(
4458 last_user_message.content,
4459 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
4460 );
4461 }
4462
4463 model.send_last_completion_stream_text_chunk("Some context low response...");
4464 model.end_last_completion_stream();
4465
4466 let result = summary_task.await;
4467 assert_eq!(result.unwrap(), "Some context low response...\n");
4468}
4469
4470#[gpui::test]
4471async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
4472 init_test(cx);
4473
4474 always_allow_tools(cx);
4475
4476 cx.update(|cx| {
4477 cx.update_flags(true, vec!["subagents".to_string()]);
4478 });
4479
4480 let fs = FakeFs::new(cx.executor());
4481 fs.insert_tree(path!("/test"), json!({})).await;
4482 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4483 let project_context = cx.new(|_cx| ProjectContext::default());
4484 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4485 let context_server_registry =
4486 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4487 cx.update(LanguageModelRegistry::test);
4488 let model = Arc::new(FakeLanguageModel::default());
4489 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4490 let native_agent = NativeAgent::new(
4491 project.clone(),
4492 thread_store,
4493 Templates::new(),
4494 None,
4495 fs,
4496 &mut cx.to_async(),
4497 )
4498 .await
4499 .unwrap();
4500 let parent_thread = cx.new(|cx| {
4501 let mut thread = Thread::new(
4502 project.clone(),
4503 project_context,
4504 context_server_registry,
4505 Templates::new(),
4506 Some(model.clone()),
4507 cx,
4508 );
4509 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4510 thread.add_tool(GrepTool::new(project.clone()), None);
4511 thread
4512 });
4513
4514 let _subagent_handle = cx
4515 .update(|cx| {
4516 NativeThreadEnvironment::create_subagent_thread(
4517 native_agent.downgrade(),
4518 parent_thread.clone(),
4519 "some title".to_string(),
4520 "task prompt".to_string(),
4521 Some(Duration::from_millis(10)),
4522 None,
4523 cx,
4524 )
4525 })
4526 .expect("Failed to create subagent");
4527
4528 cx.run_until_parked();
4529
4530 let tools = model
4531 .pending_completions()
4532 .last()
4533 .unwrap()
4534 .tools
4535 .iter()
4536 .map(|tool| tool.name.clone())
4537 .collect::<Vec<_>>();
4538 assert_eq!(tools.len(), 2);
4539 assert!(tools.contains(&"grep".to_string()));
4540 assert!(tools.contains(&"list_directory".to_string()));
4541}
4542
4543#[gpui::test]
4544async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
4545 init_test(cx);
4546
4547 always_allow_tools(cx);
4548
4549 cx.update(|cx| {
4550 cx.update_flags(true, vec!["subagents".to_string()]);
4551 });
4552
4553 let fs = FakeFs::new(cx.executor());
4554 fs.insert_tree(path!("/test"), json!({})).await;
4555 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4556 let project_context = cx.new(|_cx| ProjectContext::default());
4557 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4558 let context_server_registry =
4559 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4560 cx.update(LanguageModelRegistry::test);
4561 let model = Arc::new(FakeLanguageModel::default());
4562 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4563 let native_agent = NativeAgent::new(
4564 project.clone(),
4565 thread_store,
4566 Templates::new(),
4567 None,
4568 fs,
4569 &mut cx.to_async(),
4570 )
4571 .await
4572 .unwrap();
4573 let parent_thread = cx.new(|cx| {
4574 let mut thread = Thread::new(
4575 project.clone(),
4576 project_context,
4577 context_server_registry,
4578 Templates::new(),
4579 Some(model.clone()),
4580 cx,
4581 );
4582 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4583 thread.add_tool(GrepTool::new(project.clone()), None);
4584 thread
4585 });
4586
4587 let _subagent_handle = cx
4588 .update(|cx| {
4589 NativeThreadEnvironment::create_subagent_thread(
4590 native_agent.downgrade(),
4591 parent_thread.clone(),
4592 "some title".to_string(),
4593 "task prompt".to_string(),
4594 Some(Duration::from_millis(10)),
4595 Some(vec!["grep".to_string()]),
4596 cx,
4597 )
4598 })
4599 .expect("Failed to create subagent");
4600
4601 cx.run_until_parked();
4602
4603 let tools = model
4604 .pending_completions()
4605 .last()
4606 .unwrap()
4607 .tools
4608 .iter()
4609 .map(|tool| tool.name.clone())
4610 .collect::<Vec<_>>();
4611 assert_eq!(tools, vec!["grep"]);
4612}
4613
4614#[gpui::test]
4615async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4616 init_test(cx);
4617
4618 let fs = FakeFs::new(cx.executor());
4619 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4620 .await;
4621 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4622
4623 cx.update(|cx| {
4624 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4625 settings.tool_permissions.tools.insert(
4626 EditFileTool::NAME.into(),
4627 agent_settings::ToolRules {
4628 default_mode: settings::ToolPermissionMode::Allow,
4629 always_allow: vec![],
4630 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
4631 always_confirm: vec![],
4632 invalid_patterns: vec![],
4633 },
4634 );
4635 agent_settings::AgentSettings::override_global(settings, cx);
4636 });
4637
4638 let context_server_registry =
4639 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4640 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4641 let templates = crate::Templates::new();
4642 let thread = cx.new(|cx| {
4643 crate::Thread::new(
4644 project.clone(),
4645 cx.new(|_cx| prompt_store::ProjectContext::default()),
4646 context_server_registry,
4647 templates.clone(),
4648 None,
4649 cx,
4650 )
4651 });
4652
4653 #[allow(clippy::arc_with_non_send_sync)]
4654 let tool = Arc::new(crate::EditFileTool::new(
4655 project.clone(),
4656 thread.downgrade(),
4657 language_registry,
4658 templates,
4659 ));
4660 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4661
4662 let task = cx.update(|cx| {
4663 tool.run(
4664 crate::EditFileToolInput {
4665 display_description: "Edit sensitive file".to_string(),
4666 path: "root/sensitive_config.txt".into(),
4667 mode: crate::EditFileMode::Edit,
4668 },
4669 event_stream,
4670 cx,
4671 )
4672 });
4673
4674 let result = task.await;
4675 assert!(result.is_err(), "expected edit to be blocked");
4676 assert!(
4677 result.unwrap_err().to_string().contains("blocked"),
4678 "error should mention the edit was blocked"
4679 );
4680}
4681
4682#[gpui::test]
4683async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4684 init_test(cx);
4685
4686 let fs = FakeFs::new(cx.executor());
4687 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4688 .await;
4689 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4690
4691 cx.update(|cx| {
4692 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4693 settings.tool_permissions.tools.insert(
4694 DeletePathTool::NAME.into(),
4695 agent_settings::ToolRules {
4696 default_mode: settings::ToolPermissionMode::Allow,
4697 always_allow: vec![],
4698 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4699 always_confirm: vec![],
4700 invalid_patterns: vec![],
4701 },
4702 );
4703 agent_settings::AgentSettings::override_global(settings, cx);
4704 });
4705
4706 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4707
4708 #[allow(clippy::arc_with_non_send_sync)]
4709 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4710 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4711
4712 let task = cx.update(|cx| {
4713 tool.run(
4714 crate::DeletePathToolInput {
4715 path: "root/important_data.txt".to_string(),
4716 },
4717 event_stream,
4718 cx,
4719 )
4720 });
4721
4722 let result = task.await;
4723 assert!(result.is_err(), "expected deletion to be blocked");
4724 assert!(
4725 result.unwrap_err().to_string().contains("blocked"),
4726 "error should mention the deletion was blocked"
4727 );
4728}
4729
4730#[gpui::test]
4731async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
4732 init_test(cx);
4733
4734 let fs = FakeFs::new(cx.executor());
4735 fs.insert_tree(
4736 "/root",
4737 json!({
4738 "safe.txt": "content",
4739 "protected": {}
4740 }),
4741 )
4742 .await;
4743 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4744
4745 cx.update(|cx| {
4746 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4747 settings.tool_permissions.tools.insert(
4748 MovePathTool::NAME.into(),
4749 agent_settings::ToolRules {
4750 default_mode: settings::ToolPermissionMode::Allow,
4751 always_allow: vec![],
4752 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
4753 always_confirm: vec![],
4754 invalid_patterns: vec![],
4755 },
4756 );
4757 agent_settings::AgentSettings::override_global(settings, cx);
4758 });
4759
4760 #[allow(clippy::arc_with_non_send_sync)]
4761 let tool = Arc::new(crate::MovePathTool::new(project));
4762 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4763
4764 let task = cx.update(|cx| {
4765 tool.run(
4766 crate::MovePathToolInput {
4767 source_path: "root/safe.txt".to_string(),
4768 destination_path: "root/protected/safe.txt".to_string(),
4769 },
4770 event_stream,
4771 cx,
4772 )
4773 });
4774
4775 let result = task.await;
4776 assert!(
4777 result.is_err(),
4778 "expected move to be blocked due to destination path"
4779 );
4780 assert!(
4781 result.unwrap_err().to_string().contains("blocked"),
4782 "error should mention the move was blocked"
4783 );
4784}
4785
4786#[gpui::test]
4787async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
4788 init_test(cx);
4789
4790 let fs = FakeFs::new(cx.executor());
4791 fs.insert_tree(
4792 "/root",
4793 json!({
4794 "secret.txt": "secret content",
4795 "public": {}
4796 }),
4797 )
4798 .await;
4799 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4800
4801 cx.update(|cx| {
4802 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4803 settings.tool_permissions.tools.insert(
4804 MovePathTool::NAME.into(),
4805 agent_settings::ToolRules {
4806 default_mode: settings::ToolPermissionMode::Allow,
4807 always_allow: vec![],
4808 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
4809 always_confirm: vec![],
4810 invalid_patterns: vec![],
4811 },
4812 );
4813 agent_settings::AgentSettings::override_global(settings, cx);
4814 });
4815
4816 #[allow(clippy::arc_with_non_send_sync)]
4817 let tool = Arc::new(crate::MovePathTool::new(project));
4818 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4819
4820 let task = cx.update(|cx| {
4821 tool.run(
4822 crate::MovePathToolInput {
4823 source_path: "root/secret.txt".to_string(),
4824 destination_path: "root/public/not_secret.txt".to_string(),
4825 },
4826 event_stream,
4827 cx,
4828 )
4829 });
4830
4831 let result = task.await;
4832 assert!(
4833 result.is_err(),
4834 "expected move to be blocked due to source path"
4835 );
4836 assert!(
4837 result.unwrap_err().to_string().contains("blocked"),
4838 "error should mention the move was blocked"
4839 );
4840}
4841
4842#[gpui::test]
4843async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
4844 init_test(cx);
4845
4846 let fs = FakeFs::new(cx.executor());
4847 fs.insert_tree(
4848 "/root",
4849 json!({
4850 "confidential.txt": "confidential data",
4851 "dest": {}
4852 }),
4853 )
4854 .await;
4855 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4856
4857 cx.update(|cx| {
4858 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4859 settings.tool_permissions.tools.insert(
4860 CopyPathTool::NAME.into(),
4861 agent_settings::ToolRules {
4862 default_mode: settings::ToolPermissionMode::Allow,
4863 always_allow: vec![],
4864 always_deny: vec![
4865 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
4866 ],
4867 always_confirm: vec![],
4868 invalid_patterns: vec![],
4869 },
4870 );
4871 agent_settings::AgentSettings::override_global(settings, cx);
4872 });
4873
4874 #[allow(clippy::arc_with_non_send_sync)]
4875 let tool = Arc::new(crate::CopyPathTool::new(project));
4876 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4877
4878 let task = cx.update(|cx| {
4879 tool.run(
4880 crate::CopyPathToolInput {
4881 source_path: "root/confidential.txt".to_string(),
4882 destination_path: "root/dest/copy.txt".to_string(),
4883 },
4884 event_stream,
4885 cx,
4886 )
4887 });
4888
4889 let result = task.await;
4890 assert!(result.is_err(), "expected copy to be blocked");
4891 assert!(
4892 result.unwrap_err().to_string().contains("blocked"),
4893 "error should mention the copy was blocked"
4894 );
4895}
4896
4897#[gpui::test]
4898async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
4899 init_test(cx);
4900
4901 let fs = FakeFs::new(cx.executor());
4902 fs.insert_tree(
4903 "/root",
4904 json!({
4905 "normal.txt": "normal content",
4906 "readonly": {
4907 "config.txt": "readonly content"
4908 }
4909 }),
4910 )
4911 .await;
4912 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4913
4914 cx.update(|cx| {
4915 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4916 settings.tool_permissions.tools.insert(
4917 SaveFileTool::NAME.into(),
4918 agent_settings::ToolRules {
4919 default_mode: settings::ToolPermissionMode::Allow,
4920 always_allow: vec![],
4921 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
4922 always_confirm: vec![],
4923 invalid_patterns: vec![],
4924 },
4925 );
4926 agent_settings::AgentSettings::override_global(settings, cx);
4927 });
4928
4929 #[allow(clippy::arc_with_non_send_sync)]
4930 let tool = Arc::new(crate::SaveFileTool::new(project));
4931 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4932
4933 let task = cx.update(|cx| {
4934 tool.run(
4935 crate::SaveFileToolInput {
4936 paths: vec![
4937 std::path::PathBuf::from("root/normal.txt"),
4938 std::path::PathBuf::from("root/readonly/config.txt"),
4939 ],
4940 },
4941 event_stream,
4942 cx,
4943 )
4944 });
4945
4946 let result = task.await;
4947 assert!(
4948 result.is_err(),
4949 "expected save to be blocked due to denied path"
4950 );
4951 assert!(
4952 result.unwrap_err().to_string().contains("blocked"),
4953 "error should mention the save was blocked"
4954 );
4955}
4956
4957#[gpui::test]
4958async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
4959 init_test(cx);
4960
4961 let fs = FakeFs::new(cx.executor());
4962 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
4963 .await;
4964 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4965
4966 cx.update(|cx| {
4967 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4968 settings.always_allow_tool_actions = false;
4969 settings.tool_permissions.tools.insert(
4970 SaveFileTool::NAME.into(),
4971 agent_settings::ToolRules {
4972 default_mode: settings::ToolPermissionMode::Allow,
4973 always_allow: vec![],
4974 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
4975 always_confirm: vec![],
4976 invalid_patterns: vec![],
4977 },
4978 );
4979 agent_settings::AgentSettings::override_global(settings, cx);
4980 });
4981
4982 #[allow(clippy::arc_with_non_send_sync)]
4983 let tool = Arc::new(crate::SaveFileTool::new(project));
4984 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4985
4986 let task = cx.update(|cx| {
4987 tool.run(
4988 crate::SaveFileToolInput {
4989 paths: vec![std::path::PathBuf::from("root/config.secret")],
4990 },
4991 event_stream,
4992 cx,
4993 )
4994 });
4995
4996 let result = task.await;
4997 assert!(result.is_err(), "expected save to be blocked");
4998 assert!(
4999 result.unwrap_err().to_string().contains("blocked"),
5000 "error should mention the save was blocked"
5001 );
5002}
5003
5004#[gpui::test]
5005async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5006 init_test(cx);
5007
5008 cx.update(|cx| {
5009 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5010 settings.tool_permissions.tools.insert(
5011 WebSearchTool::NAME.into(),
5012 agent_settings::ToolRules {
5013 default_mode: settings::ToolPermissionMode::Allow,
5014 always_allow: vec![],
5015 always_deny: vec![
5016 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5017 ],
5018 always_confirm: vec![],
5019 invalid_patterns: vec![],
5020 },
5021 );
5022 agent_settings::AgentSettings::override_global(settings, cx);
5023 });
5024
5025 #[allow(clippy::arc_with_non_send_sync)]
5026 let tool = Arc::new(crate::WebSearchTool);
5027 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5028
5029 let input: crate::WebSearchToolInput =
5030 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5031
5032 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5033
5034 let result = task.await;
5035 assert!(result.is_err(), "expected search to be blocked");
5036 assert!(
5037 result.unwrap_err().to_string().contains("blocked"),
5038 "error should mention the search was blocked"
5039 );
5040}
5041
5042#[gpui::test]
5043async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5044 init_test(cx);
5045
5046 let fs = FakeFs::new(cx.executor());
5047 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5048 .await;
5049 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5050
5051 cx.update(|cx| {
5052 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5053 settings.always_allow_tool_actions = false;
5054 settings.tool_permissions.tools.insert(
5055 EditFileTool::NAME.into(),
5056 agent_settings::ToolRules {
5057 default_mode: settings::ToolPermissionMode::Confirm,
5058 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5059 always_deny: vec![],
5060 always_confirm: vec![],
5061 invalid_patterns: vec![],
5062 },
5063 );
5064 agent_settings::AgentSettings::override_global(settings, cx);
5065 });
5066
5067 let context_server_registry =
5068 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5069 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5070 let templates = crate::Templates::new();
5071 let thread = cx.new(|cx| {
5072 crate::Thread::new(
5073 project.clone(),
5074 cx.new(|_cx| prompt_store::ProjectContext::default()),
5075 context_server_registry,
5076 templates.clone(),
5077 None,
5078 cx,
5079 )
5080 });
5081
5082 #[allow(clippy::arc_with_non_send_sync)]
5083 let tool = Arc::new(crate::EditFileTool::new(
5084 project,
5085 thread.downgrade(),
5086 language_registry,
5087 templates,
5088 ));
5089 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5090
5091 let _task = cx.update(|cx| {
5092 tool.run(
5093 crate::EditFileToolInput {
5094 display_description: "Edit README".to_string(),
5095 path: "root/README.md".into(),
5096 mode: crate::EditFileMode::Edit,
5097 },
5098 event_stream,
5099 cx,
5100 )
5101 });
5102
5103 cx.run_until_parked();
5104
5105 let event = rx.try_next();
5106 assert!(
5107 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5108 "expected no authorization request for allowed .md file"
5109 );
5110}
5111
5112#[gpui::test]
5113async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5114 init_test(cx);
5115
5116 cx.update(|cx| {
5117 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5118 settings.tool_permissions.tools.insert(
5119 FetchTool::NAME.into(),
5120 agent_settings::ToolRules {
5121 default_mode: settings::ToolPermissionMode::Allow,
5122 always_allow: vec![],
5123 always_deny: vec![
5124 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5125 ],
5126 always_confirm: vec![],
5127 invalid_patterns: vec![],
5128 },
5129 );
5130 agent_settings::AgentSettings::override_global(settings, cx);
5131 });
5132
5133 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5134
5135 #[allow(clippy::arc_with_non_send_sync)]
5136 let tool = Arc::new(crate::FetchTool::new(http_client));
5137 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5138
5139 let input: crate::FetchToolInput =
5140 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5141
5142 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5143
5144 let result = task.await;
5145 assert!(result.is_err(), "expected fetch to be blocked");
5146 assert!(
5147 result.unwrap_err().to_string().contains("blocked"),
5148 "error should mention the fetch was blocked"
5149 );
5150}
5151
5152#[gpui::test]
5153async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5154 init_test(cx);
5155
5156 cx.update(|cx| {
5157 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5158 settings.always_allow_tool_actions = false;
5159 settings.tool_permissions.tools.insert(
5160 FetchTool::NAME.into(),
5161 agent_settings::ToolRules {
5162 default_mode: settings::ToolPermissionMode::Confirm,
5163 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5164 always_deny: vec![],
5165 always_confirm: vec![],
5166 invalid_patterns: vec![],
5167 },
5168 );
5169 agent_settings::AgentSettings::override_global(settings, cx);
5170 });
5171
5172 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5173
5174 #[allow(clippy::arc_with_non_send_sync)]
5175 let tool = Arc::new(crate::FetchTool::new(http_client));
5176 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5177
5178 let input: crate::FetchToolInput =
5179 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5180
5181 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5182
5183 cx.run_until_parked();
5184
5185 let event = rx.try_next();
5186 assert!(
5187 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5188 "expected no authorization request for allowed docs.rs URL"
5189 );
5190}
5191
5192#[gpui::test]
5193async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5194 init_test(cx);
5195 always_allow_tools(cx);
5196
5197 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5198 let fake_model = model.as_fake();
5199
5200 // Add a tool so we can simulate tool calls
5201 thread.update(cx, |thread, _cx| {
5202 thread.add_tool(EchoTool, None);
5203 });
5204
5205 // Start a turn by sending a message
5206 let mut events = thread
5207 .update(cx, |thread, cx| {
5208 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5209 })
5210 .unwrap();
5211 cx.run_until_parked();
5212
5213 // Simulate the model making a tool call
5214 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5215 LanguageModelToolUse {
5216 id: "tool_1".into(),
5217 name: "echo".into(),
5218 raw_input: r#"{"text": "hello"}"#.into(),
5219 input: json!({"text": "hello"}),
5220 is_input_complete: true,
5221 thought_signature: None,
5222 },
5223 ));
5224 fake_model
5225 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5226
5227 // Signal that a message is queued before ending the stream
5228 thread.update(cx, |thread, _cx| {
5229 thread.set_has_queued_message(true);
5230 });
5231
5232 // Now end the stream - tool will run, and the boundary check should see the queue
5233 fake_model.end_last_completion_stream();
5234
5235 // Collect all events until the turn stops
5236 let all_events = collect_events_until_stop(&mut events, cx).await;
5237
5238 // Verify we received the tool call event
5239 let tool_call_ids: Vec<_> = all_events
5240 .iter()
5241 .filter_map(|e| match e {
5242 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5243 _ => None,
5244 })
5245 .collect();
5246 assert_eq!(
5247 tool_call_ids,
5248 vec!["tool_1"],
5249 "Should have received a tool call event for our echo tool"
5250 );
5251
5252 // The turn should have stopped with EndTurn
5253 let stop_reasons = stop_events(all_events);
5254 assert_eq!(
5255 stop_reasons,
5256 vec![acp::StopReason::EndTurn],
5257 "Turn should have ended after tool completion due to queued message"
5258 );
5259
5260 // Verify the queued message flag is still set
5261 thread.update(cx, |thread, _cx| {
5262 assert!(
5263 thread.has_queued_message(),
5264 "Should still have queued message flag set"
5265 );
5266 });
5267
5268 // Thread should be idle now
5269 thread.update(cx, |thread, _cx| {
5270 assert!(
5271 thread.is_turn_complete(),
5272 "Thread should not be running after turn ends"
5273 );
5274 });
5275}