1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
33};
34use pretty_assertions::assert_eq;
35use project::{
36 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
37};
38use prompt_store::ProjectContext;
39use reqwest_client::ReqwestClient;
40use schemars::JsonSchema;
41use serde::{Deserialize, Serialize};
42use serde_json::json;
43use settings::{Settings, SettingsStore};
44use std::{
45 path::Path,
46 pin::Pin,
47 rc::Rc,
48 sync::{
49 Arc,
50 atomic::{AtomicBool, Ordering},
51 },
52 time::Duration,
53};
54use util::path;
55
56mod edit_file_thread_test;
57mod test_tools;
58use test_tools::*;
59
60fn init_test(cx: &mut TestAppContext) {
61 cx.update(|cx| {
62 let settings_store = SettingsStore::test(cx);
63 cx.set_global(settings_store);
64 });
65}
66
67struct FakeTerminalHandle {
68 killed: Arc<AtomicBool>,
69 stopped_by_user: Arc<AtomicBool>,
70 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
71 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
72 output: acp::TerminalOutputResponse,
73 id: acp::TerminalId,
74}
75
76impl FakeTerminalHandle {
77 fn new_never_exits(cx: &mut App) -> Self {
78 let killed = Arc::new(AtomicBool::new(false));
79 let stopped_by_user = Arc::new(AtomicBool::new(false));
80
81 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
82
83 let wait_for_exit = cx
84 .spawn(async move |_cx| {
85 // Wait for the exit signal (sent when kill() is called)
86 let _ = exit_receiver.await;
87 acp::TerminalExitStatus::new()
88 })
89 .shared();
90
91 Self {
92 killed,
93 stopped_by_user,
94 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
95 wait_for_exit,
96 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
97 id: acp::TerminalId::new("fake_terminal".to_string()),
98 }
99 }
100
101 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
102 let killed = Arc::new(AtomicBool::new(false));
103 let stopped_by_user = Arc::new(AtomicBool::new(false));
104 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
105
106 let wait_for_exit = cx
107 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
108 .shared();
109
110 Self {
111 killed,
112 stopped_by_user,
113 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
114 wait_for_exit,
115 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
116 id: acp::TerminalId::new("fake_terminal".to_string()),
117 }
118 }
119
120 fn was_killed(&self) -> bool {
121 self.killed.load(Ordering::SeqCst)
122 }
123
124 fn set_stopped_by_user(&self, stopped: bool) {
125 self.stopped_by_user.store(stopped, Ordering::SeqCst);
126 }
127
128 fn signal_exit(&self) {
129 if let Some(sender) = self.exit_sender.borrow_mut().take() {
130 let _ = sender.send(());
131 }
132 }
133}
134
135impl crate::TerminalHandle for FakeTerminalHandle {
136 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
137 Ok(self.id.clone())
138 }
139
140 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
141 Ok(self.output.clone())
142 }
143
144 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
145 Ok(self.wait_for_exit.clone())
146 }
147
148 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
149 self.killed.store(true, Ordering::SeqCst);
150 self.signal_exit();
151 Ok(())
152 }
153
154 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
155 Ok(self.stopped_by_user.load(Ordering::SeqCst))
156 }
157}
158
159struct FakeSubagentHandle {
160 session_id: acp::SessionId,
161 wait_for_summary_task: Shared<Task<String>>,
162}
163
164impl SubagentHandle for FakeSubagentHandle {
165 fn id(&self) -> acp::SessionId {
166 self.session_id.clone()
167 }
168
169 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
170 let task = self.wait_for_summary_task.clone();
171 cx.background_spawn(async move { Ok(task.await) })
172 }
173}
174
175#[derive(Default)]
176struct FakeThreadEnvironment {
177 terminal_handle: Option<Rc<FakeTerminalHandle>>,
178 subagent_handle: Option<Rc<FakeSubagentHandle>>,
179}
180
181impl FakeThreadEnvironment {
182 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
183 Self {
184 terminal_handle: Some(terminal_handle.into()),
185 ..self
186 }
187 }
188}
189
190impl crate::ThreadEnvironment for FakeThreadEnvironment {
191 fn create_terminal(
192 &self,
193 _command: String,
194 _cwd: Option<std::path::PathBuf>,
195 _output_byte_limit: Option<u64>,
196 _cx: &mut AsyncApp,
197 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
198 let handle = self
199 .terminal_handle
200 .clone()
201 .expect("Terminal handle not available on FakeThreadEnvironment");
202 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
203 }
204
205 fn create_subagent(
206 &self,
207 _parent_thread: Entity<Thread>,
208 _label: String,
209 _initial_prompt: String,
210 _timeout_ms: Option<Duration>,
211 _allowed_tools: Option<Vec<String>>,
212 _cx: &mut App,
213 ) -> Result<Rc<dyn SubagentHandle>> {
214 Ok(self
215 .subagent_handle
216 .clone()
217 .expect("Subagent handle not available on FakeThreadEnvironment")
218 as Rc<dyn SubagentHandle>)
219 }
220}
221
222/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
223struct MultiTerminalEnvironment {
224 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
225}
226
227impl MultiTerminalEnvironment {
228 fn new() -> Self {
229 Self {
230 handles: std::cell::RefCell::new(Vec::new()),
231 }
232 }
233
234 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
235 self.handles.borrow().clone()
236 }
237}
238
239impl crate::ThreadEnvironment for MultiTerminalEnvironment {
240 fn create_terminal(
241 &self,
242 _command: String,
243 _cwd: Option<std::path::PathBuf>,
244 _output_byte_limit: Option<u64>,
245 cx: &mut AsyncApp,
246 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
247 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
248 self.handles.borrow_mut().push(handle.clone());
249 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
250 }
251
252 fn create_subagent(
253 &self,
254 _parent_thread: Entity<Thread>,
255 _label: String,
256 _initial_prompt: String,
257 _timeout: Option<Duration>,
258 _allowed_tools: Option<Vec<String>>,
259 _cx: &mut App,
260 ) -> Result<Rc<dyn SubagentHandle>> {
261 unimplemented!()
262 }
263}
264
265fn always_allow_tools(cx: &mut TestAppContext) {
266 cx.update(|cx| {
267 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
268 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
269 agent_settings::AgentSettings::override_global(settings, cx);
270 });
271}
272
273#[gpui::test]
274async fn test_echo(cx: &mut TestAppContext) {
275 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
276 let fake_model = model.as_fake();
277
278 let events = thread
279 .update(cx, |thread, cx| {
280 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
281 })
282 .unwrap();
283 cx.run_until_parked();
284 fake_model.send_last_completion_stream_text_chunk("Hello");
285 fake_model
286 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
287 fake_model.end_last_completion_stream();
288
289 let events = events.collect().await;
290 thread.update(cx, |thread, _cx| {
291 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
292 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
293 });
294 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
295}
296
297#[gpui::test]
298async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
299 init_test(cx);
300 always_allow_tools(cx);
301
302 let fs = FakeFs::new(cx.executor());
303 let project = Project::test(fs, [], cx).await;
304
305 let environment = Rc::new(cx.update(|cx| {
306 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
307 }));
308 let handle = environment.terminal_handle.clone().unwrap();
309
310 #[allow(clippy::arc_with_non_send_sync)]
311 let tool = Arc::new(crate::TerminalTool::new(project, environment));
312 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
313
314 let task = cx.update(|cx| {
315 tool.run(
316 crate::TerminalToolInput {
317 command: "sleep 1000".to_string(),
318 cd: ".".to_string(),
319 timeout_ms: Some(5),
320 },
321 event_stream,
322 cx,
323 )
324 });
325
326 let update = rx.expect_update_fields().await;
327 assert!(
328 update.content.iter().any(|blocks| {
329 blocks
330 .iter()
331 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
332 }),
333 "expected tool call update to include terminal content"
334 );
335
336 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
337
338 let deadline = std::time::Instant::now() + Duration::from_millis(500);
339 loop {
340 if let Some(result) = task_future.as_mut().now_or_never() {
341 let result = result.expect("terminal tool task should complete");
342
343 assert!(
344 handle.was_killed(),
345 "expected terminal handle to be killed on timeout"
346 );
347 assert!(
348 result.contains("partial output"),
349 "expected result to include terminal output, got: {result}"
350 );
351 return;
352 }
353
354 if std::time::Instant::now() >= deadline {
355 panic!("timed out waiting for terminal tool task to complete");
356 }
357
358 cx.run_until_parked();
359 cx.background_executor.timer(Duration::from_millis(1)).await;
360 }
361}
362
363#[gpui::test]
364#[ignore]
365async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
366 init_test(cx);
367 always_allow_tools(cx);
368
369 let fs = FakeFs::new(cx.executor());
370 let project = Project::test(fs, [], cx).await;
371
372 let environment = Rc::new(cx.update(|cx| {
373 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
374 }));
375 let handle = environment.terminal_handle.clone().unwrap();
376
377 #[allow(clippy::arc_with_non_send_sync)]
378 let tool = Arc::new(crate::TerminalTool::new(project, environment));
379 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
380
381 let _task = cx.update(|cx| {
382 tool.run(
383 crate::TerminalToolInput {
384 command: "sleep 1000".to_string(),
385 cd: ".".to_string(),
386 timeout_ms: None,
387 },
388 event_stream,
389 cx,
390 )
391 });
392
393 let update = rx.expect_update_fields().await;
394 assert!(
395 update.content.iter().any(|blocks| {
396 blocks
397 .iter()
398 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
399 }),
400 "expected tool call update to include terminal content"
401 );
402
403 cx.background_executor
404 .timer(Duration::from_millis(25))
405 .await;
406
407 assert!(
408 !handle.was_killed(),
409 "did not expect terminal handle to be killed without a timeout"
410 );
411}
412
413#[gpui::test]
414async fn test_thinking(cx: &mut TestAppContext) {
415 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
416 let fake_model = model.as_fake();
417
418 let events = thread
419 .update(cx, |thread, cx| {
420 thread.send(
421 UserMessageId::new(),
422 [indoc! {"
423 Testing:
424
425 Generate a thinking step where you just think the word 'Think',
426 and have your final answer be 'Hello'
427 "}],
428 cx,
429 )
430 })
431 .unwrap();
432 cx.run_until_parked();
433 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
434 text: "Think".to_string(),
435 signature: None,
436 });
437 fake_model.send_last_completion_stream_text_chunk("Hello");
438 fake_model
439 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
440 fake_model.end_last_completion_stream();
441
442 let events = events.collect().await;
443 thread.update(cx, |thread, _cx| {
444 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
445 assert_eq!(
446 thread.last_message().unwrap().to_markdown(),
447 indoc! {"
448 <think>Think</think>
449 Hello
450 "}
451 )
452 });
453 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
454}
455
456#[gpui::test]
457async fn test_system_prompt(cx: &mut TestAppContext) {
458 let ThreadTest {
459 model,
460 thread,
461 project_context,
462 ..
463 } = setup(cx, TestModel::Fake).await;
464 let fake_model = model.as_fake();
465
466 project_context.update(cx, |project_context, _cx| {
467 project_context.shell = "test-shell".into()
468 });
469 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
470 thread
471 .update(cx, |thread, cx| {
472 thread.send(UserMessageId::new(), ["abc"], cx)
473 })
474 .unwrap();
475 cx.run_until_parked();
476 let mut pending_completions = fake_model.pending_completions();
477 assert_eq!(
478 pending_completions.len(),
479 1,
480 "unexpected pending completions: {:?}",
481 pending_completions
482 );
483
484 let pending_completion = pending_completions.pop().unwrap();
485 assert_eq!(pending_completion.messages[0].role, Role::System);
486
487 let system_message = &pending_completion.messages[0];
488 let system_prompt = system_message.content[0].to_str().unwrap();
489 assert!(
490 system_prompt.contains("test-shell"),
491 "unexpected system message: {:?}",
492 system_message
493 );
494 assert!(
495 system_prompt.contains("## Fixing Diagnostics"),
496 "unexpected system message: {:?}",
497 system_message
498 );
499}
500
501#[gpui::test]
502async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
503 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
504 let fake_model = model.as_fake();
505
506 thread
507 .update(cx, |thread, cx| {
508 thread.send(UserMessageId::new(), ["abc"], cx)
509 })
510 .unwrap();
511 cx.run_until_parked();
512 let mut pending_completions = fake_model.pending_completions();
513 assert_eq!(
514 pending_completions.len(),
515 1,
516 "unexpected pending completions: {:?}",
517 pending_completions
518 );
519
520 let pending_completion = pending_completions.pop().unwrap();
521 assert_eq!(pending_completion.messages[0].role, Role::System);
522
523 let system_message = &pending_completion.messages[0];
524 let system_prompt = system_message.content[0].to_str().unwrap();
525 assert!(
526 !system_prompt.contains("## Tool Use"),
527 "unexpected system message: {:?}",
528 system_message
529 );
530 assert!(
531 !system_prompt.contains("## Fixing Diagnostics"),
532 "unexpected system message: {:?}",
533 system_message
534 );
535}
536
537#[gpui::test]
538async fn test_prompt_caching(cx: &mut TestAppContext) {
539 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
540 let fake_model = model.as_fake();
541
542 // Send initial user message and verify it's cached
543 thread
544 .update(cx, |thread, cx| {
545 thread.send(UserMessageId::new(), ["Message 1"], cx)
546 })
547 .unwrap();
548 cx.run_until_parked();
549
550 let completion = fake_model.pending_completions().pop().unwrap();
551 assert_eq!(
552 completion.messages[1..],
553 vec![LanguageModelRequestMessage {
554 role: Role::User,
555 content: vec!["Message 1".into()],
556 cache: true,
557 reasoning_details: None,
558 }]
559 );
560 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
561 "Response to Message 1".into(),
562 ));
563 fake_model.end_last_completion_stream();
564 cx.run_until_parked();
565
566 // Send another user message and verify only the latest is cached
567 thread
568 .update(cx, |thread, cx| {
569 thread.send(UserMessageId::new(), ["Message 2"], cx)
570 })
571 .unwrap();
572 cx.run_until_parked();
573
574 let completion = fake_model.pending_completions().pop().unwrap();
575 assert_eq!(
576 completion.messages[1..],
577 vec![
578 LanguageModelRequestMessage {
579 role: Role::User,
580 content: vec!["Message 1".into()],
581 cache: false,
582 reasoning_details: None,
583 },
584 LanguageModelRequestMessage {
585 role: Role::Assistant,
586 content: vec!["Response to Message 1".into()],
587 cache: false,
588 reasoning_details: None,
589 },
590 LanguageModelRequestMessage {
591 role: Role::User,
592 content: vec!["Message 2".into()],
593 cache: true,
594 reasoning_details: None,
595 }
596 ]
597 );
598 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
599 "Response to Message 2".into(),
600 ));
601 fake_model.end_last_completion_stream();
602 cx.run_until_parked();
603
604 // Simulate a tool call and verify that the latest tool result is cached
605 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
606 thread
607 .update(cx, |thread, cx| {
608 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
609 })
610 .unwrap();
611 cx.run_until_parked();
612
613 let tool_use = LanguageModelToolUse {
614 id: "tool_1".into(),
615 name: EchoTool::NAME.into(),
616 raw_input: json!({"text": "test"}).to_string(),
617 input: json!({"text": "test"}),
618 is_input_complete: true,
619 thought_signature: None,
620 };
621 fake_model
622 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
623 fake_model.end_last_completion_stream();
624 cx.run_until_parked();
625
626 let completion = fake_model.pending_completions().pop().unwrap();
627 let tool_result = LanguageModelToolResult {
628 tool_use_id: "tool_1".into(),
629 tool_name: EchoTool::NAME.into(),
630 is_error: false,
631 content: "test".into(),
632 output: Some("test".into()),
633 };
634 assert_eq!(
635 completion.messages[1..],
636 vec![
637 LanguageModelRequestMessage {
638 role: Role::User,
639 content: vec!["Message 1".into()],
640 cache: false,
641 reasoning_details: None,
642 },
643 LanguageModelRequestMessage {
644 role: Role::Assistant,
645 content: vec!["Response to Message 1".into()],
646 cache: false,
647 reasoning_details: None,
648 },
649 LanguageModelRequestMessage {
650 role: Role::User,
651 content: vec!["Message 2".into()],
652 cache: false,
653 reasoning_details: None,
654 },
655 LanguageModelRequestMessage {
656 role: Role::Assistant,
657 content: vec!["Response to Message 2".into()],
658 cache: false,
659 reasoning_details: None,
660 },
661 LanguageModelRequestMessage {
662 role: Role::User,
663 content: vec!["Use the echo tool".into()],
664 cache: false,
665 reasoning_details: None,
666 },
667 LanguageModelRequestMessage {
668 role: Role::Assistant,
669 content: vec![MessageContent::ToolUse(tool_use)],
670 cache: false,
671 reasoning_details: None,
672 },
673 LanguageModelRequestMessage {
674 role: Role::User,
675 content: vec![MessageContent::ToolResult(tool_result)],
676 cache: true,
677 reasoning_details: None,
678 }
679 ]
680 );
681}
682
683#[gpui::test]
684#[cfg_attr(not(feature = "e2e"), ignore)]
685async fn test_basic_tool_calls(cx: &mut TestAppContext) {
686 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
687
688 // Test a tool call that's likely to complete *before* streaming stops.
689 let events = thread
690 .update(cx, |thread, cx| {
691 thread.add_tool(EchoTool, None);
692 thread.send(
693 UserMessageId::new(),
694 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
695 cx,
696 )
697 })
698 .unwrap()
699 .collect()
700 .await;
701 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
702
703 // Test a tool calls that's likely to complete *after* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.remove_tool(&EchoTool::NAME);
707 thread.add_tool(DelayTool, None);
708 thread.send(
709 UserMessageId::new(),
710 [
711 "Now call the delay tool with 200ms.",
712 "When the timer goes off, then you echo the output of the tool.",
713 ],
714 cx,
715 )
716 })
717 .unwrap()
718 .collect()
719 .await;
720 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
721 thread.update(cx, |thread, _cx| {
722 assert!(
723 thread
724 .last_message()
725 .unwrap()
726 .as_agent_message()
727 .unwrap()
728 .content
729 .iter()
730 .any(|content| {
731 if let AgentMessageContent::Text(text) = content {
732 text.contains("Ding")
733 } else {
734 false
735 }
736 }),
737 "{}",
738 thread.to_markdown()
739 );
740 });
741}
742
743#[gpui::test]
744#[cfg_attr(not(feature = "e2e"), ignore)]
745async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
746 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
747
748 // Test a tool call that's likely to complete *before* streaming stops.
749 let mut events = thread
750 .update(cx, |thread, cx| {
751 thread.add_tool(WordListTool, None);
752 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
753 })
754 .unwrap();
755
756 let mut saw_partial_tool_use = false;
757 while let Some(event) = events.next().await {
758 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
759 thread.update(cx, |thread, _cx| {
760 // Look for a tool use in the thread's last message
761 let message = thread.last_message().unwrap();
762 let agent_message = message.as_agent_message().unwrap();
763 let last_content = agent_message.content.last().unwrap();
764 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
765 assert_eq!(last_tool_use.name.as_ref(), "word_list");
766 if tool_call.status == acp::ToolCallStatus::Pending {
767 if !last_tool_use.is_input_complete
768 && last_tool_use.input.get("g").is_none()
769 {
770 saw_partial_tool_use = true;
771 }
772 } else {
773 last_tool_use
774 .input
775 .get("a")
776 .expect("'a' has streamed because input is now complete");
777 last_tool_use
778 .input
779 .get("g")
780 .expect("'g' has streamed because input is now complete");
781 }
782 } else {
783 panic!("last content should be a tool use");
784 }
785 });
786 }
787 }
788
789 assert!(
790 saw_partial_tool_use,
791 "should see at least one partially streamed tool use in the history"
792 );
793}
794
795#[gpui::test]
796async fn test_tool_authorization(cx: &mut TestAppContext) {
797 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
798 let fake_model = model.as_fake();
799
800 let mut events = thread
801 .update(cx, |thread, cx| {
802 thread.add_tool(ToolRequiringPermission, None);
803 thread.send(UserMessageId::new(), ["abc"], cx)
804 })
805 .unwrap();
806 cx.run_until_parked();
807 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
808 LanguageModelToolUse {
809 id: "tool_id_1".into(),
810 name: ToolRequiringPermission::NAME.into(),
811 raw_input: "{}".into(),
812 input: json!({}),
813 is_input_complete: true,
814 thought_signature: None,
815 },
816 ));
817 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
818 LanguageModelToolUse {
819 id: "tool_id_2".into(),
820 name: ToolRequiringPermission::NAME.into(),
821 raw_input: "{}".into(),
822 input: json!({}),
823 is_input_complete: true,
824 thought_signature: None,
825 },
826 ));
827 fake_model.end_last_completion_stream();
828 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
829 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
830
831 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
832 tool_call_auth_1
833 .response
834 .send(acp::PermissionOptionId::new("allow"))
835 .unwrap();
836 cx.run_until_parked();
837
838 // Reject the second - send "deny" option_id directly since Deny is now a button
839 tool_call_auth_2
840 .response
841 .send(acp::PermissionOptionId::new("deny"))
842 .unwrap();
843 cx.run_until_parked();
844
845 let completion = fake_model.pending_completions().pop().unwrap();
846 let message = completion.messages.last().unwrap();
847 assert_eq!(
848 message.content,
849 vec![
850 language_model::MessageContent::ToolResult(LanguageModelToolResult {
851 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
852 tool_name: ToolRequiringPermission::NAME.into(),
853 is_error: false,
854 content: "Allowed".into(),
855 output: Some("Allowed".into())
856 }),
857 language_model::MessageContent::ToolResult(LanguageModelToolResult {
858 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
859 tool_name: ToolRequiringPermission::NAME.into(),
860 is_error: true,
861 content: "Permission to run tool denied by user".into(),
862 output: Some("Permission to run tool denied by user".into())
863 })
864 ]
865 );
866
867 // Simulate yet another tool call.
868 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
869 LanguageModelToolUse {
870 id: "tool_id_3".into(),
871 name: ToolRequiringPermission::NAME.into(),
872 raw_input: "{}".into(),
873 input: json!({}),
874 is_input_complete: true,
875 thought_signature: None,
876 },
877 ));
878 fake_model.end_last_completion_stream();
879
880 // Respond by always allowing tools - send transformed option_id
881 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
882 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
883 tool_call_auth_3
884 .response
885 .send(acp::PermissionOptionId::new(
886 "always_allow:tool_requiring_permission",
887 ))
888 .unwrap();
889 cx.run_until_parked();
890 let completion = fake_model.pending_completions().pop().unwrap();
891 let message = completion.messages.last().unwrap();
892 assert_eq!(
893 message.content,
894 vec![language_model::MessageContent::ToolResult(
895 LanguageModelToolResult {
896 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
897 tool_name: ToolRequiringPermission::NAME.into(),
898 is_error: false,
899 content: "Allowed".into(),
900 output: Some("Allowed".into())
901 }
902 )]
903 );
904
905 // Simulate a final tool call, ensuring we don't trigger authorization.
906 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
907 LanguageModelToolUse {
908 id: "tool_id_4".into(),
909 name: ToolRequiringPermission::NAME.into(),
910 raw_input: "{}".into(),
911 input: json!({}),
912 is_input_complete: true,
913 thought_signature: None,
914 },
915 ));
916 fake_model.end_last_completion_stream();
917 cx.run_until_parked();
918 let completion = fake_model.pending_completions().pop().unwrap();
919 let message = completion.messages.last().unwrap();
920 assert_eq!(
921 message.content,
922 vec![language_model::MessageContent::ToolResult(
923 LanguageModelToolResult {
924 tool_use_id: "tool_id_4".into(),
925 tool_name: ToolRequiringPermission::NAME.into(),
926 is_error: false,
927 content: "Allowed".into(),
928 output: Some("Allowed".into())
929 }
930 )]
931 );
932}
933
934#[gpui::test]
935async fn test_tool_hallucination(cx: &mut TestAppContext) {
936 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
937 let fake_model = model.as_fake();
938
939 let mut events = thread
940 .update(cx, |thread, cx| {
941 thread.send(UserMessageId::new(), ["abc"], cx)
942 })
943 .unwrap();
944 cx.run_until_parked();
945 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
946 LanguageModelToolUse {
947 id: "tool_id_1".into(),
948 name: "nonexistent_tool".into(),
949 raw_input: "{}".into(),
950 input: json!({}),
951 is_input_complete: true,
952 thought_signature: None,
953 },
954 ));
955 fake_model.end_last_completion_stream();
956
957 let tool_call = expect_tool_call(&mut events).await;
958 assert_eq!(tool_call.title, "nonexistent_tool");
959 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
960 let update = expect_tool_call_update_fields(&mut events).await;
961 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
962}
963
964async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
965 let event = events
966 .next()
967 .await
968 .expect("no tool call authorization event received")
969 .unwrap();
970 match event {
971 ThreadEvent::ToolCall(tool_call) => tool_call,
972 event => {
973 panic!("Unexpected event {event:?}");
974 }
975 }
976}
977
978async fn expect_tool_call_update_fields(
979 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
980) -> acp::ToolCallUpdate {
981 let event = events
982 .next()
983 .await
984 .expect("no tool call authorization event received")
985 .unwrap();
986 match event {
987 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
988 event => {
989 panic!("Unexpected event {event:?}");
990 }
991 }
992}
993
994async fn next_tool_call_authorization(
995 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
996) -> ToolCallAuthorization {
997 loop {
998 let event = events
999 .next()
1000 .await
1001 .expect("no tool call authorization event received")
1002 .unwrap();
1003 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1004 let permission_kinds = tool_call_authorization
1005 .options
1006 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1007 .map(|option| option.kind);
1008 let allow_once = tool_call_authorization
1009 .options
1010 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1011 .map(|option| option.kind);
1012
1013 assert_eq!(
1014 permission_kinds,
1015 Some(acp::PermissionOptionKind::AllowAlways)
1016 );
1017 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1018 return tool_call_authorization;
1019 }
1020 }
1021}
1022
1023#[test]
1024fn test_permission_options_terminal_with_pattern() {
1025 let permission_options = ToolPermissionContext::new(
1026 TerminalTool::NAME,
1027 vec!["cargo build --release".to_string()],
1028 )
1029 .build_permission_options();
1030
1031 let PermissionOptions::Dropdown(choices) = permission_options else {
1032 panic!("Expected dropdown permission options");
1033 };
1034
1035 assert_eq!(choices.len(), 3);
1036 let labels: Vec<&str> = choices
1037 .iter()
1038 .map(|choice| choice.allow.name.as_ref())
1039 .collect();
1040 assert!(labels.contains(&"Always for terminal"));
1041 assert!(labels.contains(&"Always for `cargo build` commands"));
1042 assert!(labels.contains(&"Only this time"));
1043}
1044
1045#[test]
1046fn test_permission_options_terminal_command_with_flag_second_token() {
1047 let permission_options =
1048 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1049 .build_permission_options();
1050
1051 let PermissionOptions::Dropdown(choices) = permission_options else {
1052 panic!("Expected dropdown permission options");
1053 };
1054
1055 assert_eq!(choices.len(), 3);
1056 let labels: Vec<&str> = choices
1057 .iter()
1058 .map(|choice| choice.allow.name.as_ref())
1059 .collect();
1060 assert!(labels.contains(&"Always for terminal"));
1061 assert!(labels.contains(&"Always for `ls` commands"));
1062 assert!(labels.contains(&"Only this time"));
1063}
1064
1065#[test]
1066fn test_permission_options_terminal_single_word_command() {
1067 let permission_options =
1068 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1069 .build_permission_options();
1070
1071 let PermissionOptions::Dropdown(choices) = permission_options else {
1072 panic!("Expected dropdown permission options");
1073 };
1074
1075 assert_eq!(choices.len(), 3);
1076 let labels: Vec<&str> = choices
1077 .iter()
1078 .map(|choice| choice.allow.name.as_ref())
1079 .collect();
1080 assert!(labels.contains(&"Always for terminal"));
1081 assert!(labels.contains(&"Always for `whoami` commands"));
1082 assert!(labels.contains(&"Only this time"));
1083}
1084
1085#[test]
1086fn test_permission_options_edit_file_with_path_pattern() {
1087 let permission_options =
1088 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1089 .build_permission_options();
1090
1091 let PermissionOptions::Dropdown(choices) = permission_options else {
1092 panic!("Expected dropdown permission options");
1093 };
1094
1095 let labels: Vec<&str> = choices
1096 .iter()
1097 .map(|choice| choice.allow.name.as_ref())
1098 .collect();
1099 assert!(labels.contains(&"Always for edit file"));
1100 assert!(labels.contains(&"Always for `src/`"));
1101}
1102
1103#[test]
1104fn test_permission_options_fetch_with_domain_pattern() {
1105 let permission_options =
1106 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1107 .build_permission_options();
1108
1109 let PermissionOptions::Dropdown(choices) = permission_options else {
1110 panic!("Expected dropdown permission options");
1111 };
1112
1113 let labels: Vec<&str> = choices
1114 .iter()
1115 .map(|choice| choice.allow.name.as_ref())
1116 .collect();
1117 assert!(labels.contains(&"Always for fetch"));
1118 assert!(labels.contains(&"Always for `docs.rs`"));
1119}
1120
1121#[test]
1122fn test_permission_options_without_pattern() {
1123 let permission_options = ToolPermissionContext::new(
1124 TerminalTool::NAME,
1125 vec!["./deploy.sh --production".to_string()],
1126 )
1127 .build_permission_options();
1128
1129 let PermissionOptions::Dropdown(choices) = permission_options else {
1130 panic!("Expected dropdown permission options");
1131 };
1132
1133 assert_eq!(choices.len(), 2);
1134 let labels: Vec<&str> = choices
1135 .iter()
1136 .map(|choice| choice.allow.name.as_ref())
1137 .collect();
1138 assert!(labels.contains(&"Always for terminal"));
1139 assert!(labels.contains(&"Only this time"));
1140 assert!(!labels.iter().any(|label| label.contains("commands")));
1141}
1142
1143#[test]
1144fn test_permission_options_symlink_target_are_flat_once_only() {
1145 let permission_options =
1146 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1147 .build_permission_options();
1148
1149 let PermissionOptions::Flat(options) = permission_options else {
1150 panic!("Expected flat permission options for symlink target authorization");
1151 };
1152
1153 assert_eq!(options.len(), 2);
1154 assert!(options.iter().any(|option| {
1155 option.option_id.0.as_ref() == "allow"
1156 && option.kind == acp::PermissionOptionKind::AllowOnce
1157 }));
1158 assert!(options.iter().any(|option| {
1159 option.option_id.0.as_ref() == "deny"
1160 && option.kind == acp::PermissionOptionKind::RejectOnce
1161 }));
1162}
1163
1164#[test]
1165fn test_permission_option_ids_for_terminal() {
1166 let permission_options = ToolPermissionContext::new(
1167 TerminalTool::NAME,
1168 vec!["cargo build --release".to_string()],
1169 )
1170 .build_permission_options();
1171
1172 let PermissionOptions::Dropdown(choices) = permission_options else {
1173 panic!("Expected dropdown permission options");
1174 };
1175
1176 let allow_ids: Vec<String> = choices
1177 .iter()
1178 .map(|choice| choice.allow.option_id.0.to_string())
1179 .collect();
1180 let deny_ids: Vec<String> = choices
1181 .iter()
1182 .map(|choice| choice.deny.option_id.0.to_string())
1183 .collect();
1184
1185 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1186 assert!(allow_ids.contains(&"allow".to_string()));
1187 assert!(
1188 allow_ids
1189 .iter()
1190 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1191 "Missing allow pattern option"
1192 );
1193
1194 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1195 assert!(deny_ids.contains(&"deny".to_string()));
1196 assert!(
1197 deny_ids
1198 .iter()
1199 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1200 "Missing deny pattern option"
1201 );
1202}
1203
1204#[gpui::test]
1205#[cfg_attr(not(feature = "e2e"), ignore)]
1206async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1207 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1208
1209 // Test concurrent tool calls with different delay times
1210 let events = thread
1211 .update(cx, |thread, cx| {
1212 thread.add_tool(DelayTool, None);
1213 thread.send(
1214 UserMessageId::new(),
1215 [
1216 "Call the delay tool twice in the same message.",
1217 "Once with 100ms. Once with 300ms.",
1218 "When both timers are complete, describe the outputs.",
1219 ],
1220 cx,
1221 )
1222 })
1223 .unwrap()
1224 .collect()
1225 .await;
1226
1227 let stop_reasons = stop_events(events);
1228 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1229
1230 thread.update(cx, |thread, _cx| {
1231 let last_message = thread.last_message().unwrap();
1232 let agent_message = last_message.as_agent_message().unwrap();
1233 let text = agent_message
1234 .content
1235 .iter()
1236 .filter_map(|content| {
1237 if let AgentMessageContent::Text(text) = content {
1238 Some(text.as_str())
1239 } else {
1240 None
1241 }
1242 })
1243 .collect::<String>();
1244
1245 assert!(text.contains("Ding"));
1246 });
1247}
1248
1249#[gpui::test]
1250async fn test_profiles(cx: &mut TestAppContext) {
1251 let ThreadTest {
1252 model, thread, fs, ..
1253 } = setup(cx, TestModel::Fake).await;
1254 let fake_model = model.as_fake();
1255
1256 thread.update(cx, |thread, _cx| {
1257 thread.add_tool(DelayTool, None);
1258 thread.add_tool(EchoTool, None);
1259 thread.add_tool(InfiniteTool, None);
1260 });
1261
1262 // Override profiles and wait for settings to be loaded.
1263 fs.insert_file(
1264 paths::settings_file(),
1265 json!({
1266 "agent": {
1267 "profiles": {
1268 "test-1": {
1269 "name": "Test Profile 1",
1270 "tools": {
1271 EchoTool::NAME: true,
1272 DelayTool::NAME: true,
1273 }
1274 },
1275 "test-2": {
1276 "name": "Test Profile 2",
1277 "tools": {
1278 InfiniteTool::NAME: true,
1279 }
1280 }
1281 }
1282 }
1283 })
1284 .to_string()
1285 .into_bytes(),
1286 )
1287 .await;
1288 cx.run_until_parked();
1289
1290 // Test that test-1 profile (default) has echo and delay tools
1291 thread
1292 .update(cx, |thread, cx| {
1293 thread.set_profile(AgentProfileId("test-1".into()), cx);
1294 thread.send(UserMessageId::new(), ["test"], cx)
1295 })
1296 .unwrap();
1297 cx.run_until_parked();
1298
1299 let mut pending_completions = fake_model.pending_completions();
1300 assert_eq!(pending_completions.len(), 1);
1301 let completion = pending_completions.pop().unwrap();
1302 let tool_names: Vec<String> = completion
1303 .tools
1304 .iter()
1305 .map(|tool| tool.name.clone())
1306 .collect();
1307 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1308 fake_model.end_last_completion_stream();
1309
1310 // Switch to test-2 profile, and verify that it has only the infinite tool.
1311 thread
1312 .update(cx, |thread, cx| {
1313 thread.set_profile(AgentProfileId("test-2".into()), cx);
1314 thread.send(UserMessageId::new(), ["test2"], cx)
1315 })
1316 .unwrap();
1317 cx.run_until_parked();
1318 let mut pending_completions = fake_model.pending_completions();
1319 assert_eq!(pending_completions.len(), 1);
1320 let completion = pending_completions.pop().unwrap();
1321 let tool_names: Vec<String> = completion
1322 .tools
1323 .iter()
1324 .map(|tool| tool.name.clone())
1325 .collect();
1326 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1327}
1328
1329#[gpui::test]
1330async fn test_mcp_tools(cx: &mut TestAppContext) {
1331 let ThreadTest {
1332 model,
1333 thread,
1334 context_server_store,
1335 fs,
1336 ..
1337 } = setup(cx, TestModel::Fake).await;
1338 let fake_model = model.as_fake();
1339
1340 // Override profiles and wait for settings to be loaded.
1341 fs.insert_file(
1342 paths::settings_file(),
1343 json!({
1344 "agent": {
1345 "tool_permissions": { "default": "allow" },
1346 "profiles": {
1347 "test": {
1348 "name": "Test Profile",
1349 "enable_all_context_servers": true,
1350 "tools": {
1351 EchoTool::NAME: true,
1352 }
1353 },
1354 }
1355 }
1356 })
1357 .to_string()
1358 .into_bytes(),
1359 )
1360 .await;
1361 cx.run_until_parked();
1362 thread.update(cx, |thread, cx| {
1363 thread.set_profile(AgentProfileId("test".into()), cx)
1364 });
1365
1366 let mut mcp_tool_calls = setup_context_server(
1367 "test_server",
1368 vec![context_server::types::Tool {
1369 name: "echo".into(),
1370 description: None,
1371 input_schema: serde_json::to_value(EchoTool::input_schema(
1372 LanguageModelToolSchemaFormat::JsonSchema,
1373 ))
1374 .unwrap(),
1375 output_schema: None,
1376 annotations: None,
1377 }],
1378 &context_server_store,
1379 cx,
1380 );
1381
1382 let events = thread.update(cx, |thread, cx| {
1383 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1384 });
1385 cx.run_until_parked();
1386
1387 // Simulate the model calling the MCP tool.
1388 let completion = fake_model.pending_completions().pop().unwrap();
1389 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1390 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1391 LanguageModelToolUse {
1392 id: "tool_1".into(),
1393 name: "echo".into(),
1394 raw_input: json!({"text": "test"}).to_string(),
1395 input: json!({"text": "test"}),
1396 is_input_complete: true,
1397 thought_signature: None,
1398 },
1399 ));
1400 fake_model.end_last_completion_stream();
1401 cx.run_until_parked();
1402
1403 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1404 assert_eq!(tool_call_params.name, "echo");
1405 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1406 tool_call_response
1407 .send(context_server::types::CallToolResponse {
1408 content: vec![context_server::types::ToolResponseContent::Text {
1409 text: "test".into(),
1410 }],
1411 is_error: None,
1412 meta: None,
1413 structured_content: None,
1414 })
1415 .unwrap();
1416 cx.run_until_parked();
1417
1418 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1419 fake_model.send_last_completion_stream_text_chunk("Done!");
1420 fake_model.end_last_completion_stream();
1421 events.collect::<Vec<_>>().await;
1422
1423 // Send again after adding the echo tool, ensuring the name collision is resolved.
1424 let events = thread.update(cx, |thread, cx| {
1425 thread.add_tool(EchoTool, None);
1426 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1427 });
1428 cx.run_until_parked();
1429 let completion = fake_model.pending_completions().pop().unwrap();
1430 assert_eq!(
1431 tool_names_for_completion(&completion),
1432 vec!["echo", "test_server_echo"]
1433 );
1434 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1435 LanguageModelToolUse {
1436 id: "tool_2".into(),
1437 name: "test_server_echo".into(),
1438 raw_input: json!({"text": "mcp"}).to_string(),
1439 input: json!({"text": "mcp"}),
1440 is_input_complete: true,
1441 thought_signature: None,
1442 },
1443 ));
1444 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1445 LanguageModelToolUse {
1446 id: "tool_3".into(),
1447 name: "echo".into(),
1448 raw_input: json!({"text": "native"}).to_string(),
1449 input: json!({"text": "native"}),
1450 is_input_complete: true,
1451 thought_signature: None,
1452 },
1453 ));
1454 fake_model.end_last_completion_stream();
1455 cx.run_until_parked();
1456
1457 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1458 assert_eq!(tool_call_params.name, "echo");
1459 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1460 tool_call_response
1461 .send(context_server::types::CallToolResponse {
1462 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1463 is_error: None,
1464 meta: None,
1465 structured_content: None,
1466 })
1467 .unwrap();
1468 cx.run_until_parked();
1469
1470 // Ensure the tool results were inserted with the correct names.
1471 let completion = fake_model.pending_completions().pop().unwrap();
1472 assert_eq!(
1473 completion.messages.last().unwrap().content,
1474 vec![
1475 MessageContent::ToolResult(LanguageModelToolResult {
1476 tool_use_id: "tool_3".into(),
1477 tool_name: "echo".into(),
1478 is_error: false,
1479 content: "native".into(),
1480 output: Some("native".into()),
1481 },),
1482 MessageContent::ToolResult(LanguageModelToolResult {
1483 tool_use_id: "tool_2".into(),
1484 tool_name: "test_server_echo".into(),
1485 is_error: false,
1486 content: "mcp".into(),
1487 output: Some("mcp".into()),
1488 },),
1489 ]
1490 );
1491 fake_model.end_last_completion_stream();
1492 events.collect::<Vec<_>>().await;
1493}
1494
1495#[gpui::test]
1496async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1497 let ThreadTest {
1498 model,
1499 thread,
1500 context_server_store,
1501 fs,
1502 ..
1503 } = setup(cx, TestModel::Fake).await;
1504 let fake_model = model.as_fake();
1505
1506 // Setup settings to allow MCP tools
1507 fs.insert_file(
1508 paths::settings_file(),
1509 json!({
1510 "agent": {
1511 "always_allow_tool_actions": true,
1512 "profiles": {
1513 "test": {
1514 "name": "Test Profile",
1515 "enable_all_context_servers": true,
1516 "tools": {}
1517 },
1518 }
1519 }
1520 })
1521 .to_string()
1522 .into_bytes(),
1523 )
1524 .await;
1525 cx.run_until_parked();
1526 thread.update(cx, |thread, cx| {
1527 thread.set_profile(AgentProfileId("test".into()), cx)
1528 });
1529
1530 // Setup a context server with a tool
1531 let mut mcp_tool_calls = setup_context_server(
1532 "github_server",
1533 vec![context_server::types::Tool {
1534 name: "issue_read".into(),
1535 description: Some("Read a GitHub issue".into()),
1536 input_schema: json!({
1537 "type": "object",
1538 "properties": {
1539 "issue_url": { "type": "string" }
1540 }
1541 }),
1542 output_schema: None,
1543 annotations: None,
1544 }],
1545 &context_server_store,
1546 cx,
1547 );
1548
1549 // Send a message and have the model call the MCP tool
1550 let events = thread.update(cx, |thread, cx| {
1551 thread
1552 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1553 .unwrap()
1554 });
1555 cx.run_until_parked();
1556
1557 // Verify the MCP tool is available to the model
1558 let completion = fake_model.pending_completions().pop().unwrap();
1559 assert_eq!(
1560 tool_names_for_completion(&completion),
1561 vec!["issue_read"],
1562 "MCP tool should be available"
1563 );
1564
1565 // Simulate the model calling the MCP tool
1566 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1567 LanguageModelToolUse {
1568 id: "tool_1".into(),
1569 name: "issue_read".into(),
1570 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1571 .to_string(),
1572 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1573 is_input_complete: true,
1574 thought_signature: None,
1575 },
1576 ));
1577 fake_model.end_last_completion_stream();
1578 cx.run_until_parked();
1579
1580 // The MCP server receives the tool call and responds with content
1581 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1582 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1583 assert_eq!(tool_call_params.name, "issue_read");
1584 tool_call_response
1585 .send(context_server::types::CallToolResponse {
1586 content: vec![context_server::types::ToolResponseContent::Text {
1587 text: expected_tool_output.into(),
1588 }],
1589 is_error: None,
1590 meta: None,
1591 structured_content: None,
1592 })
1593 .unwrap();
1594 cx.run_until_parked();
1595
1596 // After tool completes, the model continues with a new completion request
1597 // that includes the tool results. We need to respond to this.
1598 let _completion = fake_model.pending_completions().pop().unwrap();
1599 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1600 fake_model
1601 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1602 fake_model.end_last_completion_stream();
1603 events.collect::<Vec<_>>().await;
1604
1605 // Verify the tool result is stored in the thread by checking the markdown output.
1606 // The tool result is in the first assistant message (not the last one, which is
1607 // the model's response after the tool completed).
1608 thread.update(cx, |thread, _cx| {
1609 let markdown = thread.to_markdown();
1610 assert!(
1611 markdown.contains("**Tool Result**: issue_read"),
1612 "Thread should contain tool result header"
1613 );
1614 assert!(
1615 markdown.contains(expected_tool_output),
1616 "Thread should contain tool output: {}",
1617 expected_tool_output
1618 );
1619 });
1620
1621 // Simulate app restart: disconnect the MCP server.
1622 // After restart, the MCP server won't be connected yet when the thread is replayed.
1623 context_server_store.update(cx, |store, cx| {
1624 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1625 });
1626 cx.run_until_parked();
1627
1628 // Replay the thread (this is what happens when loading a saved thread)
1629 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1630
1631 let mut found_tool_call = None;
1632 let mut found_tool_call_update_with_output = None;
1633
1634 while let Some(event) = replay_events.next().await {
1635 let event = event.unwrap();
1636 match &event {
1637 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1638 found_tool_call = Some(tc.clone());
1639 }
1640 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1641 if update.tool_call_id.to_string() == "tool_1" =>
1642 {
1643 if update.fields.raw_output.is_some() {
1644 found_tool_call_update_with_output = Some(update.clone());
1645 }
1646 }
1647 _ => {}
1648 }
1649 }
1650
1651 // The tool call should be found
1652 assert!(
1653 found_tool_call.is_some(),
1654 "Tool call should be emitted during replay"
1655 );
1656
1657 assert!(
1658 found_tool_call_update_with_output.is_some(),
1659 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1660 );
1661
1662 let update = found_tool_call_update_with_output.unwrap();
1663 assert_eq!(
1664 update.fields.raw_output,
1665 Some(expected_tool_output.into()),
1666 "raw_output should contain the saved tool result"
1667 );
1668
1669 // Also verify the status is correct (completed, not failed)
1670 assert_eq!(
1671 update.fields.status,
1672 Some(acp::ToolCallStatus::Completed),
1673 "Tool call status should reflect the original completion status"
1674 );
1675}
1676
1677#[gpui::test]
1678async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1679 let ThreadTest {
1680 model,
1681 thread,
1682 context_server_store,
1683 fs,
1684 ..
1685 } = setup(cx, TestModel::Fake).await;
1686 let fake_model = model.as_fake();
1687
1688 // Set up a profile with all tools enabled
1689 fs.insert_file(
1690 paths::settings_file(),
1691 json!({
1692 "agent": {
1693 "profiles": {
1694 "test": {
1695 "name": "Test Profile",
1696 "enable_all_context_servers": true,
1697 "tools": {
1698 EchoTool::NAME: true,
1699 DelayTool::NAME: true,
1700 WordListTool::NAME: true,
1701 ToolRequiringPermission::NAME: true,
1702 InfiniteTool::NAME: true,
1703 }
1704 },
1705 }
1706 }
1707 })
1708 .to_string()
1709 .into_bytes(),
1710 )
1711 .await;
1712 cx.run_until_parked();
1713
1714 thread.update(cx, |thread, cx| {
1715 thread.set_profile(AgentProfileId("test".into()), cx);
1716 thread.add_tool(EchoTool, None);
1717 thread.add_tool(DelayTool, None);
1718 thread.add_tool(WordListTool, None);
1719 thread.add_tool(ToolRequiringPermission, None);
1720 thread.add_tool(InfiniteTool, None);
1721 });
1722
1723 // Set up multiple context servers with some overlapping tool names
1724 let _server1_calls = setup_context_server(
1725 "xxx",
1726 vec![
1727 context_server::types::Tool {
1728 name: "echo".into(), // Conflicts with native EchoTool
1729 description: None,
1730 input_schema: serde_json::to_value(EchoTool::input_schema(
1731 LanguageModelToolSchemaFormat::JsonSchema,
1732 ))
1733 .unwrap(),
1734 output_schema: None,
1735 annotations: None,
1736 },
1737 context_server::types::Tool {
1738 name: "unique_tool_1".into(),
1739 description: None,
1740 input_schema: json!({"type": "object", "properties": {}}),
1741 output_schema: None,
1742 annotations: None,
1743 },
1744 ],
1745 &context_server_store,
1746 cx,
1747 );
1748
1749 let _server2_calls = setup_context_server(
1750 "yyy",
1751 vec![
1752 context_server::types::Tool {
1753 name: "echo".into(), // Also conflicts with native EchoTool
1754 description: None,
1755 input_schema: serde_json::to_value(EchoTool::input_schema(
1756 LanguageModelToolSchemaFormat::JsonSchema,
1757 ))
1758 .unwrap(),
1759 output_schema: None,
1760 annotations: None,
1761 },
1762 context_server::types::Tool {
1763 name: "unique_tool_2".into(),
1764 description: None,
1765 input_schema: json!({"type": "object", "properties": {}}),
1766 output_schema: None,
1767 annotations: None,
1768 },
1769 context_server::types::Tool {
1770 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1771 description: None,
1772 input_schema: json!({"type": "object", "properties": {}}),
1773 output_schema: None,
1774 annotations: None,
1775 },
1776 context_server::types::Tool {
1777 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1778 description: None,
1779 input_schema: json!({"type": "object", "properties": {}}),
1780 output_schema: None,
1781 annotations: None,
1782 },
1783 ],
1784 &context_server_store,
1785 cx,
1786 );
1787 let _server3_calls = setup_context_server(
1788 "zzz",
1789 vec![
1790 context_server::types::Tool {
1791 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1792 description: None,
1793 input_schema: json!({"type": "object", "properties": {}}),
1794 output_schema: None,
1795 annotations: None,
1796 },
1797 context_server::types::Tool {
1798 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1799 description: None,
1800 input_schema: json!({"type": "object", "properties": {}}),
1801 output_schema: None,
1802 annotations: None,
1803 },
1804 context_server::types::Tool {
1805 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1806 description: None,
1807 input_schema: json!({"type": "object", "properties": {}}),
1808 output_schema: None,
1809 annotations: None,
1810 },
1811 ],
1812 &context_server_store,
1813 cx,
1814 );
1815
1816 // Server with spaces in name - tests snake_case conversion for API compatibility
1817 let _server4_calls = setup_context_server(
1818 "Azure DevOps",
1819 vec![context_server::types::Tool {
1820 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1821 description: None,
1822 input_schema: serde_json::to_value(EchoTool::input_schema(
1823 LanguageModelToolSchemaFormat::JsonSchema,
1824 ))
1825 .unwrap(),
1826 output_schema: None,
1827 annotations: None,
1828 }],
1829 &context_server_store,
1830 cx,
1831 );
1832
1833 thread
1834 .update(cx, |thread, cx| {
1835 thread.send(UserMessageId::new(), ["Go"], cx)
1836 })
1837 .unwrap();
1838 cx.run_until_parked();
1839 let completion = fake_model.pending_completions().pop().unwrap();
1840 assert_eq!(
1841 tool_names_for_completion(&completion),
1842 vec![
1843 "azure_dev_ops_echo",
1844 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1845 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1846 "delay",
1847 "echo",
1848 "infinite",
1849 "tool_requiring_permission",
1850 "unique_tool_1",
1851 "unique_tool_2",
1852 "word_list",
1853 "xxx_echo",
1854 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1855 "yyy_echo",
1856 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1857 ]
1858 );
1859}
1860
1861#[gpui::test]
1862#[cfg_attr(not(feature = "e2e"), ignore)]
1863async fn test_cancellation(cx: &mut TestAppContext) {
1864 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1865
1866 let mut events = thread
1867 .update(cx, |thread, cx| {
1868 thread.add_tool(InfiniteTool, None);
1869 thread.add_tool(EchoTool, None);
1870 thread.send(
1871 UserMessageId::new(),
1872 ["Call the echo tool, then call the infinite tool, then explain their output"],
1873 cx,
1874 )
1875 })
1876 .unwrap();
1877
1878 // Wait until both tools are called.
1879 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1880 let mut echo_id = None;
1881 let mut echo_completed = false;
1882 while let Some(event) = events.next().await {
1883 match event.unwrap() {
1884 ThreadEvent::ToolCall(tool_call) => {
1885 assert_eq!(tool_call.title, expected_tools.remove(0));
1886 if tool_call.title == "Echo" {
1887 echo_id = Some(tool_call.tool_call_id);
1888 }
1889 }
1890 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1891 acp::ToolCallUpdate {
1892 tool_call_id,
1893 fields:
1894 acp::ToolCallUpdateFields {
1895 status: Some(acp::ToolCallStatus::Completed),
1896 ..
1897 },
1898 ..
1899 },
1900 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1901 echo_completed = true;
1902 }
1903 _ => {}
1904 }
1905
1906 if expected_tools.is_empty() && echo_completed {
1907 break;
1908 }
1909 }
1910
1911 // Cancel the current send and ensure that the event stream is closed, even
1912 // if one of the tools is still running.
1913 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1914 let events = events.collect::<Vec<_>>().await;
1915 let last_event = events.last();
1916 assert!(
1917 matches!(
1918 last_event,
1919 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1920 ),
1921 "unexpected event {last_event:?}"
1922 );
1923
1924 // Ensure we can still send a new message after cancellation.
1925 let events = thread
1926 .update(cx, |thread, cx| {
1927 thread.send(
1928 UserMessageId::new(),
1929 ["Testing: reply with 'Hello' then stop."],
1930 cx,
1931 )
1932 })
1933 .unwrap()
1934 .collect::<Vec<_>>()
1935 .await;
1936 thread.update(cx, |thread, _cx| {
1937 let message = thread.last_message().unwrap();
1938 let agent_message = message.as_agent_message().unwrap();
1939 assert_eq!(
1940 agent_message.content,
1941 vec![AgentMessageContent::Text("Hello".to_string())]
1942 );
1943 });
1944 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1945}
1946
1947#[gpui::test]
1948async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1949 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1950 always_allow_tools(cx);
1951 let fake_model = model.as_fake();
1952
1953 let environment = Rc::new(cx.update(|cx| {
1954 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1955 }));
1956 let handle = environment.terminal_handle.clone().unwrap();
1957
1958 let mut events = thread
1959 .update(cx, |thread, cx| {
1960 thread.add_tool(
1961 crate::TerminalTool::new(thread.project().clone(), environment),
1962 None,
1963 );
1964 thread.send(UserMessageId::new(), ["run a command"], cx)
1965 })
1966 .unwrap();
1967
1968 cx.run_until_parked();
1969
1970 // Simulate the model calling the terminal tool
1971 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1972 LanguageModelToolUse {
1973 id: "terminal_tool_1".into(),
1974 name: TerminalTool::NAME.into(),
1975 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1976 input: json!({"command": "sleep 1000", "cd": "."}),
1977 is_input_complete: true,
1978 thought_signature: None,
1979 },
1980 ));
1981 fake_model.end_last_completion_stream();
1982
1983 // Wait for the terminal tool to start running
1984 wait_for_terminal_tool_started(&mut events, cx).await;
1985
1986 // Cancel the thread while the terminal is running
1987 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1988
1989 // Collect remaining events, driving the executor to let cancellation complete
1990 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1991
1992 // Verify the terminal was killed
1993 assert!(
1994 handle.was_killed(),
1995 "expected terminal handle to be killed on cancellation"
1996 );
1997
1998 // Verify we got a cancellation stop event
1999 assert_eq!(
2000 stop_events(remaining_events),
2001 vec![acp::StopReason::Cancelled],
2002 );
2003
2004 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2005 thread.update(cx, |thread, _cx| {
2006 let message = thread.last_message().unwrap();
2007 let agent_message = message.as_agent_message().unwrap();
2008
2009 let tool_use = agent_message
2010 .content
2011 .iter()
2012 .find_map(|content| match content {
2013 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2014 _ => None,
2015 })
2016 .expect("expected tool use in agent message");
2017
2018 let tool_result = agent_message
2019 .tool_results
2020 .get(&tool_use.id)
2021 .expect("expected tool result");
2022
2023 let result_text = match &tool_result.content {
2024 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2025 _ => panic!("expected text content in tool result"),
2026 };
2027
2028 // "partial output" comes from FakeTerminalHandle's output field
2029 assert!(
2030 result_text.contains("partial output"),
2031 "expected tool result to contain terminal output, got: {result_text}"
2032 );
2033 // Match the actual format from process_content in terminal_tool.rs
2034 assert!(
2035 result_text.contains("The user stopped this command"),
2036 "expected tool result to indicate user stopped, got: {result_text}"
2037 );
2038 });
2039
2040 // Verify we can send a new message after cancellation
2041 verify_thread_recovery(&thread, &fake_model, cx).await;
2042}
2043
2044#[gpui::test]
2045async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2046 // This test verifies that tools which properly handle cancellation via
2047 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2048 // to cancellation and report that they were cancelled.
2049 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2050 always_allow_tools(cx);
2051 let fake_model = model.as_fake();
2052
2053 let (tool, was_cancelled) = CancellationAwareTool::new();
2054
2055 let mut events = thread
2056 .update(cx, |thread, cx| {
2057 thread.add_tool(tool, None);
2058 thread.send(
2059 UserMessageId::new(),
2060 ["call the cancellation aware tool"],
2061 cx,
2062 )
2063 })
2064 .unwrap();
2065
2066 cx.run_until_parked();
2067
2068 // Simulate the model calling the cancellation-aware tool
2069 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2070 LanguageModelToolUse {
2071 id: "cancellation_aware_1".into(),
2072 name: "cancellation_aware".into(),
2073 raw_input: r#"{}"#.into(),
2074 input: json!({}),
2075 is_input_complete: true,
2076 thought_signature: None,
2077 },
2078 ));
2079 fake_model.end_last_completion_stream();
2080
2081 cx.run_until_parked();
2082
2083 // Wait for the tool call to be reported
2084 let mut tool_started = false;
2085 let deadline = cx.executor().num_cpus() * 100;
2086 for _ in 0..deadline {
2087 cx.run_until_parked();
2088
2089 while let Some(Some(event)) = events.next().now_or_never() {
2090 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2091 if tool_call.title == "Cancellation Aware Tool" {
2092 tool_started = true;
2093 break;
2094 }
2095 }
2096 }
2097
2098 if tool_started {
2099 break;
2100 }
2101
2102 cx.background_executor
2103 .timer(Duration::from_millis(10))
2104 .await;
2105 }
2106 assert!(tool_started, "expected cancellation aware tool to start");
2107
2108 // Cancel the thread and wait for it to complete
2109 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2110
2111 // The cancel task should complete promptly because the tool handles cancellation
2112 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2113 futures::select! {
2114 _ = cancel_task.fuse() => {}
2115 _ = timeout.fuse() => {
2116 panic!("cancel task timed out - tool did not respond to cancellation");
2117 }
2118 }
2119
2120 // Verify the tool detected cancellation via its flag
2121 assert!(
2122 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2123 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2124 );
2125
2126 // Collect remaining events
2127 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2128
2129 // Verify we got a cancellation stop event
2130 assert_eq!(
2131 stop_events(remaining_events),
2132 vec![acp::StopReason::Cancelled],
2133 );
2134
2135 // Verify we can send a new message after cancellation
2136 verify_thread_recovery(&thread, &fake_model, cx).await;
2137}
2138
2139/// Helper to verify thread can recover after cancellation by sending a simple message.
2140async fn verify_thread_recovery(
2141 thread: &Entity<Thread>,
2142 fake_model: &FakeLanguageModel,
2143 cx: &mut TestAppContext,
2144) {
2145 let events = thread
2146 .update(cx, |thread, cx| {
2147 thread.send(
2148 UserMessageId::new(),
2149 ["Testing: reply with 'Hello' then stop."],
2150 cx,
2151 )
2152 })
2153 .unwrap();
2154 cx.run_until_parked();
2155 fake_model.send_last_completion_stream_text_chunk("Hello");
2156 fake_model
2157 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2158 fake_model.end_last_completion_stream();
2159
2160 let events = events.collect::<Vec<_>>().await;
2161 thread.update(cx, |thread, _cx| {
2162 let message = thread.last_message().unwrap();
2163 let agent_message = message.as_agent_message().unwrap();
2164 assert_eq!(
2165 agent_message.content,
2166 vec![AgentMessageContent::Text("Hello".to_string())]
2167 );
2168 });
2169 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2170}
2171
2172/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2173async fn wait_for_terminal_tool_started(
2174 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2175 cx: &mut TestAppContext,
2176) {
2177 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2178 for _ in 0..deadline {
2179 cx.run_until_parked();
2180
2181 while let Some(Some(event)) = events.next().now_or_never() {
2182 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2183 update,
2184 ))) = &event
2185 {
2186 if update.fields.content.as_ref().is_some_and(|content| {
2187 content
2188 .iter()
2189 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2190 }) {
2191 return;
2192 }
2193 }
2194 }
2195
2196 cx.background_executor
2197 .timer(Duration::from_millis(10))
2198 .await;
2199 }
2200 panic!("terminal tool did not start within the expected time");
2201}
2202
2203/// Collects events until a Stop event is received, driving the executor to completion.
2204async fn collect_events_until_stop(
2205 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2206 cx: &mut TestAppContext,
2207) -> Vec<Result<ThreadEvent>> {
2208 let mut collected = Vec::new();
2209 let deadline = cx.executor().num_cpus() * 200;
2210
2211 for _ in 0..deadline {
2212 cx.executor().advance_clock(Duration::from_millis(10));
2213 cx.run_until_parked();
2214
2215 while let Some(Some(event)) = events.next().now_or_never() {
2216 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2217 collected.push(event);
2218 if is_stop {
2219 return collected;
2220 }
2221 }
2222 }
2223 panic!(
2224 "did not receive Stop event within the expected time; collected {} events",
2225 collected.len()
2226 );
2227}
2228
2229#[gpui::test]
2230async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2231 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2232 always_allow_tools(cx);
2233 let fake_model = model.as_fake();
2234
2235 let environment = Rc::new(cx.update(|cx| {
2236 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2237 }));
2238 let handle = environment.terminal_handle.clone().unwrap();
2239
2240 let message_id = UserMessageId::new();
2241 let mut events = thread
2242 .update(cx, |thread, cx| {
2243 thread.add_tool(
2244 crate::TerminalTool::new(thread.project().clone(), environment),
2245 None,
2246 );
2247 thread.send(message_id.clone(), ["run a command"], cx)
2248 })
2249 .unwrap();
2250
2251 cx.run_until_parked();
2252
2253 // Simulate the model calling the terminal tool
2254 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2255 LanguageModelToolUse {
2256 id: "terminal_tool_1".into(),
2257 name: TerminalTool::NAME.into(),
2258 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2259 input: json!({"command": "sleep 1000", "cd": "."}),
2260 is_input_complete: true,
2261 thought_signature: None,
2262 },
2263 ));
2264 fake_model.end_last_completion_stream();
2265
2266 // Wait for the terminal tool to start running
2267 wait_for_terminal_tool_started(&mut events, cx).await;
2268
2269 // Truncate the thread while the terminal is running
2270 thread
2271 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2272 .unwrap();
2273
2274 // Drive the executor to let cancellation complete
2275 let _ = collect_events_until_stop(&mut events, cx).await;
2276
2277 // Verify the terminal was killed
2278 assert!(
2279 handle.was_killed(),
2280 "expected terminal handle to be killed on truncate"
2281 );
2282
2283 // Verify the thread is empty after truncation
2284 thread.update(cx, |thread, _cx| {
2285 assert_eq!(
2286 thread.to_markdown(),
2287 "",
2288 "expected thread to be empty after truncating the only message"
2289 );
2290 });
2291
2292 // Verify we can send a new message after truncation
2293 verify_thread_recovery(&thread, &fake_model, cx).await;
2294}
2295
2296#[gpui::test]
2297async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2298 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2299 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2300 always_allow_tools(cx);
2301 let fake_model = model.as_fake();
2302
2303 let environment = Rc::new(MultiTerminalEnvironment::new());
2304
2305 let mut events = thread
2306 .update(cx, |thread, cx| {
2307 thread.add_tool(
2308 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2309 None,
2310 );
2311 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2312 })
2313 .unwrap();
2314
2315 cx.run_until_parked();
2316
2317 // Simulate the model calling two terminal tools
2318 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2319 LanguageModelToolUse {
2320 id: "terminal_tool_1".into(),
2321 name: TerminalTool::NAME.into(),
2322 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2323 input: json!({"command": "sleep 1000", "cd": "."}),
2324 is_input_complete: true,
2325 thought_signature: None,
2326 },
2327 ));
2328 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2329 LanguageModelToolUse {
2330 id: "terminal_tool_2".into(),
2331 name: TerminalTool::NAME.into(),
2332 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2333 input: json!({"command": "sleep 2000", "cd": "."}),
2334 is_input_complete: true,
2335 thought_signature: None,
2336 },
2337 ));
2338 fake_model.end_last_completion_stream();
2339
2340 // Wait for both terminal tools to start by counting terminal content updates
2341 let mut terminals_started = 0;
2342 let deadline = cx.executor().num_cpus() * 100;
2343 for _ in 0..deadline {
2344 cx.run_until_parked();
2345
2346 while let Some(Some(event)) = events.next().now_or_never() {
2347 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2348 update,
2349 ))) = &event
2350 {
2351 if update.fields.content.as_ref().is_some_and(|content| {
2352 content
2353 .iter()
2354 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2355 }) {
2356 terminals_started += 1;
2357 if terminals_started >= 2 {
2358 break;
2359 }
2360 }
2361 }
2362 }
2363 if terminals_started >= 2 {
2364 break;
2365 }
2366
2367 cx.background_executor
2368 .timer(Duration::from_millis(10))
2369 .await;
2370 }
2371 assert!(
2372 terminals_started >= 2,
2373 "expected 2 terminal tools to start, got {terminals_started}"
2374 );
2375
2376 // Cancel the thread while both terminals are running
2377 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2378
2379 // Collect remaining events
2380 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2381
2382 // Verify both terminal handles were killed
2383 let handles = environment.handles();
2384 assert_eq!(
2385 handles.len(),
2386 2,
2387 "expected 2 terminal handles to be created"
2388 );
2389 assert!(
2390 handles[0].was_killed(),
2391 "expected first terminal handle to be killed on cancellation"
2392 );
2393 assert!(
2394 handles[1].was_killed(),
2395 "expected second terminal handle to be killed on cancellation"
2396 );
2397
2398 // Verify we got a cancellation stop event
2399 assert_eq!(
2400 stop_events(remaining_events),
2401 vec![acp::StopReason::Cancelled],
2402 );
2403}
2404
2405#[gpui::test]
2406async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2407 // Tests that clicking the stop button on the terminal card (as opposed to the main
2408 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2409 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2410 always_allow_tools(cx);
2411 let fake_model = model.as_fake();
2412
2413 let environment = Rc::new(cx.update(|cx| {
2414 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2415 }));
2416 let handle = environment.terminal_handle.clone().unwrap();
2417
2418 let mut events = thread
2419 .update(cx, |thread, cx| {
2420 thread.add_tool(
2421 crate::TerminalTool::new(thread.project().clone(), environment),
2422 None,
2423 );
2424 thread.send(UserMessageId::new(), ["run a command"], cx)
2425 })
2426 .unwrap();
2427
2428 cx.run_until_parked();
2429
2430 // Simulate the model calling the terminal tool
2431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2432 LanguageModelToolUse {
2433 id: "terminal_tool_1".into(),
2434 name: TerminalTool::NAME.into(),
2435 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2436 input: json!({"command": "sleep 1000", "cd": "."}),
2437 is_input_complete: true,
2438 thought_signature: None,
2439 },
2440 ));
2441 fake_model.end_last_completion_stream();
2442
2443 // Wait for the terminal tool to start running
2444 wait_for_terminal_tool_started(&mut events, cx).await;
2445
2446 // Simulate user clicking stop on the terminal card itself.
2447 // This sets the flag and signals exit (simulating what the real UI would do).
2448 handle.set_stopped_by_user(true);
2449 handle.killed.store(true, Ordering::SeqCst);
2450 handle.signal_exit();
2451
2452 // Wait for the tool to complete
2453 cx.run_until_parked();
2454
2455 // The thread continues after tool completion - simulate the model ending its turn
2456 fake_model
2457 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2458 fake_model.end_last_completion_stream();
2459
2460 // Collect remaining events
2461 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2462
2463 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2464 assert_eq!(
2465 stop_events(remaining_events),
2466 vec![acp::StopReason::EndTurn],
2467 );
2468
2469 // Verify the tool result indicates user stopped
2470 thread.update(cx, |thread, _cx| {
2471 let message = thread.last_message().unwrap();
2472 let agent_message = message.as_agent_message().unwrap();
2473
2474 let tool_use = agent_message
2475 .content
2476 .iter()
2477 .find_map(|content| match content {
2478 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2479 _ => None,
2480 })
2481 .expect("expected tool use in agent message");
2482
2483 let tool_result = agent_message
2484 .tool_results
2485 .get(&tool_use.id)
2486 .expect("expected tool result");
2487
2488 let result_text = match &tool_result.content {
2489 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2490 _ => panic!("expected text content in tool result"),
2491 };
2492
2493 assert!(
2494 result_text.contains("The user stopped this command"),
2495 "expected tool result to indicate user stopped, got: {result_text}"
2496 );
2497 });
2498}
2499
2500#[gpui::test]
2501async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2502 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2503 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2504 always_allow_tools(cx);
2505 let fake_model = model.as_fake();
2506
2507 let environment = Rc::new(cx.update(|cx| {
2508 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2509 }));
2510 let handle = environment.terminal_handle.clone().unwrap();
2511
2512 let mut events = thread
2513 .update(cx, |thread, cx| {
2514 thread.add_tool(
2515 crate::TerminalTool::new(thread.project().clone(), environment),
2516 None,
2517 );
2518 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2519 })
2520 .unwrap();
2521
2522 cx.run_until_parked();
2523
2524 // Simulate the model calling the terminal tool with a short timeout
2525 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2526 LanguageModelToolUse {
2527 id: "terminal_tool_1".into(),
2528 name: TerminalTool::NAME.into(),
2529 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2530 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2531 is_input_complete: true,
2532 thought_signature: None,
2533 },
2534 ));
2535 fake_model.end_last_completion_stream();
2536
2537 // Wait for the terminal tool to start running
2538 wait_for_terminal_tool_started(&mut events, cx).await;
2539
2540 // Advance clock past the timeout
2541 cx.executor().advance_clock(Duration::from_millis(200));
2542 cx.run_until_parked();
2543
2544 // The thread continues after tool completion - simulate the model ending its turn
2545 fake_model
2546 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2547 fake_model.end_last_completion_stream();
2548
2549 // Collect remaining events
2550 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2551
2552 // Verify the terminal was killed due to timeout
2553 assert!(
2554 handle.was_killed(),
2555 "expected terminal handle to be killed on timeout"
2556 );
2557
2558 // Verify we got an EndTurn (the tool completed, just with timeout)
2559 assert_eq!(
2560 stop_events(remaining_events),
2561 vec![acp::StopReason::EndTurn],
2562 );
2563
2564 // Verify the tool result indicates timeout, not user stopped
2565 thread.update(cx, |thread, _cx| {
2566 let message = thread.last_message().unwrap();
2567 let agent_message = message.as_agent_message().unwrap();
2568
2569 let tool_use = agent_message
2570 .content
2571 .iter()
2572 .find_map(|content| match content {
2573 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2574 _ => None,
2575 })
2576 .expect("expected tool use in agent message");
2577
2578 let tool_result = agent_message
2579 .tool_results
2580 .get(&tool_use.id)
2581 .expect("expected tool result");
2582
2583 let result_text = match &tool_result.content {
2584 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2585 _ => panic!("expected text content in tool result"),
2586 };
2587
2588 assert!(
2589 result_text.contains("timed out"),
2590 "expected tool result to indicate timeout, got: {result_text}"
2591 );
2592 assert!(
2593 !result_text.contains("The user stopped"),
2594 "tool result should not mention user stopped when it timed out, got: {result_text}"
2595 );
2596 });
2597}
2598
2599#[gpui::test]
2600async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2601 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2602 let fake_model = model.as_fake();
2603
2604 let events_1 = thread
2605 .update(cx, |thread, cx| {
2606 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2607 })
2608 .unwrap();
2609 cx.run_until_parked();
2610 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2611 cx.run_until_parked();
2612
2613 let events_2 = thread
2614 .update(cx, |thread, cx| {
2615 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2616 })
2617 .unwrap();
2618 cx.run_until_parked();
2619 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2620 fake_model
2621 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2622 fake_model.end_last_completion_stream();
2623
2624 let events_1 = events_1.collect::<Vec<_>>().await;
2625 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2626 let events_2 = events_2.collect::<Vec<_>>().await;
2627 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2628}
2629
2630#[gpui::test]
2631async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2632 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2633 let fake_model = model.as_fake();
2634
2635 let events_1 = thread
2636 .update(cx, |thread, cx| {
2637 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2638 })
2639 .unwrap();
2640 cx.run_until_parked();
2641 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2642 fake_model
2643 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2644 fake_model.end_last_completion_stream();
2645 let events_1 = events_1.collect::<Vec<_>>().await;
2646
2647 let events_2 = thread
2648 .update(cx, |thread, cx| {
2649 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2650 })
2651 .unwrap();
2652 cx.run_until_parked();
2653 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2654 fake_model
2655 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2656 fake_model.end_last_completion_stream();
2657 let events_2 = events_2.collect::<Vec<_>>().await;
2658
2659 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2660 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2661}
2662
2663#[gpui::test]
2664async fn test_refusal(cx: &mut TestAppContext) {
2665 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2666 let fake_model = model.as_fake();
2667
2668 let events = thread
2669 .update(cx, |thread, cx| {
2670 thread.send(UserMessageId::new(), ["Hello"], cx)
2671 })
2672 .unwrap();
2673 cx.run_until_parked();
2674 thread.read_with(cx, |thread, _| {
2675 assert_eq!(
2676 thread.to_markdown(),
2677 indoc! {"
2678 ## User
2679
2680 Hello
2681 "}
2682 );
2683 });
2684
2685 fake_model.send_last_completion_stream_text_chunk("Hey!");
2686 cx.run_until_parked();
2687 thread.read_with(cx, |thread, _| {
2688 assert_eq!(
2689 thread.to_markdown(),
2690 indoc! {"
2691 ## User
2692
2693 Hello
2694
2695 ## Assistant
2696
2697 Hey!
2698 "}
2699 );
2700 });
2701
2702 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2703 fake_model
2704 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2705 let events = events.collect::<Vec<_>>().await;
2706 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2707 thread.read_with(cx, |thread, _| {
2708 assert_eq!(thread.to_markdown(), "");
2709 });
2710}
2711
2712#[gpui::test]
2713async fn test_truncate_first_message(cx: &mut TestAppContext) {
2714 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2715 let fake_model = model.as_fake();
2716
2717 let message_id = UserMessageId::new();
2718 thread
2719 .update(cx, |thread, cx| {
2720 thread.send(message_id.clone(), ["Hello"], cx)
2721 })
2722 .unwrap();
2723 cx.run_until_parked();
2724 thread.read_with(cx, |thread, _| {
2725 assert_eq!(
2726 thread.to_markdown(),
2727 indoc! {"
2728 ## User
2729
2730 Hello
2731 "}
2732 );
2733 assert_eq!(thread.latest_token_usage(), None);
2734 });
2735
2736 fake_model.send_last_completion_stream_text_chunk("Hey!");
2737 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2738 language_model::TokenUsage {
2739 input_tokens: 32_000,
2740 output_tokens: 16_000,
2741 cache_creation_input_tokens: 0,
2742 cache_read_input_tokens: 0,
2743 },
2744 ));
2745 cx.run_until_parked();
2746 thread.read_with(cx, |thread, _| {
2747 assert_eq!(
2748 thread.to_markdown(),
2749 indoc! {"
2750 ## User
2751
2752 Hello
2753
2754 ## Assistant
2755
2756 Hey!
2757 "}
2758 );
2759 assert_eq!(
2760 thread.latest_token_usage(),
2761 Some(acp_thread::TokenUsage {
2762 used_tokens: 32_000 + 16_000,
2763 max_tokens: 1_000_000,
2764 max_output_tokens: None,
2765 input_tokens: 32_000,
2766 output_tokens: 16_000,
2767 })
2768 );
2769 });
2770
2771 thread
2772 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2773 .unwrap();
2774 cx.run_until_parked();
2775 thread.read_with(cx, |thread, _| {
2776 assert_eq!(thread.to_markdown(), "");
2777 assert_eq!(thread.latest_token_usage(), None);
2778 });
2779
2780 // Ensure we can still send a new message after truncation.
2781 thread
2782 .update(cx, |thread, cx| {
2783 thread.send(UserMessageId::new(), ["Hi"], cx)
2784 })
2785 .unwrap();
2786 thread.update(cx, |thread, _cx| {
2787 assert_eq!(
2788 thread.to_markdown(),
2789 indoc! {"
2790 ## User
2791
2792 Hi
2793 "}
2794 );
2795 });
2796 cx.run_until_parked();
2797 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2798 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2799 language_model::TokenUsage {
2800 input_tokens: 40_000,
2801 output_tokens: 20_000,
2802 cache_creation_input_tokens: 0,
2803 cache_read_input_tokens: 0,
2804 },
2805 ));
2806 cx.run_until_parked();
2807 thread.read_with(cx, |thread, _| {
2808 assert_eq!(
2809 thread.to_markdown(),
2810 indoc! {"
2811 ## User
2812
2813 Hi
2814
2815 ## Assistant
2816
2817 Ahoy!
2818 "}
2819 );
2820
2821 assert_eq!(
2822 thread.latest_token_usage(),
2823 Some(acp_thread::TokenUsage {
2824 used_tokens: 40_000 + 20_000,
2825 max_tokens: 1_000_000,
2826 max_output_tokens: None,
2827 input_tokens: 40_000,
2828 output_tokens: 20_000,
2829 })
2830 );
2831 });
2832}
2833
2834#[gpui::test]
2835async fn test_truncate_second_message(cx: &mut TestAppContext) {
2836 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2837 let fake_model = model.as_fake();
2838
2839 thread
2840 .update(cx, |thread, cx| {
2841 thread.send(UserMessageId::new(), ["Message 1"], cx)
2842 })
2843 .unwrap();
2844 cx.run_until_parked();
2845 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2846 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2847 language_model::TokenUsage {
2848 input_tokens: 32_000,
2849 output_tokens: 16_000,
2850 cache_creation_input_tokens: 0,
2851 cache_read_input_tokens: 0,
2852 },
2853 ));
2854 fake_model.end_last_completion_stream();
2855 cx.run_until_parked();
2856
2857 let assert_first_message_state = |cx: &mut TestAppContext| {
2858 thread.clone().read_with(cx, |thread, _| {
2859 assert_eq!(
2860 thread.to_markdown(),
2861 indoc! {"
2862 ## User
2863
2864 Message 1
2865
2866 ## Assistant
2867
2868 Message 1 response
2869 "}
2870 );
2871
2872 assert_eq!(
2873 thread.latest_token_usage(),
2874 Some(acp_thread::TokenUsage {
2875 used_tokens: 32_000 + 16_000,
2876 max_tokens: 1_000_000,
2877 max_output_tokens: None,
2878 input_tokens: 32_000,
2879 output_tokens: 16_000,
2880 })
2881 );
2882 });
2883 };
2884
2885 assert_first_message_state(cx);
2886
2887 let second_message_id = UserMessageId::new();
2888 thread
2889 .update(cx, |thread, cx| {
2890 thread.send(second_message_id.clone(), ["Message 2"], cx)
2891 })
2892 .unwrap();
2893 cx.run_until_parked();
2894
2895 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2896 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2897 language_model::TokenUsage {
2898 input_tokens: 40_000,
2899 output_tokens: 20_000,
2900 cache_creation_input_tokens: 0,
2901 cache_read_input_tokens: 0,
2902 },
2903 ));
2904 fake_model.end_last_completion_stream();
2905 cx.run_until_parked();
2906
2907 thread.read_with(cx, |thread, _| {
2908 assert_eq!(
2909 thread.to_markdown(),
2910 indoc! {"
2911 ## User
2912
2913 Message 1
2914
2915 ## Assistant
2916
2917 Message 1 response
2918
2919 ## User
2920
2921 Message 2
2922
2923 ## Assistant
2924
2925 Message 2 response
2926 "}
2927 );
2928
2929 assert_eq!(
2930 thread.latest_token_usage(),
2931 Some(acp_thread::TokenUsage {
2932 used_tokens: 40_000 + 20_000,
2933 max_tokens: 1_000_000,
2934 max_output_tokens: None,
2935 input_tokens: 40_000,
2936 output_tokens: 20_000,
2937 })
2938 );
2939 });
2940
2941 thread
2942 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2943 .unwrap();
2944 cx.run_until_parked();
2945
2946 assert_first_message_state(cx);
2947}
2948
2949#[gpui::test]
2950async fn test_title_generation(cx: &mut TestAppContext) {
2951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2952 let fake_model = model.as_fake();
2953
2954 let summary_model = Arc::new(FakeLanguageModel::default());
2955 thread.update(cx, |thread, cx| {
2956 thread.set_summarization_model(Some(summary_model.clone()), cx)
2957 });
2958
2959 let send = thread
2960 .update(cx, |thread, cx| {
2961 thread.send(UserMessageId::new(), ["Hello"], cx)
2962 })
2963 .unwrap();
2964 cx.run_until_parked();
2965
2966 fake_model.send_last_completion_stream_text_chunk("Hey!");
2967 fake_model.end_last_completion_stream();
2968 cx.run_until_parked();
2969 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2970
2971 // Ensure the summary model has been invoked to generate a title.
2972 summary_model.send_last_completion_stream_text_chunk("Hello ");
2973 summary_model.send_last_completion_stream_text_chunk("world\nG");
2974 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2975 summary_model.end_last_completion_stream();
2976 send.collect::<Vec<_>>().await;
2977 cx.run_until_parked();
2978 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2979
2980 // Send another message, ensuring no title is generated this time.
2981 let send = thread
2982 .update(cx, |thread, cx| {
2983 thread.send(UserMessageId::new(), ["Hello again"], cx)
2984 })
2985 .unwrap();
2986 cx.run_until_parked();
2987 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2988 fake_model.end_last_completion_stream();
2989 cx.run_until_parked();
2990 assert_eq!(summary_model.pending_completions(), Vec::new());
2991 send.collect::<Vec<_>>().await;
2992 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2993}
2994
2995#[gpui::test]
2996async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2997 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2998 let fake_model = model.as_fake();
2999
3000 let _events = thread
3001 .update(cx, |thread, cx| {
3002 thread.add_tool(ToolRequiringPermission, None);
3003 thread.add_tool(EchoTool, None);
3004 thread.send(UserMessageId::new(), ["Hey!"], cx)
3005 })
3006 .unwrap();
3007 cx.run_until_parked();
3008
3009 let permission_tool_use = LanguageModelToolUse {
3010 id: "tool_id_1".into(),
3011 name: ToolRequiringPermission::NAME.into(),
3012 raw_input: "{}".into(),
3013 input: json!({}),
3014 is_input_complete: true,
3015 thought_signature: None,
3016 };
3017 let echo_tool_use = LanguageModelToolUse {
3018 id: "tool_id_2".into(),
3019 name: EchoTool::NAME.into(),
3020 raw_input: json!({"text": "test"}).to_string(),
3021 input: json!({"text": "test"}),
3022 is_input_complete: true,
3023 thought_signature: None,
3024 };
3025 fake_model.send_last_completion_stream_text_chunk("Hi!");
3026 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3027 permission_tool_use,
3028 ));
3029 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3030 echo_tool_use.clone(),
3031 ));
3032 fake_model.end_last_completion_stream();
3033 cx.run_until_parked();
3034
3035 // Ensure pending tools are skipped when building a request.
3036 let request = thread
3037 .read_with(cx, |thread, cx| {
3038 thread.build_completion_request(CompletionIntent::EditFile, cx)
3039 })
3040 .unwrap();
3041 assert_eq!(
3042 request.messages[1..],
3043 vec![
3044 LanguageModelRequestMessage {
3045 role: Role::User,
3046 content: vec!["Hey!".into()],
3047 cache: true,
3048 reasoning_details: None,
3049 },
3050 LanguageModelRequestMessage {
3051 role: Role::Assistant,
3052 content: vec![
3053 MessageContent::Text("Hi!".into()),
3054 MessageContent::ToolUse(echo_tool_use.clone())
3055 ],
3056 cache: false,
3057 reasoning_details: None,
3058 },
3059 LanguageModelRequestMessage {
3060 role: Role::User,
3061 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3062 tool_use_id: echo_tool_use.id.clone(),
3063 tool_name: echo_tool_use.name,
3064 is_error: false,
3065 content: "test".into(),
3066 output: Some("test".into())
3067 })],
3068 cache: false,
3069 reasoning_details: None,
3070 },
3071 ],
3072 );
3073}
3074
3075#[gpui::test]
3076async fn test_agent_connection(cx: &mut TestAppContext) {
3077 cx.update(settings::init);
3078 let templates = Templates::new();
3079
3080 // Initialize language model system with test provider
3081 cx.update(|cx| {
3082 gpui_tokio::init(cx);
3083
3084 let http_client = FakeHttpClient::with_404_response();
3085 let clock = Arc::new(clock::FakeSystemClock::new());
3086 let client = Client::new(clock, http_client, cx);
3087 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3088 language_model::init(client.clone(), cx);
3089 language_models::init(user_store, client.clone(), cx);
3090 LanguageModelRegistry::test(cx);
3091 });
3092 cx.executor().forbid_parking();
3093
3094 // Create a project for new_thread
3095 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3096 fake_fs.insert_tree(path!("/test"), json!({})).await;
3097 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3098 let cwd = Path::new("/test");
3099 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3100
3101 // Create agent and connection
3102 let agent = NativeAgent::new(
3103 project.clone(),
3104 thread_store,
3105 templates.clone(),
3106 None,
3107 fake_fs.clone(),
3108 &mut cx.to_async(),
3109 )
3110 .await
3111 .unwrap();
3112 let connection = NativeAgentConnection(agent.clone());
3113
3114 // Create a thread using new_thread
3115 let connection_rc = Rc::new(connection.clone());
3116 let acp_thread = cx
3117 .update(|cx| connection_rc.new_session(project, cwd, cx))
3118 .await
3119 .expect("new_thread should succeed");
3120
3121 // Get the session_id from the AcpThread
3122 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3123
3124 // Test model_selector returns Some
3125 let selector_opt = connection.model_selector(&session_id);
3126 assert!(
3127 selector_opt.is_some(),
3128 "agent should always support ModelSelector"
3129 );
3130 let selector = selector_opt.unwrap();
3131
3132 // Test list_models
3133 let listed_models = cx
3134 .update(|cx| selector.list_models(cx))
3135 .await
3136 .expect("list_models should succeed");
3137 let AgentModelList::Grouped(listed_models) = listed_models else {
3138 panic!("Unexpected model list type");
3139 };
3140 assert!(!listed_models.is_empty(), "should have at least one model");
3141 assert_eq!(
3142 listed_models[&AgentModelGroupName("Fake".into())][0]
3143 .id
3144 .0
3145 .as_ref(),
3146 "fake/fake"
3147 );
3148
3149 // Test selected_model returns the default
3150 let model = cx
3151 .update(|cx| selector.selected_model(cx))
3152 .await
3153 .expect("selected_model should succeed");
3154 let model = cx
3155 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3156 .unwrap();
3157 let model = model.as_fake();
3158 assert_eq!(model.id().0, "fake", "should return default model");
3159
3160 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3161 cx.run_until_parked();
3162 model.send_last_completion_stream_text_chunk("def");
3163 cx.run_until_parked();
3164 acp_thread.read_with(cx, |thread, cx| {
3165 assert_eq!(
3166 thread.to_markdown(cx),
3167 indoc! {"
3168 ## User
3169
3170 abc
3171
3172 ## Assistant
3173
3174 def
3175
3176 "}
3177 )
3178 });
3179
3180 // Test cancel
3181 cx.update(|cx| connection.cancel(&session_id, cx));
3182 request.await.expect("prompt should fail gracefully");
3183
3184 // Explicitly close the session and drop the ACP thread.
3185 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3186 .await
3187 .unwrap();
3188 drop(acp_thread);
3189 let result = cx
3190 .update(|cx| {
3191 connection.prompt(
3192 Some(acp_thread::UserMessageId::new()),
3193 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3194 cx,
3195 )
3196 })
3197 .await;
3198 assert_eq!(
3199 result.as_ref().unwrap_err().to_string(),
3200 "Session not found",
3201 "unexpected result: {:?}",
3202 result
3203 );
3204}
3205
3206#[gpui::test]
3207async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3208 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3209 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
3210 let fake_model = model.as_fake();
3211
3212 let mut events = thread
3213 .update(cx, |thread, cx| {
3214 thread.send(UserMessageId::new(), ["Echo something"], cx)
3215 })
3216 .unwrap();
3217 cx.run_until_parked();
3218
3219 // Simulate streaming partial input.
3220 let input = json!({});
3221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3222 LanguageModelToolUse {
3223 id: "1".into(),
3224 name: EchoTool::NAME.into(),
3225 raw_input: input.to_string(),
3226 input,
3227 is_input_complete: false,
3228 thought_signature: None,
3229 },
3230 ));
3231
3232 // Input streaming completed
3233 let input = json!({ "text": "Hello!" });
3234 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3235 LanguageModelToolUse {
3236 id: "1".into(),
3237 name: "echo".into(),
3238 raw_input: input.to_string(),
3239 input,
3240 is_input_complete: true,
3241 thought_signature: None,
3242 },
3243 ));
3244 fake_model.end_last_completion_stream();
3245 cx.run_until_parked();
3246
3247 let tool_call = expect_tool_call(&mut events).await;
3248 assert_eq!(
3249 tool_call,
3250 acp::ToolCall::new("1", "Echo")
3251 .raw_input(json!({}))
3252 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3253 );
3254 let update = expect_tool_call_update_fields(&mut events).await;
3255 assert_eq!(
3256 update,
3257 acp::ToolCallUpdate::new(
3258 "1",
3259 acp::ToolCallUpdateFields::new()
3260 .title("Echo")
3261 .kind(acp::ToolKind::Other)
3262 .raw_input(json!({ "text": "Hello!"}))
3263 )
3264 );
3265 let update = expect_tool_call_update_fields(&mut events).await;
3266 assert_eq!(
3267 update,
3268 acp::ToolCallUpdate::new(
3269 "1",
3270 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3271 )
3272 );
3273 let update = expect_tool_call_update_fields(&mut events).await;
3274 assert_eq!(
3275 update,
3276 acp::ToolCallUpdate::new(
3277 "1",
3278 acp::ToolCallUpdateFields::new()
3279 .status(acp::ToolCallStatus::Completed)
3280 .raw_output("Hello!")
3281 )
3282 );
3283}
3284
3285#[gpui::test]
3286async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3287 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3288 let fake_model = model.as_fake();
3289
3290 let mut events = thread
3291 .update(cx, |thread, cx| {
3292 thread.send(UserMessageId::new(), ["Hello!"], cx)
3293 })
3294 .unwrap();
3295 cx.run_until_parked();
3296
3297 fake_model.send_last_completion_stream_text_chunk("Hey!");
3298 fake_model.end_last_completion_stream();
3299
3300 let mut retry_events = Vec::new();
3301 while let Some(Ok(event)) = events.next().await {
3302 match event {
3303 ThreadEvent::Retry(retry_status) => {
3304 retry_events.push(retry_status);
3305 }
3306 ThreadEvent::Stop(..) => break,
3307 _ => {}
3308 }
3309 }
3310
3311 assert_eq!(retry_events.len(), 0);
3312 thread.read_with(cx, |thread, _cx| {
3313 assert_eq!(
3314 thread.to_markdown(),
3315 indoc! {"
3316 ## User
3317
3318 Hello!
3319
3320 ## Assistant
3321
3322 Hey!
3323 "}
3324 )
3325 });
3326}
3327
3328#[gpui::test]
3329async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3330 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3331 let fake_model = model.as_fake();
3332
3333 let mut events = thread
3334 .update(cx, |thread, cx| {
3335 thread.send(UserMessageId::new(), ["Hello!"], cx)
3336 })
3337 .unwrap();
3338 cx.run_until_parked();
3339
3340 fake_model.send_last_completion_stream_text_chunk("Hey,");
3341 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3342 provider: LanguageModelProviderName::new("Anthropic"),
3343 retry_after: Some(Duration::from_secs(3)),
3344 });
3345 fake_model.end_last_completion_stream();
3346
3347 cx.executor().advance_clock(Duration::from_secs(3));
3348 cx.run_until_parked();
3349
3350 fake_model.send_last_completion_stream_text_chunk("there!");
3351 fake_model.end_last_completion_stream();
3352 cx.run_until_parked();
3353
3354 let mut retry_events = Vec::new();
3355 while let Some(Ok(event)) = events.next().await {
3356 match event {
3357 ThreadEvent::Retry(retry_status) => {
3358 retry_events.push(retry_status);
3359 }
3360 ThreadEvent::Stop(..) => break,
3361 _ => {}
3362 }
3363 }
3364
3365 assert_eq!(retry_events.len(), 1);
3366 assert!(matches!(
3367 retry_events[0],
3368 acp_thread::RetryStatus { attempt: 1, .. }
3369 ));
3370 thread.read_with(cx, |thread, _cx| {
3371 assert_eq!(
3372 thread.to_markdown(),
3373 indoc! {"
3374 ## User
3375
3376 Hello!
3377
3378 ## Assistant
3379
3380 Hey,
3381
3382 [resume]
3383
3384 ## Assistant
3385
3386 there!
3387 "}
3388 )
3389 });
3390}
3391
3392#[gpui::test]
3393async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3394 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3395 let fake_model = model.as_fake();
3396
3397 let events = thread
3398 .update(cx, |thread, cx| {
3399 thread.add_tool(EchoTool, None);
3400 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3401 })
3402 .unwrap();
3403 cx.run_until_parked();
3404
3405 let tool_use_1 = LanguageModelToolUse {
3406 id: "tool_1".into(),
3407 name: EchoTool::NAME.into(),
3408 raw_input: json!({"text": "test"}).to_string(),
3409 input: json!({"text": "test"}),
3410 is_input_complete: true,
3411 thought_signature: None,
3412 };
3413 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3414 tool_use_1.clone(),
3415 ));
3416 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3417 provider: LanguageModelProviderName::new("Anthropic"),
3418 retry_after: Some(Duration::from_secs(3)),
3419 });
3420 fake_model.end_last_completion_stream();
3421
3422 cx.executor().advance_clock(Duration::from_secs(3));
3423 let completion = fake_model.pending_completions().pop().unwrap();
3424 assert_eq!(
3425 completion.messages[1..],
3426 vec![
3427 LanguageModelRequestMessage {
3428 role: Role::User,
3429 content: vec!["Call the echo tool!".into()],
3430 cache: false,
3431 reasoning_details: None,
3432 },
3433 LanguageModelRequestMessage {
3434 role: Role::Assistant,
3435 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3436 cache: false,
3437 reasoning_details: None,
3438 },
3439 LanguageModelRequestMessage {
3440 role: Role::User,
3441 content: vec![language_model::MessageContent::ToolResult(
3442 LanguageModelToolResult {
3443 tool_use_id: tool_use_1.id.clone(),
3444 tool_name: tool_use_1.name.clone(),
3445 is_error: false,
3446 content: "test".into(),
3447 output: Some("test".into())
3448 }
3449 )],
3450 cache: true,
3451 reasoning_details: None,
3452 },
3453 ]
3454 );
3455
3456 fake_model.send_last_completion_stream_text_chunk("Done");
3457 fake_model.end_last_completion_stream();
3458 cx.run_until_parked();
3459 events.collect::<Vec<_>>().await;
3460 thread.read_with(cx, |thread, _cx| {
3461 assert_eq!(
3462 thread.last_message(),
3463 Some(Message::Agent(AgentMessage {
3464 content: vec![AgentMessageContent::Text("Done".into())],
3465 tool_results: IndexMap::default(),
3466 reasoning_details: None,
3467 }))
3468 );
3469 })
3470}
3471
3472#[gpui::test]
3473async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3474 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3475 let fake_model = model.as_fake();
3476
3477 let mut events = thread
3478 .update(cx, |thread, cx| {
3479 thread.send(UserMessageId::new(), ["Hello!"], cx)
3480 })
3481 .unwrap();
3482 cx.run_until_parked();
3483
3484 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3485 fake_model.send_last_completion_stream_error(
3486 LanguageModelCompletionError::ServerOverloaded {
3487 provider: LanguageModelProviderName::new("Anthropic"),
3488 retry_after: Some(Duration::from_secs(3)),
3489 },
3490 );
3491 fake_model.end_last_completion_stream();
3492 cx.executor().advance_clock(Duration::from_secs(3));
3493 cx.run_until_parked();
3494 }
3495
3496 let mut errors = Vec::new();
3497 let mut retry_events = Vec::new();
3498 while let Some(event) = events.next().await {
3499 match event {
3500 Ok(ThreadEvent::Retry(retry_status)) => {
3501 retry_events.push(retry_status);
3502 }
3503 Ok(ThreadEvent::Stop(..)) => break,
3504 Err(error) => errors.push(error),
3505 _ => {}
3506 }
3507 }
3508
3509 assert_eq!(
3510 retry_events.len(),
3511 crate::thread::MAX_RETRY_ATTEMPTS as usize
3512 );
3513 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3514 assert_eq!(retry_events[i].attempt, i + 1);
3515 }
3516 assert_eq!(errors.len(), 1);
3517 let error = errors[0]
3518 .downcast_ref::<LanguageModelCompletionError>()
3519 .unwrap();
3520 assert!(matches!(
3521 error,
3522 LanguageModelCompletionError::ServerOverloaded { .. }
3523 ));
3524}
3525
3526/// Filters out the stop events for asserting against in tests
3527fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3528 result_events
3529 .into_iter()
3530 .filter_map(|event| match event.unwrap() {
3531 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3532 _ => None,
3533 })
3534 .collect()
3535}
3536
3537struct ThreadTest {
3538 model: Arc<dyn LanguageModel>,
3539 thread: Entity<Thread>,
3540 project_context: Entity<ProjectContext>,
3541 context_server_store: Entity<ContextServerStore>,
3542 fs: Arc<FakeFs>,
3543}
3544
3545enum TestModel {
3546 Sonnet4,
3547 Fake,
3548}
3549
3550impl TestModel {
3551 fn id(&self) -> LanguageModelId {
3552 match self {
3553 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3554 TestModel::Fake => unreachable!(),
3555 }
3556 }
3557}
3558
3559async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3560 cx.executor().allow_parking();
3561
3562 let fs = FakeFs::new(cx.background_executor.clone());
3563 fs.create_dir(paths::settings_file().parent().unwrap())
3564 .await
3565 .unwrap();
3566 fs.insert_file(
3567 paths::settings_file(),
3568 json!({
3569 "agent": {
3570 "default_profile": "test-profile",
3571 "profiles": {
3572 "test-profile": {
3573 "name": "Test Profile",
3574 "tools": {
3575 EchoTool::NAME: true,
3576 DelayTool::NAME: true,
3577 WordListTool::NAME: true,
3578 ToolRequiringPermission::NAME: true,
3579 InfiniteTool::NAME: true,
3580 CancellationAwareTool::NAME: true,
3581 (TerminalTool::NAME): true,
3582 }
3583 }
3584 }
3585 }
3586 })
3587 .to_string()
3588 .into_bytes(),
3589 )
3590 .await;
3591
3592 cx.update(|cx| {
3593 settings::init(cx);
3594
3595 match model {
3596 TestModel::Fake => {}
3597 TestModel::Sonnet4 => {
3598 gpui_tokio::init(cx);
3599 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3600 cx.set_http_client(Arc::new(http_client));
3601 let client = Client::production(cx);
3602 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3603 language_model::init(client.clone(), cx);
3604 language_models::init(user_store, client.clone(), cx);
3605 }
3606 };
3607
3608 watch_settings(fs.clone(), cx);
3609 });
3610
3611 let templates = Templates::new();
3612
3613 fs.insert_tree(path!("/test"), json!({})).await;
3614 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3615
3616 let model = cx
3617 .update(|cx| {
3618 if let TestModel::Fake = model {
3619 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3620 } else {
3621 let model_id = model.id();
3622 let models = LanguageModelRegistry::read_global(cx);
3623 let model = models
3624 .available_models(cx)
3625 .find(|model| model.id() == model_id)
3626 .unwrap();
3627
3628 let provider = models.provider(&model.provider_id()).unwrap();
3629 let authenticated = provider.authenticate(cx);
3630
3631 cx.spawn(async move |_cx| {
3632 authenticated.await.unwrap();
3633 model
3634 })
3635 }
3636 })
3637 .await;
3638
3639 let project_context = cx.new(|_cx| ProjectContext::default());
3640 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3641 let context_server_registry =
3642 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3643 let thread = cx.new(|cx| {
3644 Thread::new(
3645 project,
3646 project_context.clone(),
3647 context_server_registry,
3648 templates,
3649 Some(model.clone()),
3650 cx,
3651 )
3652 });
3653 ThreadTest {
3654 model,
3655 thread,
3656 project_context,
3657 context_server_store,
3658 fs,
3659 }
3660}
3661
3662#[cfg(test)]
3663#[ctor::ctor]
3664fn init_logger() {
3665 if std::env::var("RUST_LOG").is_ok() {
3666 env_logger::init();
3667 }
3668}
3669
3670fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3671 let fs = fs.clone();
3672 cx.spawn({
3673 async move |cx| {
3674 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3675 cx.background_executor(),
3676 fs,
3677 paths::settings_file().clone(),
3678 );
3679 let _watcher_task = watcher_task;
3680
3681 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3682 cx.update(|cx| {
3683 SettingsStore::update_global(cx, |settings, cx| {
3684 settings.set_user_settings(&new_settings_content, cx)
3685 })
3686 })
3687 .ok();
3688 }
3689 }
3690 })
3691 .detach();
3692}
3693
3694fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3695 completion
3696 .tools
3697 .iter()
3698 .map(|tool| tool.name.clone())
3699 .collect()
3700}
3701
3702fn setup_context_server(
3703 name: &'static str,
3704 tools: Vec<context_server::types::Tool>,
3705 context_server_store: &Entity<ContextServerStore>,
3706 cx: &mut TestAppContext,
3707) -> mpsc::UnboundedReceiver<(
3708 context_server::types::CallToolParams,
3709 oneshot::Sender<context_server::types::CallToolResponse>,
3710)> {
3711 cx.update(|cx| {
3712 let mut settings = ProjectSettings::get_global(cx).clone();
3713 settings.context_servers.insert(
3714 name.into(),
3715 project::project_settings::ContextServerSettings::Stdio {
3716 enabled: true,
3717 remote: false,
3718 command: ContextServerCommand {
3719 path: "somebinary".into(),
3720 args: Vec::new(),
3721 env: None,
3722 timeout: None,
3723 },
3724 },
3725 );
3726 ProjectSettings::override_global(settings, cx);
3727 });
3728
3729 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3730 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3731 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3732 context_server::types::InitializeResponse {
3733 protocol_version: context_server::types::ProtocolVersion(
3734 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3735 ),
3736 server_info: context_server::types::Implementation {
3737 name: name.into(),
3738 version: "1.0.0".to_string(),
3739 },
3740 capabilities: context_server::types::ServerCapabilities {
3741 tools: Some(context_server::types::ToolsCapabilities {
3742 list_changed: Some(true),
3743 }),
3744 ..Default::default()
3745 },
3746 meta: None,
3747 }
3748 })
3749 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3750 let tools = tools.clone();
3751 async move {
3752 context_server::types::ListToolsResponse {
3753 tools,
3754 next_cursor: None,
3755 meta: None,
3756 }
3757 }
3758 })
3759 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3760 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3761 async move {
3762 let (response_tx, response_rx) = oneshot::channel();
3763 mcp_tool_calls_tx
3764 .unbounded_send((params, response_tx))
3765 .unwrap();
3766 response_rx.await.unwrap()
3767 }
3768 });
3769 context_server_store.update(cx, |store, cx| {
3770 store.start_server(
3771 Arc::new(ContextServer::new(
3772 ContextServerId(name.into()),
3773 Arc::new(fake_transport),
3774 )),
3775 cx,
3776 );
3777 });
3778 cx.run_until_parked();
3779 mcp_tool_calls_rx
3780}
3781
3782#[gpui::test]
3783async fn test_tokens_before_message(cx: &mut TestAppContext) {
3784 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3785 let fake_model = model.as_fake();
3786
3787 // First message
3788 let message_1_id = UserMessageId::new();
3789 thread
3790 .update(cx, |thread, cx| {
3791 thread.send(message_1_id.clone(), ["First message"], cx)
3792 })
3793 .unwrap();
3794 cx.run_until_parked();
3795
3796 // Before any response, tokens_before_message should return None for first message
3797 thread.read_with(cx, |thread, _| {
3798 assert_eq!(
3799 thread.tokens_before_message(&message_1_id),
3800 None,
3801 "First message should have no tokens before it"
3802 );
3803 });
3804
3805 // Complete first message with usage
3806 fake_model.send_last_completion_stream_text_chunk("Response 1");
3807 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3808 language_model::TokenUsage {
3809 input_tokens: 100,
3810 output_tokens: 50,
3811 cache_creation_input_tokens: 0,
3812 cache_read_input_tokens: 0,
3813 },
3814 ));
3815 fake_model.end_last_completion_stream();
3816 cx.run_until_parked();
3817
3818 // First message still has no tokens before it
3819 thread.read_with(cx, |thread, _| {
3820 assert_eq!(
3821 thread.tokens_before_message(&message_1_id),
3822 None,
3823 "First message should still have no tokens before it after response"
3824 );
3825 });
3826
3827 // Second message
3828 let message_2_id = UserMessageId::new();
3829 thread
3830 .update(cx, |thread, cx| {
3831 thread.send(message_2_id.clone(), ["Second message"], cx)
3832 })
3833 .unwrap();
3834 cx.run_until_parked();
3835
3836 // Second message should have first message's input tokens before it
3837 thread.read_with(cx, |thread, _| {
3838 assert_eq!(
3839 thread.tokens_before_message(&message_2_id),
3840 Some(100),
3841 "Second message should have 100 tokens before it (from first request)"
3842 );
3843 });
3844
3845 // Complete second message
3846 fake_model.send_last_completion_stream_text_chunk("Response 2");
3847 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3848 language_model::TokenUsage {
3849 input_tokens: 250, // Total for this request (includes previous context)
3850 output_tokens: 75,
3851 cache_creation_input_tokens: 0,
3852 cache_read_input_tokens: 0,
3853 },
3854 ));
3855 fake_model.end_last_completion_stream();
3856 cx.run_until_parked();
3857
3858 // Third message
3859 let message_3_id = UserMessageId::new();
3860 thread
3861 .update(cx, |thread, cx| {
3862 thread.send(message_3_id.clone(), ["Third message"], cx)
3863 })
3864 .unwrap();
3865 cx.run_until_parked();
3866
3867 // Third message should have second message's input tokens (250) before it
3868 thread.read_with(cx, |thread, _| {
3869 assert_eq!(
3870 thread.tokens_before_message(&message_3_id),
3871 Some(250),
3872 "Third message should have 250 tokens before it (from second request)"
3873 );
3874 // Second message should still have 100
3875 assert_eq!(
3876 thread.tokens_before_message(&message_2_id),
3877 Some(100),
3878 "Second message should still have 100 tokens before it"
3879 );
3880 // First message still has none
3881 assert_eq!(
3882 thread.tokens_before_message(&message_1_id),
3883 None,
3884 "First message should still have no tokens before it"
3885 );
3886 });
3887}
3888
3889#[gpui::test]
3890async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3891 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3892 let fake_model = model.as_fake();
3893
3894 // Set up three messages with responses
3895 let message_1_id = UserMessageId::new();
3896 thread
3897 .update(cx, |thread, cx| {
3898 thread.send(message_1_id.clone(), ["Message 1"], cx)
3899 })
3900 .unwrap();
3901 cx.run_until_parked();
3902 fake_model.send_last_completion_stream_text_chunk("Response 1");
3903 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3904 language_model::TokenUsage {
3905 input_tokens: 100,
3906 output_tokens: 50,
3907 cache_creation_input_tokens: 0,
3908 cache_read_input_tokens: 0,
3909 },
3910 ));
3911 fake_model.end_last_completion_stream();
3912 cx.run_until_parked();
3913
3914 let message_2_id = UserMessageId::new();
3915 thread
3916 .update(cx, |thread, cx| {
3917 thread.send(message_2_id.clone(), ["Message 2"], cx)
3918 })
3919 .unwrap();
3920 cx.run_until_parked();
3921 fake_model.send_last_completion_stream_text_chunk("Response 2");
3922 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3923 language_model::TokenUsage {
3924 input_tokens: 250,
3925 output_tokens: 75,
3926 cache_creation_input_tokens: 0,
3927 cache_read_input_tokens: 0,
3928 },
3929 ));
3930 fake_model.end_last_completion_stream();
3931 cx.run_until_parked();
3932
3933 // Verify initial state
3934 thread.read_with(cx, |thread, _| {
3935 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3936 });
3937
3938 // Truncate at message 2 (removes message 2 and everything after)
3939 thread
3940 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3941 .unwrap();
3942 cx.run_until_parked();
3943
3944 // After truncation, message_2_id no longer exists, so lookup should return None
3945 thread.read_with(cx, |thread, _| {
3946 assert_eq!(
3947 thread.tokens_before_message(&message_2_id),
3948 None,
3949 "After truncation, message 2 no longer exists"
3950 );
3951 // Message 1 still exists but has no tokens before it
3952 assert_eq!(
3953 thread.tokens_before_message(&message_1_id),
3954 None,
3955 "First message still has no tokens before it"
3956 );
3957 });
3958}
3959
3960#[gpui::test]
3961async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3962 init_test(cx);
3963
3964 let fs = FakeFs::new(cx.executor());
3965 fs.insert_tree("/root", json!({})).await;
3966 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3967
3968 // Test 1: Deny rule blocks command
3969 {
3970 let environment = Rc::new(cx.update(|cx| {
3971 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3972 }));
3973
3974 cx.update(|cx| {
3975 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3976 settings.tool_permissions.tools.insert(
3977 TerminalTool::NAME.into(),
3978 agent_settings::ToolRules {
3979 default: Some(settings::ToolPermissionMode::Confirm),
3980 always_allow: vec![],
3981 always_deny: vec![
3982 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3983 ],
3984 always_confirm: vec![],
3985 invalid_patterns: vec![],
3986 },
3987 );
3988 agent_settings::AgentSettings::override_global(settings, cx);
3989 });
3990
3991 #[allow(clippy::arc_with_non_send_sync)]
3992 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3993 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3994
3995 let task = cx.update(|cx| {
3996 tool.run(
3997 crate::TerminalToolInput {
3998 command: "rm -rf /".to_string(),
3999 cd: ".".to_string(),
4000 timeout_ms: None,
4001 },
4002 event_stream,
4003 cx,
4004 )
4005 });
4006
4007 let result = task.await;
4008 assert!(
4009 result.is_err(),
4010 "expected command to be blocked by deny rule"
4011 );
4012 let err_msg = result.unwrap_err().to_string().to_lowercase();
4013 assert!(
4014 err_msg.contains("blocked"),
4015 "error should mention the command was blocked"
4016 );
4017 }
4018
4019 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4020 {
4021 let environment = Rc::new(cx.update(|cx| {
4022 FakeThreadEnvironment::default()
4023 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4024 }));
4025
4026 cx.update(|cx| {
4027 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4028 settings.tool_permissions.tools.insert(
4029 TerminalTool::NAME.into(),
4030 agent_settings::ToolRules {
4031 default: Some(settings::ToolPermissionMode::Deny),
4032 always_allow: vec![
4033 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4034 ],
4035 always_deny: vec![],
4036 always_confirm: vec![],
4037 invalid_patterns: vec![],
4038 },
4039 );
4040 agent_settings::AgentSettings::override_global(settings, cx);
4041 });
4042
4043 #[allow(clippy::arc_with_non_send_sync)]
4044 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4045 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4046
4047 let task = cx.update(|cx| {
4048 tool.run(
4049 crate::TerminalToolInput {
4050 command: "echo hello".to_string(),
4051 cd: ".".to_string(),
4052 timeout_ms: None,
4053 },
4054 event_stream,
4055 cx,
4056 )
4057 });
4058
4059 let update = rx.expect_update_fields().await;
4060 assert!(
4061 update.content.iter().any(|blocks| {
4062 blocks
4063 .iter()
4064 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4065 }),
4066 "expected terminal content (allow rule should skip confirmation and override default deny)"
4067 );
4068
4069 let result = task.await;
4070 assert!(
4071 result.is_ok(),
4072 "expected command to succeed without confirmation"
4073 );
4074 }
4075
4076 // Test 3: global default: allow does NOT override always_confirm patterns
4077 {
4078 let environment = Rc::new(cx.update(|cx| {
4079 FakeThreadEnvironment::default()
4080 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4081 }));
4082
4083 cx.update(|cx| {
4084 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4085 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4086 settings.tool_permissions.tools.insert(
4087 TerminalTool::NAME.into(),
4088 agent_settings::ToolRules {
4089 default: Some(settings::ToolPermissionMode::Allow),
4090 always_allow: vec![],
4091 always_deny: vec![],
4092 always_confirm: vec![
4093 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4094 ],
4095 invalid_patterns: vec![],
4096 },
4097 );
4098 agent_settings::AgentSettings::override_global(settings, cx);
4099 });
4100
4101 #[allow(clippy::arc_with_non_send_sync)]
4102 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4103 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4104
4105 let _task = cx.update(|cx| {
4106 tool.run(
4107 crate::TerminalToolInput {
4108 command: "sudo rm file".to_string(),
4109 cd: ".".to_string(),
4110 timeout_ms: None,
4111 },
4112 event_stream,
4113 cx,
4114 )
4115 });
4116
4117 // With global default: allow, confirm patterns are still respected
4118 // The expect_authorization() call will panic if no authorization is requested,
4119 // which validates that the confirm pattern still triggers confirmation
4120 let _auth = rx.expect_authorization().await;
4121
4122 drop(_task);
4123 }
4124
4125 // Test 4: tool-specific default: deny is respected even with global default: allow
4126 {
4127 let environment = Rc::new(cx.update(|cx| {
4128 FakeThreadEnvironment::default()
4129 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4130 }));
4131
4132 cx.update(|cx| {
4133 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4134 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4135 settings.tool_permissions.tools.insert(
4136 TerminalTool::NAME.into(),
4137 agent_settings::ToolRules {
4138 default: Some(settings::ToolPermissionMode::Deny),
4139 always_allow: vec![],
4140 always_deny: vec![],
4141 always_confirm: vec![],
4142 invalid_patterns: vec![],
4143 },
4144 );
4145 agent_settings::AgentSettings::override_global(settings, cx);
4146 });
4147
4148 #[allow(clippy::arc_with_non_send_sync)]
4149 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4150 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4151
4152 let task = cx.update(|cx| {
4153 tool.run(
4154 crate::TerminalToolInput {
4155 command: "echo hello".to_string(),
4156 cd: ".".to_string(),
4157 timeout_ms: None,
4158 },
4159 event_stream,
4160 cx,
4161 )
4162 });
4163
4164 // tool-specific default: deny is respected even with global default: allow
4165 let result = task.await;
4166 assert!(
4167 result.is_err(),
4168 "expected command to be blocked by tool-specific deny default"
4169 );
4170 let err_msg = result.unwrap_err().to_string().to_lowercase();
4171 assert!(
4172 err_msg.contains("disabled"),
4173 "error should mention the tool is disabled, got: {err_msg}"
4174 );
4175 }
4176}
4177
4178#[gpui::test]
4179async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4180 init_test(cx);
4181 cx.update(|cx| {
4182 LanguageModelRegistry::test(cx);
4183 });
4184 cx.update(|cx| {
4185 cx.update_flags(true, vec!["subagents".to_string()]);
4186 });
4187
4188 let fs = FakeFs::new(cx.executor());
4189 fs.insert_tree(
4190 "/",
4191 json!({
4192 "a": {
4193 "b.md": "Lorem"
4194 }
4195 }),
4196 )
4197 .await;
4198 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4199 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4200 let agent = NativeAgent::new(
4201 project.clone(),
4202 thread_store.clone(),
4203 Templates::new(),
4204 None,
4205 fs.clone(),
4206 &mut cx.to_async(),
4207 )
4208 .await
4209 .unwrap();
4210 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4211
4212 let acp_thread = cx
4213 .update(|cx| {
4214 connection
4215 .clone()
4216 .new_session(project.clone(), Path::new(""), cx)
4217 })
4218 .await
4219 .unwrap();
4220 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4221 let thread = agent.read_with(cx, |agent, _| {
4222 agent.sessions.get(&session_id).unwrap().thread.clone()
4223 });
4224 let model = Arc::new(FakeLanguageModel::default());
4225
4226 // Ensure empty threads are not saved, even if they get mutated.
4227 thread.update(cx, |thread, cx| {
4228 thread.set_model(model.clone(), cx);
4229 });
4230 cx.run_until_parked();
4231
4232 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4233 cx.run_until_parked();
4234 model.send_last_completion_stream_text_chunk("spawning subagent");
4235 let subagent_tool_input = SubagentToolInput {
4236 label: "label".to_string(),
4237 task_prompt: "subagent task prompt".to_string(),
4238 summary_prompt: "subagent summary prompt".to_string(),
4239 timeout_ms: None,
4240 allowed_tools: None,
4241 };
4242 let subagent_tool_use = LanguageModelToolUse {
4243 id: "subagent_1".into(),
4244 name: SubagentTool::NAME.into(),
4245 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4246 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4247 is_input_complete: true,
4248 thought_signature: None,
4249 };
4250 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4251 subagent_tool_use,
4252 ));
4253 model.end_last_completion_stream();
4254
4255 cx.run_until_parked();
4256
4257 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4258 thread
4259 .running_subagent_ids(cx)
4260 .get(0)
4261 .expect("subagent thread should be running")
4262 .clone()
4263 });
4264
4265 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4266 agent
4267 .sessions
4268 .get(&subagent_session_id)
4269 .expect("subagent session should exist")
4270 .acp_thread
4271 .clone()
4272 });
4273
4274 model.send_last_completion_stream_text_chunk("subagent task response");
4275 model.end_last_completion_stream();
4276
4277 cx.run_until_parked();
4278
4279 model.send_last_completion_stream_text_chunk("subagent summary response");
4280 model.end_last_completion_stream();
4281
4282 cx.run_until_parked();
4283
4284 assert_eq!(
4285 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4286 indoc! {"
4287 ## User
4288
4289 subagent task prompt
4290
4291 ## Assistant
4292
4293 subagent task response
4294
4295 ## User
4296
4297 subagent summary prompt
4298
4299 ## Assistant
4300
4301 subagent summary response
4302
4303 "}
4304 );
4305
4306 model.send_last_completion_stream_text_chunk("Response");
4307 model.end_last_completion_stream();
4308
4309 send.await.unwrap();
4310
4311 assert_eq!(
4312 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4313 format!(
4314 indoc! {r#"
4315 ## User
4316
4317 Prompt
4318
4319 ## Assistant
4320
4321 spawning subagent
4322
4323 **Tool Call: label**
4324 Status: Completed
4325
4326 ```json
4327 {{
4328 "subagent_session_id": "{}",
4329 "summary": "subagent summary response\n"
4330 }}
4331 ```
4332
4333 ## Assistant
4334
4335 Response
4336
4337 "#},
4338 subagent_session_id
4339 )
4340 );
4341}
4342
4343#[gpui::test]
4344async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4345 init_test(cx);
4346 cx.update(|cx| {
4347 LanguageModelRegistry::test(cx);
4348 });
4349 cx.update(|cx| {
4350 cx.update_flags(true, vec!["subagents".to_string()]);
4351 });
4352
4353 let fs = FakeFs::new(cx.executor());
4354 fs.insert_tree(
4355 "/",
4356 json!({
4357 "a": {
4358 "b.md": "Lorem"
4359 }
4360 }),
4361 )
4362 .await;
4363 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4364 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4365 let agent = NativeAgent::new(
4366 project.clone(),
4367 thread_store.clone(),
4368 Templates::new(),
4369 None,
4370 fs.clone(),
4371 &mut cx.to_async(),
4372 )
4373 .await
4374 .unwrap();
4375 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4376
4377 let acp_thread = cx
4378 .update(|cx| {
4379 connection
4380 .clone()
4381 .new_session(project.clone(), Path::new(""), cx)
4382 })
4383 .await
4384 .unwrap();
4385 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4386 let thread = agent.read_with(cx, |agent, _| {
4387 agent.sessions.get(&session_id).unwrap().thread.clone()
4388 });
4389 let model = Arc::new(FakeLanguageModel::default());
4390
4391 // Ensure empty threads are not saved, even if they get mutated.
4392 thread.update(cx, |thread, cx| {
4393 thread.set_model(model.clone(), cx);
4394 });
4395 cx.run_until_parked();
4396
4397 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4398 cx.run_until_parked();
4399 model.send_last_completion_stream_text_chunk("spawning subagent");
4400 let subagent_tool_input = SubagentToolInput {
4401 label: "label".to_string(),
4402 task_prompt: "subagent task prompt".to_string(),
4403 summary_prompt: "subagent summary prompt".to_string(),
4404 timeout_ms: None,
4405 allowed_tools: None,
4406 };
4407 let subagent_tool_use = LanguageModelToolUse {
4408 id: "subagent_1".into(),
4409 name: SubagentTool::NAME.into(),
4410 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4411 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4412 is_input_complete: true,
4413 thought_signature: None,
4414 };
4415 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4416 subagent_tool_use,
4417 ));
4418 model.end_last_completion_stream();
4419
4420 cx.run_until_parked();
4421
4422 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4423 thread
4424 .running_subagent_ids(cx)
4425 .get(0)
4426 .expect("subagent thread should be running")
4427 .clone()
4428 });
4429 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4430 agent
4431 .sessions
4432 .get(&subagent_session_id)
4433 .expect("subagent session should exist")
4434 .acp_thread
4435 .clone()
4436 });
4437
4438 // model.send_last_completion_stream_text_chunk("subagent task response");
4439 // model.end_last_completion_stream();
4440
4441 // cx.run_until_parked();
4442
4443 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4444
4445 cx.run_until_parked();
4446
4447 send.await.unwrap();
4448
4449 acp_thread.read_with(cx, |thread, cx| {
4450 assert_eq!(thread.status(), ThreadStatus::Idle);
4451 assert_eq!(
4452 thread.to_markdown(cx),
4453 indoc! {"
4454 ## User
4455
4456 Prompt
4457
4458 ## Assistant
4459
4460 spawning subagent
4461
4462 **Tool Call: label**
4463 Status: Canceled
4464
4465 "}
4466 );
4467 });
4468 subagent_acp_thread.read_with(cx, |thread, cx| {
4469 assert_eq!(thread.status(), ThreadStatus::Idle);
4470 assert_eq!(
4471 thread.to_markdown(cx),
4472 indoc! {"
4473 ## User
4474
4475 subagent task prompt
4476
4477 "}
4478 );
4479 });
4480}
4481
4482#[gpui::test]
4483async fn test_subagent_tool_call_cancellation_during_summary_prompt(cx: &mut TestAppContext) {
4484 init_test(cx);
4485 cx.update(|cx| {
4486 LanguageModelRegistry::test(cx);
4487 });
4488 cx.update(|cx| {
4489 cx.update_flags(true, vec!["subagents".to_string()]);
4490 });
4491
4492 let fs = FakeFs::new(cx.executor());
4493 fs.insert_tree(
4494 "/",
4495 json!({
4496 "a": {
4497 "b.md": "Lorem"
4498 }
4499 }),
4500 )
4501 .await;
4502 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4503 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4504 let agent = NativeAgent::new(
4505 project.clone(),
4506 thread_store.clone(),
4507 Templates::new(),
4508 None,
4509 fs.clone(),
4510 &mut cx.to_async(),
4511 )
4512 .await
4513 .unwrap();
4514 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4515
4516 let acp_thread = cx
4517 .update(|cx| {
4518 connection
4519 .clone()
4520 .new_session(project.clone(), Path::new(""), cx)
4521 })
4522 .await
4523 .unwrap();
4524 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4525 let thread = agent.read_with(cx, |agent, _| {
4526 agent.sessions.get(&session_id).unwrap().thread.clone()
4527 });
4528 let model = Arc::new(FakeLanguageModel::default());
4529
4530 // Ensure empty threads are not saved, even if they get mutated.
4531 thread.update(cx, |thread, cx| {
4532 thread.set_model(model.clone(), cx);
4533 });
4534 cx.run_until_parked();
4535
4536 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4537 cx.run_until_parked();
4538 model.send_last_completion_stream_text_chunk("spawning subagent");
4539 let subagent_tool_input = SubagentToolInput {
4540 label: "label".to_string(),
4541 task_prompt: "subagent task prompt".to_string(),
4542 summary_prompt: "subagent summary prompt".to_string(),
4543 timeout_ms: None,
4544 allowed_tools: None,
4545 };
4546 let subagent_tool_use = LanguageModelToolUse {
4547 id: "subagent_1".into(),
4548 name: SubagentTool::NAME.into(),
4549 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4550 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4551 is_input_complete: true,
4552 thought_signature: None,
4553 };
4554 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4555 subagent_tool_use,
4556 ));
4557 model.end_last_completion_stream();
4558
4559 cx.run_until_parked();
4560
4561 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4562 thread
4563 .running_subagent_ids(cx)
4564 .get(0)
4565 .expect("subagent thread should be running")
4566 .clone()
4567 });
4568 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4569 agent
4570 .sessions
4571 .get(&subagent_session_id)
4572 .expect("subagent session should exist")
4573 .acp_thread
4574 .clone()
4575 });
4576
4577 model.send_last_completion_stream_text_chunk("subagent task response");
4578 model.end_last_completion_stream();
4579
4580 cx.run_until_parked();
4581
4582 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4583
4584 cx.run_until_parked();
4585
4586 send.await.unwrap();
4587
4588 acp_thread.read_with(cx, |thread, cx| {
4589 assert_eq!(thread.status(), ThreadStatus::Idle);
4590 assert_eq!(
4591 thread.to_markdown(cx),
4592 indoc! {"
4593 ## User
4594
4595 Prompt
4596
4597 ## Assistant
4598
4599 spawning subagent
4600
4601 **Tool Call: label**
4602 Status: Canceled
4603
4604 "}
4605 );
4606 });
4607 subagent_acp_thread.read_with(cx, |thread, cx| {
4608 assert_eq!(thread.status(), ThreadStatus::Idle);
4609 assert_eq!(
4610 thread.to_markdown(cx),
4611 indoc! {"
4612 ## User
4613
4614 subagent task prompt
4615
4616 ## Assistant
4617
4618 subagent task response
4619
4620 ## User
4621
4622 subagent summary prompt
4623
4624 "}
4625 );
4626 });
4627}
4628
4629#[gpui::test]
4630async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4631 init_test(cx);
4632
4633 cx.update(|cx| {
4634 cx.update_flags(true, vec!["subagents".to_string()]);
4635 });
4636
4637 let fs = FakeFs::new(cx.executor());
4638 fs.insert_tree(path!("/test"), json!({})).await;
4639 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4640 let project_context = cx.new(|_cx| ProjectContext::default());
4641 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4642 let context_server_registry =
4643 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4644 let model = Arc::new(FakeLanguageModel::default());
4645
4646 let environment = Rc::new(cx.update(|cx| {
4647 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4648 }));
4649
4650 let thread = cx.new(|cx| {
4651 let mut thread = Thread::new(
4652 project.clone(),
4653 project_context,
4654 context_server_registry,
4655 Templates::new(),
4656 Some(model),
4657 cx,
4658 );
4659 thread.add_default_tools(None, environment, cx);
4660 thread
4661 });
4662
4663 thread.read_with(cx, |thread, _| {
4664 assert!(
4665 thread.has_registered_tool(SubagentTool::NAME),
4666 "subagent tool should be present when feature flag is enabled"
4667 );
4668 });
4669}
4670
4671#[gpui::test]
4672async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4673 init_test(cx);
4674
4675 cx.update(|cx| {
4676 cx.update_flags(true, vec!["subagents".to_string()]);
4677 });
4678
4679 let fs = FakeFs::new(cx.executor());
4680 fs.insert_tree(path!("/test"), json!({})).await;
4681 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4682 let project_context = cx.new(|_cx| ProjectContext::default());
4683 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4684 let context_server_registry =
4685 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4686 let model = Arc::new(FakeLanguageModel::default());
4687
4688 let parent_thread = cx.new(|cx| {
4689 Thread::new(
4690 project.clone(),
4691 project_context,
4692 context_server_registry,
4693 Templates::new(),
4694 Some(model.clone()),
4695 cx,
4696 )
4697 });
4698
4699 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4700 subagent_thread.read_with(cx, |subagent_thread, cx| {
4701 assert!(subagent_thread.is_subagent());
4702 assert_eq!(subagent_thread.depth(), 1);
4703 assert_eq!(
4704 subagent_thread.model().map(|model| model.id()),
4705 Some(model.id())
4706 );
4707 assert_eq!(
4708 subagent_thread.parent_thread_id(),
4709 Some(parent_thread.read(cx).id().clone())
4710 );
4711 });
4712}
4713
4714#[gpui::test]
4715async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4716 init_test(cx);
4717
4718 cx.update(|cx| {
4719 cx.update_flags(true, vec!["subagents".to_string()]);
4720 });
4721
4722 let fs = FakeFs::new(cx.executor());
4723 fs.insert_tree(path!("/test"), json!({})).await;
4724 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4725 let project_context = cx.new(|_cx| ProjectContext::default());
4726 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4727 let context_server_registry =
4728 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4729 let model = Arc::new(FakeLanguageModel::default());
4730 let environment = Rc::new(cx.update(|cx| {
4731 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4732 }));
4733
4734 let deep_parent_thread = cx.new(|cx| {
4735 let mut thread = Thread::new(
4736 project.clone(),
4737 project_context,
4738 context_server_registry,
4739 Templates::new(),
4740 Some(model.clone()),
4741 cx,
4742 );
4743 thread.set_subagent_context(SubagentContext {
4744 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4745 depth: MAX_SUBAGENT_DEPTH - 1,
4746 });
4747 thread
4748 });
4749 let deep_subagent_thread = cx.new(|cx| {
4750 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4751 thread.add_default_tools(None, environment, cx);
4752 thread
4753 });
4754
4755 deep_subagent_thread.read_with(cx, |thread, _| {
4756 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4757 assert!(
4758 !thread.has_registered_tool(SubagentTool::NAME),
4759 "subagent tool should not be present at max depth"
4760 );
4761 });
4762}
4763
4764#[gpui::test]
4765async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4766 init_test(cx);
4767
4768 cx.update(|cx| {
4769 cx.update_flags(true, vec!["subagents".to_string()]);
4770 });
4771
4772 let fs = FakeFs::new(cx.executor());
4773 fs.insert_tree(path!("/test"), json!({})).await;
4774 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4775 let project_context = cx.new(|_cx| ProjectContext::default());
4776 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4777 let context_server_registry =
4778 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4779 let model = Arc::new(FakeLanguageModel::default());
4780
4781 let parent = cx.new(|cx| {
4782 Thread::new(
4783 project.clone(),
4784 project_context.clone(),
4785 context_server_registry.clone(),
4786 Templates::new(),
4787 Some(model.clone()),
4788 cx,
4789 )
4790 });
4791
4792 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4793
4794 parent.update(cx, |thread, _cx| {
4795 thread.register_running_subagent(subagent.downgrade());
4796 });
4797
4798 subagent
4799 .update(cx, |thread, cx| {
4800 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4801 })
4802 .unwrap();
4803 cx.run_until_parked();
4804
4805 subagent.read_with(cx, |thread, _| {
4806 assert!(!thread.is_turn_complete(), "subagent should be running");
4807 });
4808
4809 parent.update(cx, |thread, cx| {
4810 thread.cancel(cx).detach();
4811 });
4812
4813 subagent.read_with(cx, |thread, _| {
4814 assert!(
4815 thread.is_turn_complete(),
4816 "subagent should be cancelled when parent cancels"
4817 );
4818 });
4819}
4820
4821#[gpui::test]
4822async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4823 init_test(cx);
4824
4825 always_allow_tools(cx);
4826
4827 cx.update(|cx| {
4828 cx.update_flags(true, vec!["subagents".to_string()]);
4829 });
4830
4831 let fs = FakeFs::new(cx.executor());
4832 fs.insert_tree(path!("/test"), json!({})).await;
4833 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4834 let project_context = cx.new(|_cx| ProjectContext::default());
4835 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4836 let context_server_registry =
4837 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4838 cx.update(LanguageModelRegistry::test);
4839 let model = Arc::new(FakeLanguageModel::default());
4840 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4841 let native_agent = NativeAgent::new(
4842 project.clone(),
4843 thread_store,
4844 Templates::new(),
4845 None,
4846 fs,
4847 &mut cx.to_async(),
4848 )
4849 .await
4850 .unwrap();
4851 let parent_thread = cx.new(|cx| {
4852 Thread::new(
4853 project.clone(),
4854 project_context,
4855 context_server_registry,
4856 Templates::new(),
4857 Some(model.clone()),
4858 cx,
4859 )
4860 });
4861
4862 let subagent_handle = cx
4863 .update(|cx| {
4864 NativeThreadEnvironment::create_subagent_thread(
4865 native_agent.downgrade(),
4866 parent_thread.clone(),
4867 "some title".to_string(),
4868 "task prompt".to_string(),
4869 Some(Duration::from_millis(10)),
4870 None,
4871 cx,
4872 )
4873 })
4874 .expect("Failed to create subagent");
4875
4876 let summary_task =
4877 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4878
4879 cx.run_until_parked();
4880
4881 {
4882 let messages = model.pending_completions().last().unwrap().messages.clone();
4883 // Ensure that model received a system prompt
4884 assert_eq!(messages[0].role, Role::System);
4885 // Ensure that model received a task prompt
4886 assert_eq!(messages[1].role, Role::User);
4887 assert_eq!(
4888 messages[1].content,
4889 vec![MessageContent::Text("task prompt".to_string())]
4890 );
4891 }
4892
4893 model.send_last_completion_stream_text_chunk("Some task response...");
4894 model.end_last_completion_stream();
4895
4896 cx.run_until_parked();
4897
4898 {
4899 let messages = model.pending_completions().last().unwrap().messages.clone();
4900 assert_eq!(messages[2].role, Role::Assistant);
4901 assert_eq!(
4902 messages[2].content,
4903 vec![MessageContent::Text("Some task response...".to_string())]
4904 );
4905 // Ensure that model received a summary prompt
4906 assert_eq!(messages[3].role, Role::User);
4907 assert_eq!(
4908 messages[3].content,
4909 vec![MessageContent::Text("summary prompt".to_string())]
4910 );
4911 }
4912
4913 model.send_last_completion_stream_text_chunk("Some summary...");
4914 model.end_last_completion_stream();
4915
4916 let result = summary_task.await;
4917 assert_eq!(result.unwrap(), "Some summary...\n");
4918}
4919
4920#[gpui::test]
4921async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4922 cx: &mut TestAppContext,
4923) {
4924 init_test(cx);
4925
4926 always_allow_tools(cx);
4927
4928 cx.update(|cx| {
4929 cx.update_flags(true, vec!["subagents".to_string()]);
4930 });
4931
4932 let fs = FakeFs::new(cx.executor());
4933 fs.insert_tree(path!("/test"), json!({})).await;
4934 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4935 let project_context = cx.new(|_cx| ProjectContext::default());
4936 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4937 let context_server_registry =
4938 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4939 cx.update(LanguageModelRegistry::test);
4940 let model = Arc::new(FakeLanguageModel::default());
4941 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4942 let native_agent = NativeAgent::new(
4943 project.clone(),
4944 thread_store,
4945 Templates::new(),
4946 None,
4947 fs,
4948 &mut cx.to_async(),
4949 )
4950 .await
4951 .unwrap();
4952 let parent_thread = cx.new(|cx| {
4953 Thread::new(
4954 project.clone(),
4955 project_context,
4956 context_server_registry,
4957 Templates::new(),
4958 Some(model.clone()),
4959 cx,
4960 )
4961 });
4962
4963 let subagent_handle = cx
4964 .update(|cx| {
4965 NativeThreadEnvironment::create_subagent_thread(
4966 native_agent.downgrade(),
4967 parent_thread.clone(),
4968 "some title".to_string(),
4969 "task prompt".to_string(),
4970 Some(Duration::from_millis(100)),
4971 None,
4972 cx,
4973 )
4974 })
4975 .expect("Failed to create subagent");
4976
4977 let summary_task =
4978 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4979
4980 cx.run_until_parked();
4981
4982 {
4983 let messages = model.pending_completions().last().unwrap().messages.clone();
4984 // Ensure that model received a system prompt
4985 assert_eq!(messages[0].role, Role::System);
4986 // Ensure that model received a task prompt
4987 assert_eq!(
4988 messages[1].content,
4989 vec![MessageContent::Text("task prompt".to_string())]
4990 );
4991 }
4992
4993 // Don't complete the initial model stream — let the timeout expire instead.
4994 cx.executor().advance_clock(Duration::from_millis(200));
4995 cx.run_until_parked();
4996
4997 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
4998 // instead of the summary_prompt.
4999 {
5000 let messages = model.pending_completions().last().unwrap().messages.clone();
5001 let last_user_message = messages
5002 .iter()
5003 .rev()
5004 .find(|m| m.role == Role::User)
5005 .unwrap();
5006 assert_eq!(
5007 last_user_message.content,
5008 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
5009 );
5010 }
5011
5012 model.send_last_completion_stream_text_chunk("Some context low response...");
5013 model.end_last_completion_stream();
5014
5015 let result = summary_task.await;
5016 assert_eq!(result.unwrap(), "Some context low response...\n");
5017}
5018
5019#[gpui::test]
5020async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
5021 init_test(cx);
5022
5023 always_allow_tools(cx);
5024
5025 cx.update(|cx| {
5026 cx.update_flags(true, vec!["subagents".to_string()]);
5027 });
5028
5029 let fs = FakeFs::new(cx.executor());
5030 fs.insert_tree(path!("/test"), json!({})).await;
5031 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
5032 let project_context = cx.new(|_cx| ProjectContext::default());
5033 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5034 let context_server_registry =
5035 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5036 cx.update(LanguageModelRegistry::test);
5037 let model = Arc::new(FakeLanguageModel::default());
5038 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5039 let native_agent = NativeAgent::new(
5040 project.clone(),
5041 thread_store,
5042 Templates::new(),
5043 None,
5044 fs,
5045 &mut cx.to_async(),
5046 )
5047 .await
5048 .unwrap();
5049 let parent_thread = cx.new(|cx| {
5050 let mut thread = Thread::new(
5051 project.clone(),
5052 project_context,
5053 context_server_registry,
5054 Templates::new(),
5055 Some(model.clone()),
5056 cx,
5057 );
5058 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
5059 thread.add_tool(GrepTool::new(project.clone()), None);
5060 thread
5061 });
5062
5063 let _subagent_handle = cx
5064 .update(|cx| {
5065 NativeThreadEnvironment::create_subagent_thread(
5066 native_agent.downgrade(),
5067 parent_thread.clone(),
5068 "some title".to_string(),
5069 "task prompt".to_string(),
5070 Some(Duration::from_millis(10)),
5071 None,
5072 cx,
5073 )
5074 })
5075 .expect("Failed to create subagent");
5076
5077 cx.run_until_parked();
5078
5079 let tools = model
5080 .pending_completions()
5081 .last()
5082 .unwrap()
5083 .tools
5084 .iter()
5085 .map(|tool| tool.name.clone())
5086 .collect::<Vec<_>>();
5087 assert_eq!(tools.len(), 2);
5088 assert!(tools.contains(&"grep".to_string()));
5089 assert!(tools.contains(&"list_directory".to_string()));
5090}
5091
5092#[gpui::test]
5093async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
5094 init_test(cx);
5095
5096 always_allow_tools(cx);
5097
5098 cx.update(|cx| {
5099 cx.update_flags(true, vec!["subagents".to_string()]);
5100 });
5101
5102 let fs = FakeFs::new(cx.executor());
5103 fs.insert_tree(path!("/test"), json!({})).await;
5104 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
5105 let project_context = cx.new(|_cx| ProjectContext::default());
5106 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5107 let context_server_registry =
5108 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5109 cx.update(LanguageModelRegistry::test);
5110 let model = Arc::new(FakeLanguageModel::default());
5111 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5112 let native_agent = NativeAgent::new(
5113 project.clone(),
5114 thread_store,
5115 Templates::new(),
5116 None,
5117 fs,
5118 &mut cx.to_async(),
5119 )
5120 .await
5121 .unwrap();
5122 let parent_thread = cx.new(|cx| {
5123 let mut thread = Thread::new(
5124 project.clone(),
5125 project_context,
5126 context_server_registry,
5127 Templates::new(),
5128 Some(model.clone()),
5129 cx,
5130 );
5131 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
5132 thread.add_tool(GrepTool::new(project.clone()), None);
5133 thread
5134 });
5135
5136 let _subagent_handle = cx
5137 .update(|cx| {
5138 NativeThreadEnvironment::create_subagent_thread(
5139 native_agent.downgrade(),
5140 parent_thread.clone(),
5141 "some title".to_string(),
5142 "task prompt".to_string(),
5143 Some(Duration::from_millis(10)),
5144 Some(vec!["grep".to_string()]),
5145 cx,
5146 )
5147 })
5148 .expect("Failed to create subagent");
5149
5150 cx.run_until_parked();
5151
5152 let tools = model
5153 .pending_completions()
5154 .last()
5155 .unwrap()
5156 .tools
5157 .iter()
5158 .map(|tool| tool.name.clone())
5159 .collect::<Vec<_>>();
5160 assert_eq!(tools, vec!["grep"]);
5161}
5162
5163#[gpui::test]
5164async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5165 init_test(cx);
5166
5167 let fs = FakeFs::new(cx.executor());
5168 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5169 .await;
5170 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5171
5172 cx.update(|cx| {
5173 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5174 settings.tool_permissions.tools.insert(
5175 EditFileTool::NAME.into(),
5176 agent_settings::ToolRules {
5177 default: Some(settings::ToolPermissionMode::Allow),
5178 always_allow: vec![],
5179 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5180 always_confirm: vec![],
5181 invalid_patterns: vec![],
5182 },
5183 );
5184 agent_settings::AgentSettings::override_global(settings, cx);
5185 });
5186
5187 let context_server_registry =
5188 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5189 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5190 let templates = crate::Templates::new();
5191 let thread = cx.new(|cx| {
5192 crate::Thread::new(
5193 project.clone(),
5194 cx.new(|_cx| prompt_store::ProjectContext::default()),
5195 context_server_registry,
5196 templates.clone(),
5197 None,
5198 cx,
5199 )
5200 });
5201
5202 #[allow(clippy::arc_with_non_send_sync)]
5203 let tool = Arc::new(crate::EditFileTool::new(
5204 project.clone(),
5205 thread.downgrade(),
5206 language_registry,
5207 templates,
5208 ));
5209 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5210
5211 let task = cx.update(|cx| {
5212 tool.run(
5213 crate::EditFileToolInput {
5214 display_description: "Edit sensitive file".to_string(),
5215 path: "root/sensitive_config.txt".into(),
5216 mode: crate::EditFileMode::Edit,
5217 },
5218 event_stream,
5219 cx,
5220 )
5221 });
5222
5223 let result = task.await;
5224 assert!(result.is_err(), "expected edit to be blocked");
5225 assert!(
5226 result.unwrap_err().to_string().contains("blocked"),
5227 "error should mention the edit was blocked"
5228 );
5229}
5230
5231#[gpui::test]
5232async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5233 init_test(cx);
5234
5235 let fs = FakeFs::new(cx.executor());
5236 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5237 .await;
5238 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5239
5240 cx.update(|cx| {
5241 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5242 settings.tool_permissions.tools.insert(
5243 DeletePathTool::NAME.into(),
5244 agent_settings::ToolRules {
5245 default: Some(settings::ToolPermissionMode::Allow),
5246 always_allow: vec![],
5247 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5248 always_confirm: vec![],
5249 invalid_patterns: vec![],
5250 },
5251 );
5252 agent_settings::AgentSettings::override_global(settings, cx);
5253 });
5254
5255 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5256
5257 #[allow(clippy::arc_with_non_send_sync)]
5258 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5259 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5260
5261 let task = cx.update(|cx| {
5262 tool.run(
5263 crate::DeletePathToolInput {
5264 path: "root/important_data.txt".to_string(),
5265 },
5266 event_stream,
5267 cx,
5268 )
5269 });
5270
5271 let result = task.await;
5272 assert!(result.is_err(), "expected deletion to be blocked");
5273 assert!(
5274 result.unwrap_err().to_string().contains("blocked"),
5275 "error should mention the deletion was blocked"
5276 );
5277}
5278
5279#[gpui::test]
5280async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5281 init_test(cx);
5282
5283 let fs = FakeFs::new(cx.executor());
5284 fs.insert_tree(
5285 "/root",
5286 json!({
5287 "safe.txt": "content",
5288 "protected": {}
5289 }),
5290 )
5291 .await;
5292 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5293
5294 cx.update(|cx| {
5295 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5296 settings.tool_permissions.tools.insert(
5297 MovePathTool::NAME.into(),
5298 agent_settings::ToolRules {
5299 default: Some(settings::ToolPermissionMode::Allow),
5300 always_allow: vec![],
5301 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5302 always_confirm: vec![],
5303 invalid_patterns: vec![],
5304 },
5305 );
5306 agent_settings::AgentSettings::override_global(settings, cx);
5307 });
5308
5309 #[allow(clippy::arc_with_non_send_sync)]
5310 let tool = Arc::new(crate::MovePathTool::new(project));
5311 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5312
5313 let task = cx.update(|cx| {
5314 tool.run(
5315 crate::MovePathToolInput {
5316 source_path: "root/safe.txt".to_string(),
5317 destination_path: "root/protected/safe.txt".to_string(),
5318 },
5319 event_stream,
5320 cx,
5321 )
5322 });
5323
5324 let result = task.await;
5325 assert!(
5326 result.is_err(),
5327 "expected move to be blocked due to destination path"
5328 );
5329 assert!(
5330 result.unwrap_err().to_string().contains("blocked"),
5331 "error should mention the move was blocked"
5332 );
5333}
5334
5335#[gpui::test]
5336async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5337 init_test(cx);
5338
5339 let fs = FakeFs::new(cx.executor());
5340 fs.insert_tree(
5341 "/root",
5342 json!({
5343 "secret.txt": "secret content",
5344 "public": {}
5345 }),
5346 )
5347 .await;
5348 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5349
5350 cx.update(|cx| {
5351 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5352 settings.tool_permissions.tools.insert(
5353 MovePathTool::NAME.into(),
5354 agent_settings::ToolRules {
5355 default: Some(settings::ToolPermissionMode::Allow),
5356 always_allow: vec![],
5357 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5358 always_confirm: vec![],
5359 invalid_patterns: vec![],
5360 },
5361 );
5362 agent_settings::AgentSettings::override_global(settings, cx);
5363 });
5364
5365 #[allow(clippy::arc_with_non_send_sync)]
5366 let tool = Arc::new(crate::MovePathTool::new(project));
5367 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5368
5369 let task = cx.update(|cx| {
5370 tool.run(
5371 crate::MovePathToolInput {
5372 source_path: "root/secret.txt".to_string(),
5373 destination_path: "root/public/not_secret.txt".to_string(),
5374 },
5375 event_stream,
5376 cx,
5377 )
5378 });
5379
5380 let result = task.await;
5381 assert!(
5382 result.is_err(),
5383 "expected move to be blocked due to source path"
5384 );
5385 assert!(
5386 result.unwrap_err().to_string().contains("blocked"),
5387 "error should mention the move was blocked"
5388 );
5389}
5390
5391#[gpui::test]
5392async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5393 init_test(cx);
5394
5395 let fs = FakeFs::new(cx.executor());
5396 fs.insert_tree(
5397 "/root",
5398 json!({
5399 "confidential.txt": "confidential data",
5400 "dest": {}
5401 }),
5402 )
5403 .await;
5404 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5405
5406 cx.update(|cx| {
5407 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5408 settings.tool_permissions.tools.insert(
5409 CopyPathTool::NAME.into(),
5410 agent_settings::ToolRules {
5411 default: Some(settings::ToolPermissionMode::Allow),
5412 always_allow: vec![],
5413 always_deny: vec![
5414 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5415 ],
5416 always_confirm: vec![],
5417 invalid_patterns: vec![],
5418 },
5419 );
5420 agent_settings::AgentSettings::override_global(settings, cx);
5421 });
5422
5423 #[allow(clippy::arc_with_non_send_sync)]
5424 let tool = Arc::new(crate::CopyPathTool::new(project));
5425 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5426
5427 let task = cx.update(|cx| {
5428 tool.run(
5429 crate::CopyPathToolInput {
5430 source_path: "root/confidential.txt".to_string(),
5431 destination_path: "root/dest/copy.txt".to_string(),
5432 },
5433 event_stream,
5434 cx,
5435 )
5436 });
5437
5438 let result = task.await;
5439 assert!(result.is_err(), "expected copy to be blocked");
5440 assert!(
5441 result.unwrap_err().to_string().contains("blocked"),
5442 "error should mention the copy was blocked"
5443 );
5444}
5445
5446#[gpui::test]
5447async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5448 init_test(cx);
5449
5450 let fs = FakeFs::new(cx.executor());
5451 fs.insert_tree(
5452 "/root",
5453 json!({
5454 "normal.txt": "normal content",
5455 "readonly": {
5456 "config.txt": "readonly content"
5457 }
5458 }),
5459 )
5460 .await;
5461 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5462
5463 cx.update(|cx| {
5464 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5465 settings.tool_permissions.tools.insert(
5466 SaveFileTool::NAME.into(),
5467 agent_settings::ToolRules {
5468 default: Some(settings::ToolPermissionMode::Allow),
5469 always_allow: vec![],
5470 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5471 always_confirm: vec![],
5472 invalid_patterns: vec![],
5473 },
5474 );
5475 agent_settings::AgentSettings::override_global(settings, cx);
5476 });
5477
5478 #[allow(clippy::arc_with_non_send_sync)]
5479 let tool = Arc::new(crate::SaveFileTool::new(project));
5480 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5481
5482 let task = cx.update(|cx| {
5483 tool.run(
5484 crate::SaveFileToolInput {
5485 paths: vec![
5486 std::path::PathBuf::from("root/normal.txt"),
5487 std::path::PathBuf::from("root/readonly/config.txt"),
5488 ],
5489 },
5490 event_stream,
5491 cx,
5492 )
5493 });
5494
5495 let result = task.await;
5496 assert!(
5497 result.is_err(),
5498 "expected save to be blocked due to denied path"
5499 );
5500 assert!(
5501 result.unwrap_err().to_string().contains("blocked"),
5502 "error should mention the save was blocked"
5503 );
5504}
5505
5506#[gpui::test]
5507async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5508 init_test(cx);
5509
5510 let fs = FakeFs::new(cx.executor());
5511 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5512 .await;
5513 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5514
5515 cx.update(|cx| {
5516 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5517 settings.tool_permissions.tools.insert(
5518 SaveFileTool::NAME.into(),
5519 agent_settings::ToolRules {
5520 default: Some(settings::ToolPermissionMode::Allow),
5521 always_allow: vec![],
5522 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5523 always_confirm: vec![],
5524 invalid_patterns: vec![],
5525 },
5526 );
5527 agent_settings::AgentSettings::override_global(settings, cx);
5528 });
5529
5530 #[allow(clippy::arc_with_non_send_sync)]
5531 let tool = Arc::new(crate::SaveFileTool::new(project));
5532 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5533
5534 let task = cx.update(|cx| {
5535 tool.run(
5536 crate::SaveFileToolInput {
5537 paths: vec![std::path::PathBuf::from("root/config.secret")],
5538 },
5539 event_stream,
5540 cx,
5541 )
5542 });
5543
5544 let result = task.await;
5545 assert!(result.is_err(), "expected save to be blocked");
5546 assert!(
5547 result.unwrap_err().to_string().contains("blocked"),
5548 "error should mention the save was blocked"
5549 );
5550}
5551
5552#[gpui::test]
5553async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5554 init_test(cx);
5555
5556 cx.update(|cx| {
5557 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5558 settings.tool_permissions.tools.insert(
5559 WebSearchTool::NAME.into(),
5560 agent_settings::ToolRules {
5561 default: Some(settings::ToolPermissionMode::Allow),
5562 always_allow: vec![],
5563 always_deny: vec![
5564 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5565 ],
5566 always_confirm: vec![],
5567 invalid_patterns: vec![],
5568 },
5569 );
5570 agent_settings::AgentSettings::override_global(settings, cx);
5571 });
5572
5573 #[allow(clippy::arc_with_non_send_sync)]
5574 let tool = Arc::new(crate::WebSearchTool);
5575 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5576
5577 let input: crate::WebSearchToolInput =
5578 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5579
5580 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5581
5582 let result = task.await;
5583 assert!(result.is_err(), "expected search to be blocked");
5584 assert!(
5585 result.unwrap_err().to_string().contains("blocked"),
5586 "error should mention the search was blocked"
5587 );
5588}
5589
5590#[gpui::test]
5591async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5592 init_test(cx);
5593
5594 let fs = FakeFs::new(cx.executor());
5595 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5596 .await;
5597 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5598
5599 cx.update(|cx| {
5600 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5601 settings.tool_permissions.tools.insert(
5602 EditFileTool::NAME.into(),
5603 agent_settings::ToolRules {
5604 default: Some(settings::ToolPermissionMode::Confirm),
5605 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5606 always_deny: vec![],
5607 always_confirm: vec![],
5608 invalid_patterns: vec![],
5609 },
5610 );
5611 agent_settings::AgentSettings::override_global(settings, cx);
5612 });
5613
5614 let context_server_registry =
5615 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5616 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5617 let templates = crate::Templates::new();
5618 let thread = cx.new(|cx| {
5619 crate::Thread::new(
5620 project.clone(),
5621 cx.new(|_cx| prompt_store::ProjectContext::default()),
5622 context_server_registry,
5623 templates.clone(),
5624 None,
5625 cx,
5626 )
5627 });
5628
5629 #[allow(clippy::arc_with_non_send_sync)]
5630 let tool = Arc::new(crate::EditFileTool::new(
5631 project,
5632 thread.downgrade(),
5633 language_registry,
5634 templates,
5635 ));
5636 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5637
5638 let _task = cx.update(|cx| {
5639 tool.run(
5640 crate::EditFileToolInput {
5641 display_description: "Edit README".to_string(),
5642 path: "root/README.md".into(),
5643 mode: crate::EditFileMode::Edit,
5644 },
5645 event_stream,
5646 cx,
5647 )
5648 });
5649
5650 cx.run_until_parked();
5651
5652 let event = rx.try_next();
5653 assert!(
5654 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5655 "expected no authorization request for allowed .md file"
5656 );
5657}
5658
5659#[gpui::test]
5660async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5661 init_test(cx);
5662
5663 let fs = FakeFs::new(cx.executor());
5664 fs.insert_tree(
5665 "/root",
5666 json!({
5667 ".zed": {
5668 "settings.json": "{}"
5669 },
5670 "README.md": "# Hello"
5671 }),
5672 )
5673 .await;
5674 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5675
5676 cx.update(|cx| {
5677 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5678 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5679 agent_settings::AgentSettings::override_global(settings, cx);
5680 });
5681
5682 let context_server_registry =
5683 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5684 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5685 let templates = crate::Templates::new();
5686 let thread = cx.new(|cx| {
5687 crate::Thread::new(
5688 project.clone(),
5689 cx.new(|_cx| prompt_store::ProjectContext::default()),
5690 context_server_registry,
5691 templates.clone(),
5692 None,
5693 cx,
5694 )
5695 });
5696
5697 #[allow(clippy::arc_with_non_send_sync)]
5698 let tool = Arc::new(crate::EditFileTool::new(
5699 project,
5700 thread.downgrade(),
5701 language_registry,
5702 templates,
5703 ));
5704
5705 // Editing a file inside .zed/ should still prompt even with global default: allow,
5706 // because local settings paths are sensitive and require confirmation regardless.
5707 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5708 let _task = cx.update(|cx| {
5709 tool.run(
5710 crate::EditFileToolInput {
5711 display_description: "Edit local settings".to_string(),
5712 path: "root/.zed/settings.json".into(),
5713 mode: crate::EditFileMode::Edit,
5714 },
5715 event_stream,
5716 cx,
5717 )
5718 });
5719
5720 let _update = rx.expect_update_fields().await;
5721 let _auth = rx.expect_authorization().await;
5722}
5723
5724#[gpui::test]
5725async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5726 init_test(cx);
5727
5728 cx.update(|cx| {
5729 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5730 settings.tool_permissions.tools.insert(
5731 FetchTool::NAME.into(),
5732 agent_settings::ToolRules {
5733 default: Some(settings::ToolPermissionMode::Allow),
5734 always_allow: vec![],
5735 always_deny: vec![
5736 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5737 ],
5738 always_confirm: vec![],
5739 invalid_patterns: vec![],
5740 },
5741 );
5742 agent_settings::AgentSettings::override_global(settings, cx);
5743 });
5744
5745 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5746
5747 #[allow(clippy::arc_with_non_send_sync)]
5748 let tool = Arc::new(crate::FetchTool::new(http_client));
5749 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5750
5751 let input: crate::FetchToolInput =
5752 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5753
5754 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5755
5756 let result = task.await;
5757 assert!(result.is_err(), "expected fetch to be blocked");
5758 assert!(
5759 result.unwrap_err().to_string().contains("blocked"),
5760 "error should mention the fetch was blocked"
5761 );
5762}
5763
5764#[gpui::test]
5765async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5766 init_test(cx);
5767
5768 cx.update(|cx| {
5769 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5770 settings.tool_permissions.tools.insert(
5771 FetchTool::NAME.into(),
5772 agent_settings::ToolRules {
5773 default: Some(settings::ToolPermissionMode::Confirm),
5774 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5775 always_deny: vec![],
5776 always_confirm: vec![],
5777 invalid_patterns: vec![],
5778 },
5779 );
5780 agent_settings::AgentSettings::override_global(settings, cx);
5781 });
5782
5783 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5784
5785 #[allow(clippy::arc_with_non_send_sync)]
5786 let tool = Arc::new(crate::FetchTool::new(http_client));
5787 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5788
5789 let input: crate::FetchToolInput =
5790 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5791
5792 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5793
5794 cx.run_until_parked();
5795
5796 let event = rx.try_next();
5797 assert!(
5798 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5799 "expected no authorization request for allowed docs.rs URL"
5800 );
5801}
5802
5803#[gpui::test]
5804async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5805 init_test(cx);
5806 always_allow_tools(cx);
5807
5808 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5809 let fake_model = model.as_fake();
5810
5811 // Add a tool so we can simulate tool calls
5812 thread.update(cx, |thread, _cx| {
5813 thread.add_tool(EchoTool, None);
5814 });
5815
5816 // Start a turn by sending a message
5817 let mut events = thread
5818 .update(cx, |thread, cx| {
5819 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5820 })
5821 .unwrap();
5822 cx.run_until_parked();
5823
5824 // Simulate the model making a tool call
5825 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5826 LanguageModelToolUse {
5827 id: "tool_1".into(),
5828 name: "echo".into(),
5829 raw_input: r#"{"text": "hello"}"#.into(),
5830 input: json!({"text": "hello"}),
5831 is_input_complete: true,
5832 thought_signature: None,
5833 },
5834 ));
5835 fake_model
5836 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5837
5838 // Signal that a message is queued before ending the stream
5839 thread.update(cx, |thread, _cx| {
5840 thread.set_has_queued_message(true);
5841 });
5842
5843 // Now end the stream - tool will run, and the boundary check should see the queue
5844 fake_model.end_last_completion_stream();
5845
5846 // Collect all events until the turn stops
5847 let all_events = collect_events_until_stop(&mut events, cx).await;
5848
5849 // Verify we received the tool call event
5850 let tool_call_ids: Vec<_> = all_events
5851 .iter()
5852 .filter_map(|e| match e {
5853 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5854 _ => None,
5855 })
5856 .collect();
5857 assert_eq!(
5858 tool_call_ids,
5859 vec!["tool_1"],
5860 "Should have received a tool call event for our echo tool"
5861 );
5862
5863 // The turn should have stopped with EndTurn
5864 let stop_reasons = stop_events(all_events);
5865 assert_eq!(
5866 stop_reasons,
5867 vec![acp::StopReason::EndTurn],
5868 "Turn should have ended after tool completion due to queued message"
5869 );
5870
5871 // Verify the queued message flag is still set
5872 thread.update(cx, |thread, _cx| {
5873 assert!(
5874 thread.has_queued_message(),
5875 "Should still have queued message flag set"
5876 );
5877 });
5878
5879 // Thread should be idle now
5880 thread.update(cx, |thread, _cx| {
5881 assert!(
5882 thread.is_turn_complete(),
5883 "Thread should not be running after turn ends"
5884 );
5885 });
5886}