1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
33 fake_provider::FakeLanguageModel,
34};
35use pretty_assertions::assert_eq;
36use project::{
37 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
38};
39use prompt_store::ProjectContext;
40use reqwest_client::ReqwestClient;
41use schemars::JsonSchema;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use settings::{Settings, SettingsStore};
45use std::{
46 path::Path,
47 pin::Pin,
48 rc::Rc,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, Ordering},
52 },
53 time::Duration,
54};
55use util::path;
56
57mod edit_file_thread_test;
58mod test_tools;
59use test_tools::*;
60
61fn init_test(cx: &mut TestAppContext) {
62 cx.update(|cx| {
63 let settings_store = SettingsStore::test(cx);
64 cx.set_global(settings_store);
65 });
66}
67
68struct FakeTerminalHandle {
69 killed: Arc<AtomicBool>,
70 stopped_by_user: Arc<AtomicBool>,
71 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
72 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
73 output: acp::TerminalOutputResponse,
74 id: acp::TerminalId,
75}
76
77impl FakeTerminalHandle {
78 fn new_never_exits(cx: &mut App) -> Self {
79 let killed = Arc::new(AtomicBool::new(false));
80 let stopped_by_user = Arc::new(AtomicBool::new(false));
81
82 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
83
84 let wait_for_exit = cx
85 .spawn(async move |_cx| {
86 // Wait for the exit signal (sent when kill() is called)
87 let _ = exit_receiver.await;
88 acp::TerminalExitStatus::new()
89 })
90 .shared();
91
92 Self {
93 killed,
94 stopped_by_user,
95 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
96 wait_for_exit,
97 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
98 id: acp::TerminalId::new("fake_terminal".to_string()),
99 }
100 }
101
102 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
103 let killed = Arc::new(AtomicBool::new(false));
104 let stopped_by_user = Arc::new(AtomicBool::new(false));
105 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
106
107 let wait_for_exit = cx
108 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
109 .shared();
110
111 Self {
112 killed,
113 stopped_by_user,
114 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
115 wait_for_exit,
116 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
117 id: acp::TerminalId::new("fake_terminal".to_string()),
118 }
119 }
120
121 fn was_killed(&self) -> bool {
122 self.killed.load(Ordering::SeqCst)
123 }
124
125 fn set_stopped_by_user(&self, stopped: bool) {
126 self.stopped_by_user.store(stopped, Ordering::SeqCst);
127 }
128
129 fn signal_exit(&self) {
130 if let Some(sender) = self.exit_sender.borrow_mut().take() {
131 let _ = sender.send(());
132 }
133 }
134}
135
136impl crate::TerminalHandle for FakeTerminalHandle {
137 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
138 Ok(self.id.clone())
139 }
140
141 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
142 Ok(self.output.clone())
143 }
144
145 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
146 Ok(self.wait_for_exit.clone())
147 }
148
149 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
150 self.killed.store(true, Ordering::SeqCst);
151 self.signal_exit();
152 Ok(())
153 }
154
155 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
156 Ok(self.stopped_by_user.load(Ordering::SeqCst))
157 }
158}
159
160struct FakeSubagentHandle {
161 session_id: acp::SessionId,
162 wait_for_summary_task: Shared<Task<String>>,
163}
164
165impl SubagentHandle for FakeSubagentHandle {
166 fn id(&self) -> acp::SessionId {
167 self.session_id.clone()
168 }
169
170 fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
171 let task = self.wait_for_summary_task.clone();
172 cx.background_spawn(async move { Ok(task.await) })
173 }
174}
175
176#[derive(Default)]
177struct FakeThreadEnvironment {
178 terminal_handle: Option<Rc<FakeTerminalHandle>>,
179 subagent_handle: Option<Rc<FakeSubagentHandle>>,
180}
181
182impl FakeThreadEnvironment {
183 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
184 Self {
185 terminal_handle: Some(terminal_handle.into()),
186 ..self
187 }
188 }
189}
190
191impl crate::ThreadEnvironment for FakeThreadEnvironment {
192 fn create_terminal(
193 &self,
194 _command: String,
195 _cwd: Option<std::path::PathBuf>,
196 _output_byte_limit: Option<u64>,
197 _cx: &mut AsyncApp,
198 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
199 let handle = self
200 .terminal_handle
201 .clone()
202 .expect("Terminal handle not available on FakeThreadEnvironment");
203 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
204 }
205
206 fn create_subagent(
207 &self,
208 _parent_thread: Entity<Thread>,
209 _label: String,
210 _initial_prompt: String,
211 _cx: &mut App,
212 ) -> Result<Rc<dyn SubagentHandle>> {
213 Ok(self
214 .subagent_handle
215 .clone()
216 .expect("Subagent handle not available on FakeThreadEnvironment")
217 as Rc<dyn SubagentHandle>)
218 }
219}
220
221/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
222struct MultiTerminalEnvironment {
223 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
224}
225
226impl MultiTerminalEnvironment {
227 fn new() -> Self {
228 Self {
229 handles: std::cell::RefCell::new(Vec::new()),
230 }
231 }
232
233 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
234 self.handles.borrow().clone()
235 }
236}
237
238impl crate::ThreadEnvironment for MultiTerminalEnvironment {
239 fn create_terminal(
240 &self,
241 _command: String,
242 _cwd: Option<std::path::PathBuf>,
243 _output_byte_limit: Option<u64>,
244 cx: &mut AsyncApp,
245 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
246 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
247 self.handles.borrow_mut().push(handle.clone());
248 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
249 }
250
251 fn create_subagent(
252 &self,
253 _parent_thread: Entity<Thread>,
254 _label: String,
255 _initial_prompt: String,
256 _cx: &mut App,
257 ) -> Result<Rc<dyn SubagentHandle>> {
258 unimplemented!()
259 }
260}
261
262fn always_allow_tools(cx: &mut TestAppContext) {
263 cx.update(|cx| {
264 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
265 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
266 agent_settings::AgentSettings::override_global(settings, cx);
267 });
268}
269
270#[gpui::test]
271async fn test_echo(cx: &mut TestAppContext) {
272 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
273 let fake_model = model.as_fake();
274
275 let events = thread
276 .update(cx, |thread, cx| {
277 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
278 })
279 .unwrap();
280 cx.run_until_parked();
281 fake_model.send_last_completion_stream_text_chunk("Hello");
282 fake_model
283 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
284 fake_model.end_last_completion_stream();
285
286 let events = events.collect().await;
287 thread.update(cx, |thread, _cx| {
288 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
289 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
290 });
291 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
292}
293
294#[gpui::test]
295async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
296 init_test(cx);
297 always_allow_tools(cx);
298
299 let fs = FakeFs::new(cx.executor());
300 let project = Project::test(fs, [], cx).await;
301
302 let environment = Rc::new(cx.update(|cx| {
303 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
304 }));
305 let handle = environment.terminal_handle.clone().unwrap();
306
307 #[allow(clippy::arc_with_non_send_sync)]
308 let tool = Arc::new(crate::TerminalTool::new(project, environment));
309 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
310
311 let task = cx.update(|cx| {
312 tool.run(
313 crate::TerminalToolInput {
314 command: "sleep 1000".to_string(),
315 cd: ".".to_string(),
316 timeout_ms: Some(5),
317 },
318 event_stream,
319 cx,
320 )
321 });
322
323 let update = rx.expect_update_fields().await;
324 assert!(
325 update.content.iter().any(|blocks| {
326 blocks
327 .iter()
328 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
329 }),
330 "expected tool call update to include terminal content"
331 );
332
333 let mut task_future: Pin<Box<Fuse<Task<Result<String, String>>>>> = Box::pin(task.fuse());
334
335 let deadline = std::time::Instant::now() + Duration::from_millis(500);
336 loop {
337 if let Some(result) = task_future.as_mut().now_or_never() {
338 let result = result.expect("terminal tool task should complete");
339
340 assert!(
341 handle.was_killed(),
342 "expected terminal handle to be killed on timeout"
343 );
344 assert!(
345 result.contains("partial output"),
346 "expected result to include terminal output, got: {result}"
347 );
348 return;
349 }
350
351 if std::time::Instant::now() >= deadline {
352 panic!("timed out waiting for terminal tool task to complete");
353 }
354
355 cx.run_until_parked();
356 cx.background_executor.timer(Duration::from_millis(1)).await;
357 }
358}
359
360#[gpui::test]
361#[ignore]
362async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
363 init_test(cx);
364 always_allow_tools(cx);
365
366 let fs = FakeFs::new(cx.executor());
367 let project = Project::test(fs, [], cx).await;
368
369 let environment = Rc::new(cx.update(|cx| {
370 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
371 }));
372 let handle = environment.terminal_handle.clone().unwrap();
373
374 #[allow(clippy::arc_with_non_send_sync)]
375 let tool = Arc::new(crate::TerminalTool::new(project, environment));
376 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
377
378 let _task = cx.update(|cx| {
379 tool.run(
380 crate::TerminalToolInput {
381 command: "sleep 1000".to_string(),
382 cd: ".".to_string(),
383 timeout_ms: None,
384 },
385 event_stream,
386 cx,
387 )
388 });
389
390 let update = rx.expect_update_fields().await;
391 assert!(
392 update.content.iter().any(|blocks| {
393 blocks
394 .iter()
395 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
396 }),
397 "expected tool call update to include terminal content"
398 );
399
400 cx.background_executor
401 .timer(Duration::from_millis(25))
402 .await;
403
404 assert!(
405 !handle.was_killed(),
406 "did not expect terminal handle to be killed without a timeout"
407 );
408}
409
410#[gpui::test]
411async fn test_thinking(cx: &mut TestAppContext) {
412 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
413 let fake_model = model.as_fake();
414
415 let events = thread
416 .update(cx, |thread, cx| {
417 thread.send(
418 UserMessageId::new(),
419 [indoc! {"
420 Testing:
421
422 Generate a thinking step where you just think the word 'Think',
423 and have your final answer be 'Hello'
424 "}],
425 cx,
426 )
427 })
428 .unwrap();
429 cx.run_until_parked();
430 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
431 text: "Think".to_string(),
432 signature: None,
433 });
434 fake_model.send_last_completion_stream_text_chunk("Hello");
435 fake_model
436 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
437 fake_model.end_last_completion_stream();
438
439 let events = events.collect().await;
440 thread.update(cx, |thread, _cx| {
441 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
442 assert_eq!(
443 thread.last_message().unwrap().to_markdown(),
444 indoc! {"
445 <think>Think</think>
446 Hello
447 "}
448 )
449 });
450 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
451}
452
453#[gpui::test]
454async fn test_system_prompt(cx: &mut TestAppContext) {
455 let ThreadTest {
456 model,
457 thread,
458 project_context,
459 ..
460 } = setup(cx, TestModel::Fake).await;
461 let fake_model = model.as_fake();
462
463 project_context.update(cx, |project_context, _cx| {
464 project_context.shell = "test-shell".into()
465 });
466 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
467 thread
468 .update(cx, |thread, cx| {
469 thread.send(UserMessageId::new(), ["abc"], cx)
470 })
471 .unwrap();
472 cx.run_until_parked();
473 let mut pending_completions = fake_model.pending_completions();
474 assert_eq!(
475 pending_completions.len(),
476 1,
477 "unexpected pending completions: {:?}",
478 pending_completions
479 );
480
481 let pending_completion = pending_completions.pop().unwrap();
482 assert_eq!(pending_completion.messages[0].role, Role::System);
483
484 let system_message = &pending_completion.messages[0];
485 let system_prompt = system_message.content[0].to_str().unwrap();
486 assert!(
487 system_prompt.contains("test-shell"),
488 "unexpected system message: {:?}",
489 system_message
490 );
491 assert!(
492 system_prompt.contains("## Fixing Diagnostics"),
493 "unexpected system message: {:?}",
494 system_message
495 );
496}
497
498#[gpui::test]
499async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
500 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
501 let fake_model = model.as_fake();
502
503 thread
504 .update(cx, |thread, cx| {
505 thread.send(UserMessageId::new(), ["abc"], cx)
506 })
507 .unwrap();
508 cx.run_until_parked();
509 let mut pending_completions = fake_model.pending_completions();
510 assert_eq!(
511 pending_completions.len(),
512 1,
513 "unexpected pending completions: {:?}",
514 pending_completions
515 );
516
517 let pending_completion = pending_completions.pop().unwrap();
518 assert_eq!(pending_completion.messages[0].role, Role::System);
519
520 let system_message = &pending_completion.messages[0];
521 let system_prompt = system_message.content[0].to_str().unwrap();
522 assert!(
523 !system_prompt.contains("## Tool Use"),
524 "unexpected system message: {:?}",
525 system_message
526 );
527 assert!(
528 !system_prompt.contains("## Fixing Diagnostics"),
529 "unexpected system message: {:?}",
530 system_message
531 );
532}
533
534#[gpui::test]
535async fn test_prompt_caching(cx: &mut TestAppContext) {
536 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
537 let fake_model = model.as_fake();
538
539 // Send initial user message and verify it's cached
540 thread
541 .update(cx, |thread, cx| {
542 thread.send(UserMessageId::new(), ["Message 1"], cx)
543 })
544 .unwrap();
545 cx.run_until_parked();
546
547 let completion = fake_model.pending_completions().pop().unwrap();
548 assert_eq!(
549 completion.messages[1..],
550 vec![LanguageModelRequestMessage {
551 role: Role::User,
552 content: vec!["Message 1".into()],
553 cache: true,
554 reasoning_details: None,
555 }]
556 );
557 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
558 "Response to Message 1".into(),
559 ));
560 fake_model.end_last_completion_stream();
561 cx.run_until_parked();
562
563 // Send another user message and verify only the latest is cached
564 thread
565 .update(cx, |thread, cx| {
566 thread.send(UserMessageId::new(), ["Message 2"], cx)
567 })
568 .unwrap();
569 cx.run_until_parked();
570
571 let completion = fake_model.pending_completions().pop().unwrap();
572 assert_eq!(
573 completion.messages[1..],
574 vec![
575 LanguageModelRequestMessage {
576 role: Role::User,
577 content: vec!["Message 1".into()],
578 cache: false,
579 reasoning_details: None,
580 },
581 LanguageModelRequestMessage {
582 role: Role::Assistant,
583 content: vec!["Response to Message 1".into()],
584 cache: false,
585 reasoning_details: None,
586 },
587 LanguageModelRequestMessage {
588 role: Role::User,
589 content: vec!["Message 2".into()],
590 cache: true,
591 reasoning_details: None,
592 }
593 ]
594 );
595 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
596 "Response to Message 2".into(),
597 ));
598 fake_model.end_last_completion_stream();
599 cx.run_until_parked();
600
601 // Simulate a tool call and verify that the latest tool result is cached
602 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
603 thread
604 .update(cx, |thread, cx| {
605 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
606 })
607 .unwrap();
608 cx.run_until_parked();
609
610 let tool_use = LanguageModelToolUse {
611 id: "tool_1".into(),
612 name: EchoTool::NAME.into(),
613 raw_input: json!({"text": "test"}).to_string(),
614 input: json!({"text": "test"}),
615 is_input_complete: true,
616 thought_signature: None,
617 };
618 fake_model
619 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
620 fake_model.end_last_completion_stream();
621 cx.run_until_parked();
622
623 let completion = fake_model.pending_completions().pop().unwrap();
624 let tool_result = LanguageModelToolResult {
625 tool_use_id: "tool_1".into(),
626 tool_name: EchoTool::NAME.into(),
627 is_error: false,
628 content: "test".into(),
629 output: Some("test".into()),
630 };
631 assert_eq!(
632 completion.messages[1..],
633 vec![
634 LanguageModelRequestMessage {
635 role: Role::User,
636 content: vec!["Message 1".into()],
637 cache: false,
638 reasoning_details: None,
639 },
640 LanguageModelRequestMessage {
641 role: Role::Assistant,
642 content: vec!["Response to Message 1".into()],
643 cache: false,
644 reasoning_details: None,
645 },
646 LanguageModelRequestMessage {
647 role: Role::User,
648 content: vec!["Message 2".into()],
649 cache: false,
650 reasoning_details: None,
651 },
652 LanguageModelRequestMessage {
653 role: Role::Assistant,
654 content: vec!["Response to Message 2".into()],
655 cache: false,
656 reasoning_details: None,
657 },
658 LanguageModelRequestMessage {
659 role: Role::User,
660 content: vec!["Use the echo tool".into()],
661 cache: false,
662 reasoning_details: None,
663 },
664 LanguageModelRequestMessage {
665 role: Role::Assistant,
666 content: vec![MessageContent::ToolUse(tool_use)],
667 cache: false,
668 reasoning_details: None,
669 },
670 LanguageModelRequestMessage {
671 role: Role::User,
672 content: vec![MessageContent::ToolResult(tool_result)],
673 cache: true,
674 reasoning_details: None,
675 }
676 ]
677 );
678}
679
680#[gpui::test]
681#[cfg_attr(not(feature = "e2e"), ignore)]
682async fn test_basic_tool_calls(cx: &mut TestAppContext) {
683 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
684
685 // Test a tool call that's likely to complete *before* streaming stops.
686 let events = thread
687 .update(cx, |thread, cx| {
688 thread.add_tool(EchoTool);
689 thread.send(
690 UserMessageId::new(),
691 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
692 cx,
693 )
694 })
695 .unwrap()
696 .collect()
697 .await;
698 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
699
700 // Test a tool calls that's likely to complete *after* streaming stops.
701 let events = thread
702 .update(cx, |thread, cx| {
703 thread.remove_tool(&EchoTool::NAME);
704 thread.add_tool(DelayTool);
705 thread.send(
706 UserMessageId::new(),
707 [
708 "Now call the delay tool with 200ms.",
709 "When the timer goes off, then you echo the output of the tool.",
710 ],
711 cx,
712 )
713 })
714 .unwrap()
715 .collect()
716 .await;
717 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
718 thread.update(cx, |thread, _cx| {
719 assert!(
720 thread
721 .last_message()
722 .unwrap()
723 .as_agent_message()
724 .unwrap()
725 .content
726 .iter()
727 .any(|content| {
728 if let AgentMessageContent::Text(text) = content {
729 text.contains("Ding")
730 } else {
731 false
732 }
733 }),
734 "{}",
735 thread.to_markdown()
736 );
737 });
738}
739
740#[gpui::test]
741#[cfg_attr(not(feature = "e2e"), ignore)]
742async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
743 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
744
745 // Test a tool call that's likely to complete *before* streaming stops.
746 let mut events = thread
747 .update(cx, |thread, cx| {
748 thread.add_tool(WordListTool);
749 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
750 })
751 .unwrap();
752
753 let mut saw_partial_tool_use = false;
754 while let Some(event) = events.next().await {
755 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
756 thread.update(cx, |thread, _cx| {
757 // Look for a tool use in the thread's last message
758 let message = thread.last_message().unwrap();
759 let agent_message = message.as_agent_message().unwrap();
760 let last_content = agent_message.content.last().unwrap();
761 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
762 assert_eq!(last_tool_use.name.as_ref(), "word_list");
763 if tool_call.status == acp::ToolCallStatus::Pending {
764 if !last_tool_use.is_input_complete
765 && last_tool_use.input.get("g").is_none()
766 {
767 saw_partial_tool_use = true;
768 }
769 } else {
770 last_tool_use
771 .input
772 .get("a")
773 .expect("'a' has streamed because input is now complete");
774 last_tool_use
775 .input
776 .get("g")
777 .expect("'g' has streamed because input is now complete");
778 }
779 } else {
780 panic!("last content should be a tool use");
781 }
782 });
783 }
784 }
785
786 assert!(
787 saw_partial_tool_use,
788 "should see at least one partially streamed tool use in the history"
789 );
790}
791
792#[gpui::test]
793async fn test_tool_authorization(cx: &mut TestAppContext) {
794 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
795 let fake_model = model.as_fake();
796
797 let mut events = thread
798 .update(cx, |thread, cx| {
799 thread.add_tool(ToolRequiringPermission);
800 thread.send(UserMessageId::new(), ["abc"], cx)
801 })
802 .unwrap();
803 cx.run_until_parked();
804 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
805 LanguageModelToolUse {
806 id: "tool_id_1".into(),
807 name: ToolRequiringPermission::NAME.into(),
808 raw_input: "{}".into(),
809 input: json!({}),
810 is_input_complete: true,
811 thought_signature: None,
812 },
813 ));
814 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
815 LanguageModelToolUse {
816 id: "tool_id_2".into(),
817 name: ToolRequiringPermission::NAME.into(),
818 raw_input: "{}".into(),
819 input: json!({}),
820 is_input_complete: true,
821 thought_signature: None,
822 },
823 ));
824 fake_model.end_last_completion_stream();
825 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
826 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
827
828 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
829 tool_call_auth_1
830 .response
831 .send(acp::PermissionOptionId::new("allow"))
832 .unwrap();
833 cx.run_until_parked();
834
835 // Reject the second - send "deny" option_id directly since Deny is now a button
836 tool_call_auth_2
837 .response
838 .send(acp::PermissionOptionId::new("deny"))
839 .unwrap();
840 cx.run_until_parked();
841
842 let completion = fake_model.pending_completions().pop().unwrap();
843 let message = completion.messages.last().unwrap();
844 assert_eq!(
845 message.content,
846 vec![
847 language_model::MessageContent::ToolResult(LanguageModelToolResult {
848 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
849 tool_name: ToolRequiringPermission::NAME.into(),
850 is_error: false,
851 content: "Allowed".into(),
852 output: Some("Allowed".into())
853 }),
854 language_model::MessageContent::ToolResult(LanguageModelToolResult {
855 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
856 tool_name: ToolRequiringPermission::NAME.into(),
857 is_error: true,
858 content: "Permission to run tool denied by user".into(),
859 output: Some("Permission to run tool denied by user".into())
860 })
861 ]
862 );
863
864 // Simulate yet another tool call.
865 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
866 LanguageModelToolUse {
867 id: "tool_id_3".into(),
868 name: ToolRequiringPermission::NAME.into(),
869 raw_input: "{}".into(),
870 input: json!({}),
871 is_input_complete: true,
872 thought_signature: None,
873 },
874 ));
875 fake_model.end_last_completion_stream();
876
877 // Respond by always allowing tools - send transformed option_id
878 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
879 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
880 tool_call_auth_3
881 .response
882 .send(acp::PermissionOptionId::new(
883 "always_allow:tool_requiring_permission",
884 ))
885 .unwrap();
886 cx.run_until_parked();
887 let completion = fake_model.pending_completions().pop().unwrap();
888 let message = completion.messages.last().unwrap();
889 assert_eq!(
890 message.content,
891 vec![language_model::MessageContent::ToolResult(
892 LanguageModelToolResult {
893 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
894 tool_name: ToolRequiringPermission::NAME.into(),
895 is_error: false,
896 content: "Allowed".into(),
897 output: Some("Allowed".into())
898 }
899 )]
900 );
901
902 // Simulate a final tool call, ensuring we don't trigger authorization.
903 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
904 LanguageModelToolUse {
905 id: "tool_id_4".into(),
906 name: ToolRequiringPermission::NAME.into(),
907 raw_input: "{}".into(),
908 input: json!({}),
909 is_input_complete: true,
910 thought_signature: None,
911 },
912 ));
913 fake_model.end_last_completion_stream();
914 cx.run_until_parked();
915 let completion = fake_model.pending_completions().pop().unwrap();
916 let message = completion.messages.last().unwrap();
917 assert_eq!(
918 message.content,
919 vec![language_model::MessageContent::ToolResult(
920 LanguageModelToolResult {
921 tool_use_id: "tool_id_4".into(),
922 tool_name: ToolRequiringPermission::NAME.into(),
923 is_error: false,
924 content: "Allowed".into(),
925 output: Some("Allowed".into())
926 }
927 )]
928 );
929}
930
931#[gpui::test]
932async fn test_tool_hallucination(cx: &mut TestAppContext) {
933 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
934 let fake_model = model.as_fake();
935
936 let mut events = thread
937 .update(cx, |thread, cx| {
938 thread.send(UserMessageId::new(), ["abc"], cx)
939 })
940 .unwrap();
941 cx.run_until_parked();
942 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
943 LanguageModelToolUse {
944 id: "tool_id_1".into(),
945 name: "nonexistent_tool".into(),
946 raw_input: "{}".into(),
947 input: json!({}),
948 is_input_complete: true,
949 thought_signature: None,
950 },
951 ));
952 fake_model.end_last_completion_stream();
953
954 let tool_call = expect_tool_call(&mut events).await;
955 assert_eq!(tool_call.title, "nonexistent_tool");
956 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
957 let update = expect_tool_call_update_fields(&mut events).await;
958 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
959}
960
961async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
962 let event = events
963 .next()
964 .await
965 .expect("no tool call authorization event received")
966 .unwrap();
967 match event {
968 ThreadEvent::ToolCall(tool_call) => tool_call,
969 event => {
970 panic!("Unexpected event {event:?}");
971 }
972 }
973}
974
975async fn expect_tool_call_update_fields(
976 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
977) -> acp::ToolCallUpdate {
978 let event = events
979 .next()
980 .await
981 .expect("no tool call authorization event received")
982 .unwrap();
983 match event {
984 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
985 event => {
986 panic!("Unexpected event {event:?}");
987 }
988 }
989}
990
991async fn next_tool_call_authorization(
992 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
993) -> ToolCallAuthorization {
994 loop {
995 let event = events
996 .next()
997 .await
998 .expect("no tool call authorization event received")
999 .unwrap();
1000 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1001 let permission_kinds = tool_call_authorization
1002 .options
1003 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1004 .map(|option| option.kind);
1005 let allow_once = tool_call_authorization
1006 .options
1007 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1008 .map(|option| option.kind);
1009
1010 assert_eq!(
1011 permission_kinds,
1012 Some(acp::PermissionOptionKind::AllowAlways)
1013 );
1014 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1015 return tool_call_authorization;
1016 }
1017 }
1018}
1019
1020#[test]
1021fn test_permission_options_terminal_with_pattern() {
1022 let permission_options = ToolPermissionContext::new(
1023 TerminalTool::NAME,
1024 vec!["cargo build --release".to_string()],
1025 )
1026 .build_permission_options();
1027
1028 let PermissionOptions::Dropdown(choices) = permission_options else {
1029 panic!("Expected dropdown permission options");
1030 };
1031
1032 assert_eq!(choices.len(), 3);
1033 let labels: Vec<&str> = choices
1034 .iter()
1035 .map(|choice| choice.allow.name.as_ref())
1036 .collect();
1037 assert!(labels.contains(&"Always for terminal"));
1038 assert!(labels.contains(&"Always for `cargo build` commands"));
1039 assert!(labels.contains(&"Only this time"));
1040}
1041
1042#[test]
1043fn test_permission_options_terminal_command_with_flag_second_token() {
1044 let permission_options =
1045 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1046 .build_permission_options();
1047
1048 let PermissionOptions::Dropdown(choices) = permission_options else {
1049 panic!("Expected dropdown permission options");
1050 };
1051
1052 assert_eq!(choices.len(), 3);
1053 let labels: Vec<&str> = choices
1054 .iter()
1055 .map(|choice| choice.allow.name.as_ref())
1056 .collect();
1057 assert!(labels.contains(&"Always for terminal"));
1058 assert!(labels.contains(&"Always for `ls` commands"));
1059 assert!(labels.contains(&"Only this time"));
1060}
1061
1062#[test]
1063fn test_permission_options_terminal_single_word_command() {
1064 let permission_options =
1065 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1066 .build_permission_options();
1067
1068 let PermissionOptions::Dropdown(choices) = permission_options else {
1069 panic!("Expected dropdown permission options");
1070 };
1071
1072 assert_eq!(choices.len(), 3);
1073 let labels: Vec<&str> = choices
1074 .iter()
1075 .map(|choice| choice.allow.name.as_ref())
1076 .collect();
1077 assert!(labels.contains(&"Always for terminal"));
1078 assert!(labels.contains(&"Always for `whoami` commands"));
1079 assert!(labels.contains(&"Only this time"));
1080}
1081
1082#[test]
1083fn test_permission_options_edit_file_with_path_pattern() {
1084 let permission_options =
1085 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1086 .build_permission_options();
1087
1088 let PermissionOptions::Dropdown(choices) = permission_options else {
1089 panic!("Expected dropdown permission options");
1090 };
1091
1092 let labels: Vec<&str> = choices
1093 .iter()
1094 .map(|choice| choice.allow.name.as_ref())
1095 .collect();
1096 assert!(labels.contains(&"Always for edit file"));
1097 assert!(labels.contains(&"Always for `src/`"));
1098}
1099
1100#[test]
1101fn test_permission_options_fetch_with_domain_pattern() {
1102 let permission_options =
1103 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1104 .build_permission_options();
1105
1106 let PermissionOptions::Dropdown(choices) = permission_options else {
1107 panic!("Expected dropdown permission options");
1108 };
1109
1110 let labels: Vec<&str> = choices
1111 .iter()
1112 .map(|choice| choice.allow.name.as_ref())
1113 .collect();
1114 assert!(labels.contains(&"Always for fetch"));
1115 assert!(labels.contains(&"Always for `docs.rs`"));
1116}
1117
1118#[test]
1119fn test_permission_options_without_pattern() {
1120 let permission_options = ToolPermissionContext::new(
1121 TerminalTool::NAME,
1122 vec!["./deploy.sh --production".to_string()],
1123 )
1124 .build_permission_options();
1125
1126 let PermissionOptions::Dropdown(choices) = permission_options else {
1127 panic!("Expected dropdown permission options");
1128 };
1129
1130 assert_eq!(choices.len(), 2);
1131 let labels: Vec<&str> = choices
1132 .iter()
1133 .map(|choice| choice.allow.name.as_ref())
1134 .collect();
1135 assert!(labels.contains(&"Always for terminal"));
1136 assert!(labels.contains(&"Only this time"));
1137 assert!(!labels.iter().any(|label| label.contains("commands")));
1138}
1139
1140#[test]
1141fn test_permission_options_symlink_target_are_flat_once_only() {
1142 let permission_options =
1143 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1144 .build_permission_options();
1145
1146 let PermissionOptions::Flat(options) = permission_options else {
1147 panic!("Expected flat permission options for symlink target authorization");
1148 };
1149
1150 assert_eq!(options.len(), 2);
1151 assert!(options.iter().any(|option| {
1152 option.option_id.0.as_ref() == "allow"
1153 && option.kind == acp::PermissionOptionKind::AllowOnce
1154 }));
1155 assert!(options.iter().any(|option| {
1156 option.option_id.0.as_ref() == "deny"
1157 && option.kind == acp::PermissionOptionKind::RejectOnce
1158 }));
1159}
1160
1161#[test]
1162fn test_permission_option_ids_for_terminal() {
1163 let permission_options = ToolPermissionContext::new(
1164 TerminalTool::NAME,
1165 vec!["cargo build --release".to_string()],
1166 )
1167 .build_permission_options();
1168
1169 let PermissionOptions::Dropdown(choices) = permission_options else {
1170 panic!("Expected dropdown permission options");
1171 };
1172
1173 let allow_ids: Vec<String> = choices
1174 .iter()
1175 .map(|choice| choice.allow.option_id.0.to_string())
1176 .collect();
1177 let deny_ids: Vec<String> = choices
1178 .iter()
1179 .map(|choice| choice.deny.option_id.0.to_string())
1180 .collect();
1181
1182 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1183 assert!(allow_ids.contains(&"allow".to_string()));
1184 assert!(
1185 allow_ids
1186 .iter()
1187 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1188 "Missing allow pattern option"
1189 );
1190
1191 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1192 assert!(deny_ids.contains(&"deny".to_string()));
1193 assert!(
1194 deny_ids
1195 .iter()
1196 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1197 "Missing deny pattern option"
1198 );
1199}
1200
1201#[gpui::test]
1202#[cfg_attr(not(feature = "e2e"), ignore)]
1203async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1204 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1205
1206 // Test concurrent tool calls with different delay times
1207 let events = thread
1208 .update(cx, |thread, cx| {
1209 thread.add_tool(DelayTool);
1210 thread.send(
1211 UserMessageId::new(),
1212 [
1213 "Call the delay tool twice in the same message.",
1214 "Once with 100ms. Once with 300ms.",
1215 "When both timers are complete, describe the outputs.",
1216 ],
1217 cx,
1218 )
1219 })
1220 .unwrap()
1221 .collect()
1222 .await;
1223
1224 let stop_reasons = stop_events(events);
1225 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1226
1227 thread.update(cx, |thread, _cx| {
1228 let last_message = thread.last_message().unwrap();
1229 let agent_message = last_message.as_agent_message().unwrap();
1230 let text = agent_message
1231 .content
1232 .iter()
1233 .filter_map(|content| {
1234 if let AgentMessageContent::Text(text) = content {
1235 Some(text.as_str())
1236 } else {
1237 None
1238 }
1239 })
1240 .collect::<String>();
1241
1242 assert!(text.contains("Ding"));
1243 });
1244}
1245
1246#[gpui::test]
1247async fn test_profiles(cx: &mut TestAppContext) {
1248 let ThreadTest {
1249 model, thread, fs, ..
1250 } = setup(cx, TestModel::Fake).await;
1251 let fake_model = model.as_fake();
1252
1253 thread.update(cx, |thread, _cx| {
1254 thread.add_tool(DelayTool);
1255 thread.add_tool(EchoTool);
1256 thread.add_tool(InfiniteTool);
1257 });
1258
1259 // Override profiles and wait for settings to be loaded.
1260 fs.insert_file(
1261 paths::settings_file(),
1262 json!({
1263 "agent": {
1264 "profiles": {
1265 "test-1": {
1266 "name": "Test Profile 1",
1267 "tools": {
1268 EchoTool::NAME: true,
1269 DelayTool::NAME: true,
1270 }
1271 },
1272 "test-2": {
1273 "name": "Test Profile 2",
1274 "tools": {
1275 InfiniteTool::NAME: true,
1276 }
1277 }
1278 }
1279 }
1280 })
1281 .to_string()
1282 .into_bytes(),
1283 )
1284 .await;
1285 cx.run_until_parked();
1286
1287 // Test that test-1 profile (default) has echo and delay tools
1288 thread
1289 .update(cx, |thread, cx| {
1290 thread.set_profile(AgentProfileId("test-1".into()), cx);
1291 thread.send(UserMessageId::new(), ["test"], cx)
1292 })
1293 .unwrap();
1294 cx.run_until_parked();
1295
1296 let mut pending_completions = fake_model.pending_completions();
1297 assert_eq!(pending_completions.len(), 1);
1298 let completion = pending_completions.pop().unwrap();
1299 let tool_names: Vec<String> = completion
1300 .tools
1301 .iter()
1302 .map(|tool| tool.name.clone())
1303 .collect();
1304 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1305 fake_model.end_last_completion_stream();
1306
1307 // Switch to test-2 profile, and verify that it has only the infinite tool.
1308 thread
1309 .update(cx, |thread, cx| {
1310 thread.set_profile(AgentProfileId("test-2".into()), cx);
1311 thread.send(UserMessageId::new(), ["test2"], cx)
1312 })
1313 .unwrap();
1314 cx.run_until_parked();
1315 let mut pending_completions = fake_model.pending_completions();
1316 assert_eq!(pending_completions.len(), 1);
1317 let completion = pending_completions.pop().unwrap();
1318 let tool_names: Vec<String> = completion
1319 .tools
1320 .iter()
1321 .map(|tool| tool.name.clone())
1322 .collect();
1323 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1324}
1325
1326#[gpui::test]
1327async fn test_mcp_tools(cx: &mut TestAppContext) {
1328 let ThreadTest {
1329 model,
1330 thread,
1331 context_server_store,
1332 fs,
1333 ..
1334 } = setup(cx, TestModel::Fake).await;
1335 let fake_model = model.as_fake();
1336
1337 // Override profiles and wait for settings to be loaded.
1338 fs.insert_file(
1339 paths::settings_file(),
1340 json!({
1341 "agent": {
1342 "tool_permissions": { "default": "allow" },
1343 "profiles": {
1344 "test": {
1345 "name": "Test Profile",
1346 "enable_all_context_servers": true,
1347 "tools": {
1348 EchoTool::NAME: true,
1349 }
1350 },
1351 }
1352 }
1353 })
1354 .to_string()
1355 .into_bytes(),
1356 )
1357 .await;
1358 cx.run_until_parked();
1359 thread.update(cx, |thread, cx| {
1360 thread.set_profile(AgentProfileId("test".into()), cx)
1361 });
1362
1363 let mut mcp_tool_calls = setup_context_server(
1364 "test_server",
1365 vec![context_server::types::Tool {
1366 name: "echo".into(),
1367 description: None,
1368 input_schema: serde_json::to_value(EchoTool::input_schema(
1369 LanguageModelToolSchemaFormat::JsonSchema,
1370 ))
1371 .unwrap(),
1372 output_schema: None,
1373 annotations: None,
1374 }],
1375 &context_server_store,
1376 cx,
1377 );
1378
1379 let events = thread.update(cx, |thread, cx| {
1380 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1381 });
1382 cx.run_until_parked();
1383
1384 // Simulate the model calling the MCP tool.
1385 let completion = fake_model.pending_completions().pop().unwrap();
1386 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1387 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1388 LanguageModelToolUse {
1389 id: "tool_1".into(),
1390 name: "echo".into(),
1391 raw_input: json!({"text": "test"}).to_string(),
1392 input: json!({"text": "test"}),
1393 is_input_complete: true,
1394 thought_signature: None,
1395 },
1396 ));
1397 fake_model.end_last_completion_stream();
1398 cx.run_until_parked();
1399
1400 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1401 assert_eq!(tool_call_params.name, "echo");
1402 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1403 tool_call_response
1404 .send(context_server::types::CallToolResponse {
1405 content: vec![context_server::types::ToolResponseContent::Text {
1406 text: "test".into(),
1407 }],
1408 is_error: None,
1409 meta: None,
1410 structured_content: None,
1411 })
1412 .unwrap();
1413 cx.run_until_parked();
1414
1415 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1416 fake_model.send_last_completion_stream_text_chunk("Done!");
1417 fake_model.end_last_completion_stream();
1418 events.collect::<Vec<_>>().await;
1419
1420 // Send again after adding the echo tool, ensuring the name collision is resolved.
1421 let events = thread.update(cx, |thread, cx| {
1422 thread.add_tool(EchoTool);
1423 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1424 });
1425 cx.run_until_parked();
1426 let completion = fake_model.pending_completions().pop().unwrap();
1427 assert_eq!(
1428 tool_names_for_completion(&completion),
1429 vec!["echo", "test_server_echo"]
1430 );
1431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1432 LanguageModelToolUse {
1433 id: "tool_2".into(),
1434 name: "test_server_echo".into(),
1435 raw_input: json!({"text": "mcp"}).to_string(),
1436 input: json!({"text": "mcp"}),
1437 is_input_complete: true,
1438 thought_signature: None,
1439 },
1440 ));
1441 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1442 LanguageModelToolUse {
1443 id: "tool_3".into(),
1444 name: "echo".into(),
1445 raw_input: json!({"text": "native"}).to_string(),
1446 input: json!({"text": "native"}),
1447 is_input_complete: true,
1448 thought_signature: None,
1449 },
1450 ));
1451 fake_model.end_last_completion_stream();
1452 cx.run_until_parked();
1453
1454 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1455 assert_eq!(tool_call_params.name, "echo");
1456 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1457 tool_call_response
1458 .send(context_server::types::CallToolResponse {
1459 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1460 is_error: None,
1461 meta: None,
1462 structured_content: None,
1463 })
1464 .unwrap();
1465 cx.run_until_parked();
1466
1467 // Ensure the tool results were inserted with the correct names.
1468 let completion = fake_model.pending_completions().pop().unwrap();
1469 assert_eq!(
1470 completion.messages.last().unwrap().content,
1471 vec![
1472 MessageContent::ToolResult(LanguageModelToolResult {
1473 tool_use_id: "tool_3".into(),
1474 tool_name: "echo".into(),
1475 is_error: false,
1476 content: "native".into(),
1477 output: Some("native".into()),
1478 },),
1479 MessageContent::ToolResult(LanguageModelToolResult {
1480 tool_use_id: "tool_2".into(),
1481 tool_name: "test_server_echo".into(),
1482 is_error: false,
1483 content: "mcp".into(),
1484 output: Some("mcp".into()),
1485 },),
1486 ]
1487 );
1488 fake_model.end_last_completion_stream();
1489 events.collect::<Vec<_>>().await;
1490}
1491
1492#[gpui::test]
1493async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1494 let ThreadTest {
1495 model,
1496 thread,
1497 context_server_store,
1498 fs,
1499 ..
1500 } = setup(cx, TestModel::Fake).await;
1501 let fake_model = model.as_fake();
1502
1503 // Setup settings to allow MCP tools
1504 fs.insert_file(
1505 paths::settings_file(),
1506 json!({
1507 "agent": {
1508 "always_allow_tool_actions": true,
1509 "profiles": {
1510 "test": {
1511 "name": "Test Profile",
1512 "enable_all_context_servers": true,
1513 "tools": {}
1514 },
1515 }
1516 }
1517 })
1518 .to_string()
1519 .into_bytes(),
1520 )
1521 .await;
1522 cx.run_until_parked();
1523 thread.update(cx, |thread, cx| {
1524 thread.set_profile(AgentProfileId("test".into()), cx)
1525 });
1526
1527 // Setup a context server with a tool
1528 let mut mcp_tool_calls = setup_context_server(
1529 "github_server",
1530 vec![context_server::types::Tool {
1531 name: "issue_read".into(),
1532 description: Some("Read a GitHub issue".into()),
1533 input_schema: json!({
1534 "type": "object",
1535 "properties": {
1536 "issue_url": { "type": "string" }
1537 }
1538 }),
1539 output_schema: None,
1540 annotations: None,
1541 }],
1542 &context_server_store,
1543 cx,
1544 );
1545
1546 // Send a message and have the model call the MCP tool
1547 let events = thread.update(cx, |thread, cx| {
1548 thread
1549 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1550 .unwrap()
1551 });
1552 cx.run_until_parked();
1553
1554 // Verify the MCP tool is available to the model
1555 let completion = fake_model.pending_completions().pop().unwrap();
1556 assert_eq!(
1557 tool_names_for_completion(&completion),
1558 vec!["issue_read"],
1559 "MCP tool should be available"
1560 );
1561
1562 // Simulate the model calling the MCP tool
1563 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1564 LanguageModelToolUse {
1565 id: "tool_1".into(),
1566 name: "issue_read".into(),
1567 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1568 .to_string(),
1569 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1570 is_input_complete: true,
1571 thought_signature: None,
1572 },
1573 ));
1574 fake_model.end_last_completion_stream();
1575 cx.run_until_parked();
1576
1577 // The MCP server receives the tool call and responds with content
1578 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1579 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1580 assert_eq!(tool_call_params.name, "issue_read");
1581 tool_call_response
1582 .send(context_server::types::CallToolResponse {
1583 content: vec![context_server::types::ToolResponseContent::Text {
1584 text: expected_tool_output.into(),
1585 }],
1586 is_error: None,
1587 meta: None,
1588 structured_content: None,
1589 })
1590 .unwrap();
1591 cx.run_until_parked();
1592
1593 // After tool completes, the model continues with a new completion request
1594 // that includes the tool results. We need to respond to this.
1595 let _completion = fake_model.pending_completions().pop().unwrap();
1596 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1597 fake_model
1598 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1599 fake_model.end_last_completion_stream();
1600 events.collect::<Vec<_>>().await;
1601
1602 // Verify the tool result is stored in the thread by checking the markdown output.
1603 // The tool result is in the first assistant message (not the last one, which is
1604 // the model's response after the tool completed).
1605 thread.update(cx, |thread, _cx| {
1606 let markdown = thread.to_markdown();
1607 assert!(
1608 markdown.contains("**Tool Result**: issue_read"),
1609 "Thread should contain tool result header"
1610 );
1611 assert!(
1612 markdown.contains(expected_tool_output),
1613 "Thread should contain tool output: {}",
1614 expected_tool_output
1615 );
1616 });
1617
1618 // Simulate app restart: disconnect the MCP server.
1619 // After restart, the MCP server won't be connected yet when the thread is replayed.
1620 context_server_store.update(cx, |store, cx| {
1621 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1622 });
1623 cx.run_until_parked();
1624
1625 // Replay the thread (this is what happens when loading a saved thread)
1626 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1627
1628 let mut found_tool_call = None;
1629 let mut found_tool_call_update_with_output = None;
1630
1631 while let Some(event) = replay_events.next().await {
1632 let event = event.unwrap();
1633 match &event {
1634 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1635 found_tool_call = Some(tc.clone());
1636 }
1637 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1638 if update.tool_call_id.to_string() == "tool_1" =>
1639 {
1640 if update.fields.raw_output.is_some() {
1641 found_tool_call_update_with_output = Some(update.clone());
1642 }
1643 }
1644 _ => {}
1645 }
1646 }
1647
1648 // The tool call should be found
1649 assert!(
1650 found_tool_call.is_some(),
1651 "Tool call should be emitted during replay"
1652 );
1653
1654 assert!(
1655 found_tool_call_update_with_output.is_some(),
1656 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1657 );
1658
1659 let update = found_tool_call_update_with_output.unwrap();
1660 assert_eq!(
1661 update.fields.raw_output,
1662 Some(expected_tool_output.into()),
1663 "raw_output should contain the saved tool result"
1664 );
1665
1666 // Also verify the status is correct (completed, not failed)
1667 assert_eq!(
1668 update.fields.status,
1669 Some(acp::ToolCallStatus::Completed),
1670 "Tool call status should reflect the original completion status"
1671 );
1672}
1673
1674#[gpui::test]
1675async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1676 let ThreadTest {
1677 model,
1678 thread,
1679 context_server_store,
1680 fs,
1681 ..
1682 } = setup(cx, TestModel::Fake).await;
1683 let fake_model = model.as_fake();
1684
1685 // Set up a profile with all tools enabled
1686 fs.insert_file(
1687 paths::settings_file(),
1688 json!({
1689 "agent": {
1690 "profiles": {
1691 "test": {
1692 "name": "Test Profile",
1693 "enable_all_context_servers": true,
1694 "tools": {
1695 EchoTool::NAME: true,
1696 DelayTool::NAME: true,
1697 WordListTool::NAME: true,
1698 ToolRequiringPermission::NAME: true,
1699 InfiniteTool::NAME: true,
1700 }
1701 },
1702 }
1703 }
1704 })
1705 .to_string()
1706 .into_bytes(),
1707 )
1708 .await;
1709 cx.run_until_parked();
1710
1711 thread.update(cx, |thread, cx| {
1712 thread.set_profile(AgentProfileId("test".into()), cx);
1713 thread.add_tool(EchoTool);
1714 thread.add_tool(DelayTool);
1715 thread.add_tool(WordListTool);
1716 thread.add_tool(ToolRequiringPermission);
1717 thread.add_tool(InfiniteTool);
1718 });
1719
1720 // Set up multiple context servers with some overlapping tool names
1721 let _server1_calls = setup_context_server(
1722 "xxx",
1723 vec![
1724 context_server::types::Tool {
1725 name: "echo".into(), // Conflicts with native EchoTool
1726 description: None,
1727 input_schema: serde_json::to_value(EchoTool::input_schema(
1728 LanguageModelToolSchemaFormat::JsonSchema,
1729 ))
1730 .unwrap(),
1731 output_schema: None,
1732 annotations: None,
1733 },
1734 context_server::types::Tool {
1735 name: "unique_tool_1".into(),
1736 description: None,
1737 input_schema: json!({"type": "object", "properties": {}}),
1738 output_schema: None,
1739 annotations: None,
1740 },
1741 ],
1742 &context_server_store,
1743 cx,
1744 );
1745
1746 let _server2_calls = setup_context_server(
1747 "yyy",
1748 vec![
1749 context_server::types::Tool {
1750 name: "echo".into(), // Also conflicts with native EchoTool
1751 description: None,
1752 input_schema: serde_json::to_value(EchoTool::input_schema(
1753 LanguageModelToolSchemaFormat::JsonSchema,
1754 ))
1755 .unwrap(),
1756 output_schema: None,
1757 annotations: None,
1758 },
1759 context_server::types::Tool {
1760 name: "unique_tool_2".into(),
1761 description: None,
1762 input_schema: json!({"type": "object", "properties": {}}),
1763 output_schema: None,
1764 annotations: None,
1765 },
1766 context_server::types::Tool {
1767 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1768 description: None,
1769 input_schema: json!({"type": "object", "properties": {}}),
1770 output_schema: None,
1771 annotations: None,
1772 },
1773 context_server::types::Tool {
1774 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1775 description: None,
1776 input_schema: json!({"type": "object", "properties": {}}),
1777 output_schema: None,
1778 annotations: None,
1779 },
1780 ],
1781 &context_server_store,
1782 cx,
1783 );
1784 let _server3_calls = setup_context_server(
1785 "zzz",
1786 vec![
1787 context_server::types::Tool {
1788 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1789 description: None,
1790 input_schema: json!({"type": "object", "properties": {}}),
1791 output_schema: None,
1792 annotations: None,
1793 },
1794 context_server::types::Tool {
1795 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1796 description: None,
1797 input_schema: json!({"type": "object", "properties": {}}),
1798 output_schema: None,
1799 annotations: None,
1800 },
1801 context_server::types::Tool {
1802 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1803 description: None,
1804 input_schema: json!({"type": "object", "properties": {}}),
1805 output_schema: None,
1806 annotations: None,
1807 },
1808 ],
1809 &context_server_store,
1810 cx,
1811 );
1812
1813 // Server with spaces in name - tests snake_case conversion for API compatibility
1814 let _server4_calls = setup_context_server(
1815 "Azure DevOps",
1816 vec![context_server::types::Tool {
1817 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1818 description: None,
1819 input_schema: serde_json::to_value(EchoTool::input_schema(
1820 LanguageModelToolSchemaFormat::JsonSchema,
1821 ))
1822 .unwrap(),
1823 output_schema: None,
1824 annotations: None,
1825 }],
1826 &context_server_store,
1827 cx,
1828 );
1829
1830 thread
1831 .update(cx, |thread, cx| {
1832 thread.send(UserMessageId::new(), ["Go"], cx)
1833 })
1834 .unwrap();
1835 cx.run_until_parked();
1836 let completion = fake_model.pending_completions().pop().unwrap();
1837 assert_eq!(
1838 tool_names_for_completion(&completion),
1839 vec![
1840 "azure_dev_ops_echo",
1841 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1842 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1843 "delay",
1844 "echo",
1845 "infinite",
1846 "tool_requiring_permission",
1847 "unique_tool_1",
1848 "unique_tool_2",
1849 "word_list",
1850 "xxx_echo",
1851 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1852 "yyy_echo",
1853 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1854 ]
1855 );
1856}
1857
1858#[gpui::test]
1859#[cfg_attr(not(feature = "e2e"), ignore)]
1860async fn test_cancellation(cx: &mut TestAppContext) {
1861 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1862
1863 let mut events = thread
1864 .update(cx, |thread, cx| {
1865 thread.add_tool(InfiniteTool);
1866 thread.add_tool(EchoTool);
1867 thread.send(
1868 UserMessageId::new(),
1869 ["Call the echo tool, then call the infinite tool, then explain their output"],
1870 cx,
1871 )
1872 })
1873 .unwrap();
1874
1875 // Wait until both tools are called.
1876 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1877 let mut echo_id = None;
1878 let mut echo_completed = false;
1879 while let Some(event) = events.next().await {
1880 match event.unwrap() {
1881 ThreadEvent::ToolCall(tool_call) => {
1882 assert_eq!(tool_call.title, expected_tools.remove(0));
1883 if tool_call.title == "Echo" {
1884 echo_id = Some(tool_call.tool_call_id);
1885 }
1886 }
1887 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1888 acp::ToolCallUpdate {
1889 tool_call_id,
1890 fields:
1891 acp::ToolCallUpdateFields {
1892 status: Some(acp::ToolCallStatus::Completed),
1893 ..
1894 },
1895 ..
1896 },
1897 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1898 echo_completed = true;
1899 }
1900 _ => {}
1901 }
1902
1903 if expected_tools.is_empty() && echo_completed {
1904 break;
1905 }
1906 }
1907
1908 // Cancel the current send and ensure that the event stream is closed, even
1909 // if one of the tools is still running.
1910 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1911 let events = events.collect::<Vec<_>>().await;
1912 let last_event = events.last();
1913 assert!(
1914 matches!(
1915 last_event,
1916 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1917 ),
1918 "unexpected event {last_event:?}"
1919 );
1920
1921 // Ensure we can still send a new message after cancellation.
1922 let events = thread
1923 .update(cx, |thread, cx| {
1924 thread.send(
1925 UserMessageId::new(),
1926 ["Testing: reply with 'Hello' then stop."],
1927 cx,
1928 )
1929 })
1930 .unwrap()
1931 .collect::<Vec<_>>()
1932 .await;
1933 thread.update(cx, |thread, _cx| {
1934 let message = thread.last_message().unwrap();
1935 let agent_message = message.as_agent_message().unwrap();
1936 assert_eq!(
1937 agent_message.content,
1938 vec![AgentMessageContent::Text("Hello".to_string())]
1939 );
1940 });
1941 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1942}
1943
1944#[gpui::test]
1945async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1946 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1947 always_allow_tools(cx);
1948 let fake_model = model.as_fake();
1949
1950 let environment = Rc::new(cx.update(|cx| {
1951 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1952 }));
1953 let handle = environment.terminal_handle.clone().unwrap();
1954
1955 let mut events = thread
1956 .update(cx, |thread, cx| {
1957 thread.add_tool(crate::TerminalTool::new(
1958 thread.project().clone(),
1959 environment,
1960 ));
1961 thread.send(UserMessageId::new(), ["run a command"], cx)
1962 })
1963 .unwrap();
1964
1965 cx.run_until_parked();
1966
1967 // Simulate the model calling the terminal tool
1968 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1969 LanguageModelToolUse {
1970 id: "terminal_tool_1".into(),
1971 name: TerminalTool::NAME.into(),
1972 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1973 input: json!({"command": "sleep 1000", "cd": "."}),
1974 is_input_complete: true,
1975 thought_signature: None,
1976 },
1977 ));
1978 fake_model.end_last_completion_stream();
1979
1980 // Wait for the terminal tool to start running
1981 wait_for_terminal_tool_started(&mut events, cx).await;
1982
1983 // Cancel the thread while the terminal is running
1984 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1985
1986 // Collect remaining events, driving the executor to let cancellation complete
1987 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1988
1989 // Verify the terminal was killed
1990 assert!(
1991 handle.was_killed(),
1992 "expected terminal handle to be killed on cancellation"
1993 );
1994
1995 // Verify we got a cancellation stop event
1996 assert_eq!(
1997 stop_events(remaining_events),
1998 vec![acp::StopReason::Cancelled],
1999 );
2000
2001 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2002 thread.update(cx, |thread, _cx| {
2003 let message = thread.last_message().unwrap();
2004 let agent_message = message.as_agent_message().unwrap();
2005
2006 let tool_use = agent_message
2007 .content
2008 .iter()
2009 .find_map(|content| match content {
2010 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2011 _ => None,
2012 })
2013 .expect("expected tool use in agent message");
2014
2015 let tool_result = agent_message
2016 .tool_results
2017 .get(&tool_use.id)
2018 .expect("expected tool result");
2019
2020 let result_text = match &tool_result.content {
2021 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2022 _ => panic!("expected text content in tool result"),
2023 };
2024
2025 // "partial output" comes from FakeTerminalHandle's output field
2026 assert!(
2027 result_text.contains("partial output"),
2028 "expected tool result to contain terminal output, got: {result_text}"
2029 );
2030 // Match the actual format from process_content in terminal_tool.rs
2031 assert!(
2032 result_text.contains("The user stopped this command"),
2033 "expected tool result to indicate user stopped, got: {result_text}"
2034 );
2035 });
2036
2037 // Verify we can send a new message after cancellation
2038 verify_thread_recovery(&thread, &fake_model, cx).await;
2039}
2040
2041#[gpui::test]
2042async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2043 // This test verifies that tools which properly handle cancellation via
2044 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2045 // to cancellation and report that they were cancelled.
2046 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2047 always_allow_tools(cx);
2048 let fake_model = model.as_fake();
2049
2050 let (tool, was_cancelled) = CancellationAwareTool::new();
2051
2052 let mut events = thread
2053 .update(cx, |thread, cx| {
2054 thread.add_tool(tool);
2055 thread.send(
2056 UserMessageId::new(),
2057 ["call the cancellation aware tool"],
2058 cx,
2059 )
2060 })
2061 .unwrap();
2062
2063 cx.run_until_parked();
2064
2065 // Simulate the model calling the cancellation-aware tool
2066 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2067 LanguageModelToolUse {
2068 id: "cancellation_aware_1".into(),
2069 name: "cancellation_aware".into(),
2070 raw_input: r#"{}"#.into(),
2071 input: json!({}),
2072 is_input_complete: true,
2073 thought_signature: None,
2074 },
2075 ));
2076 fake_model.end_last_completion_stream();
2077
2078 cx.run_until_parked();
2079
2080 // Wait for the tool call to be reported
2081 let mut tool_started = false;
2082 let deadline = cx.executor().num_cpus() * 100;
2083 for _ in 0..deadline {
2084 cx.run_until_parked();
2085
2086 while let Some(Some(event)) = events.next().now_or_never() {
2087 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2088 if tool_call.title == "Cancellation Aware Tool" {
2089 tool_started = true;
2090 break;
2091 }
2092 }
2093 }
2094
2095 if tool_started {
2096 break;
2097 }
2098
2099 cx.background_executor
2100 .timer(Duration::from_millis(10))
2101 .await;
2102 }
2103 assert!(tool_started, "expected cancellation aware tool to start");
2104
2105 // Cancel the thread and wait for it to complete
2106 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2107
2108 // The cancel task should complete promptly because the tool handles cancellation
2109 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2110 futures::select! {
2111 _ = cancel_task.fuse() => {}
2112 _ = timeout.fuse() => {
2113 panic!("cancel task timed out - tool did not respond to cancellation");
2114 }
2115 }
2116
2117 // Verify the tool detected cancellation via its flag
2118 assert!(
2119 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2120 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2121 );
2122
2123 // Collect remaining events
2124 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2125
2126 // Verify we got a cancellation stop event
2127 assert_eq!(
2128 stop_events(remaining_events),
2129 vec![acp::StopReason::Cancelled],
2130 );
2131
2132 // Verify we can send a new message after cancellation
2133 verify_thread_recovery(&thread, &fake_model, cx).await;
2134}
2135
2136/// Helper to verify thread can recover after cancellation by sending a simple message.
2137async fn verify_thread_recovery(
2138 thread: &Entity<Thread>,
2139 fake_model: &FakeLanguageModel,
2140 cx: &mut TestAppContext,
2141) {
2142 let events = thread
2143 .update(cx, |thread, cx| {
2144 thread.send(
2145 UserMessageId::new(),
2146 ["Testing: reply with 'Hello' then stop."],
2147 cx,
2148 )
2149 })
2150 .unwrap();
2151 cx.run_until_parked();
2152 fake_model.send_last_completion_stream_text_chunk("Hello");
2153 fake_model
2154 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2155 fake_model.end_last_completion_stream();
2156
2157 let events = events.collect::<Vec<_>>().await;
2158 thread.update(cx, |thread, _cx| {
2159 let message = thread.last_message().unwrap();
2160 let agent_message = message.as_agent_message().unwrap();
2161 assert_eq!(
2162 agent_message.content,
2163 vec![AgentMessageContent::Text("Hello".to_string())]
2164 );
2165 });
2166 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2167}
2168
2169/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2170async fn wait_for_terminal_tool_started(
2171 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2172 cx: &mut TestAppContext,
2173) {
2174 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2175 for _ in 0..deadline {
2176 cx.run_until_parked();
2177
2178 while let Some(Some(event)) = events.next().now_or_never() {
2179 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2180 update,
2181 ))) = &event
2182 {
2183 if update.fields.content.as_ref().is_some_and(|content| {
2184 content
2185 .iter()
2186 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2187 }) {
2188 return;
2189 }
2190 }
2191 }
2192
2193 cx.background_executor
2194 .timer(Duration::from_millis(10))
2195 .await;
2196 }
2197 panic!("terminal tool did not start within the expected time");
2198}
2199
2200/// Collects events until a Stop event is received, driving the executor to completion.
2201async fn collect_events_until_stop(
2202 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2203 cx: &mut TestAppContext,
2204) -> Vec<Result<ThreadEvent>> {
2205 let mut collected = Vec::new();
2206 let deadline = cx.executor().num_cpus() * 200;
2207
2208 for _ in 0..deadline {
2209 cx.executor().advance_clock(Duration::from_millis(10));
2210 cx.run_until_parked();
2211
2212 while let Some(Some(event)) = events.next().now_or_never() {
2213 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2214 collected.push(event);
2215 if is_stop {
2216 return collected;
2217 }
2218 }
2219 }
2220 panic!(
2221 "did not receive Stop event within the expected time; collected {} events",
2222 collected.len()
2223 );
2224}
2225
2226#[gpui::test]
2227async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2228 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2229 always_allow_tools(cx);
2230 let fake_model = model.as_fake();
2231
2232 let environment = Rc::new(cx.update(|cx| {
2233 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2234 }));
2235 let handle = environment.terminal_handle.clone().unwrap();
2236
2237 let message_id = UserMessageId::new();
2238 let mut events = thread
2239 .update(cx, |thread, cx| {
2240 thread.add_tool(crate::TerminalTool::new(
2241 thread.project().clone(),
2242 environment,
2243 ));
2244 thread.send(message_id.clone(), ["run a command"], cx)
2245 })
2246 .unwrap();
2247
2248 cx.run_until_parked();
2249
2250 // Simulate the model calling the terminal tool
2251 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2252 LanguageModelToolUse {
2253 id: "terminal_tool_1".into(),
2254 name: TerminalTool::NAME.into(),
2255 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2256 input: json!({"command": "sleep 1000", "cd": "."}),
2257 is_input_complete: true,
2258 thought_signature: None,
2259 },
2260 ));
2261 fake_model.end_last_completion_stream();
2262
2263 // Wait for the terminal tool to start running
2264 wait_for_terminal_tool_started(&mut events, cx).await;
2265
2266 // Truncate the thread while the terminal is running
2267 thread
2268 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2269 .unwrap();
2270
2271 // Drive the executor to let cancellation complete
2272 let _ = collect_events_until_stop(&mut events, cx).await;
2273
2274 // Verify the terminal was killed
2275 assert!(
2276 handle.was_killed(),
2277 "expected terminal handle to be killed on truncate"
2278 );
2279
2280 // Verify the thread is empty after truncation
2281 thread.update(cx, |thread, _cx| {
2282 assert_eq!(
2283 thread.to_markdown(),
2284 "",
2285 "expected thread to be empty after truncating the only message"
2286 );
2287 });
2288
2289 // Verify we can send a new message after truncation
2290 verify_thread_recovery(&thread, &fake_model, cx).await;
2291}
2292
2293#[gpui::test]
2294async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2295 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2296 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2297 always_allow_tools(cx);
2298 let fake_model = model.as_fake();
2299
2300 let environment = Rc::new(MultiTerminalEnvironment::new());
2301
2302 let mut events = thread
2303 .update(cx, |thread, cx| {
2304 thread.add_tool(crate::TerminalTool::new(
2305 thread.project().clone(),
2306 environment.clone(),
2307 ));
2308 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2309 })
2310 .unwrap();
2311
2312 cx.run_until_parked();
2313
2314 // Simulate the model calling two terminal tools
2315 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2316 LanguageModelToolUse {
2317 id: "terminal_tool_1".into(),
2318 name: TerminalTool::NAME.into(),
2319 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2320 input: json!({"command": "sleep 1000", "cd": "."}),
2321 is_input_complete: true,
2322 thought_signature: None,
2323 },
2324 ));
2325 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2326 LanguageModelToolUse {
2327 id: "terminal_tool_2".into(),
2328 name: TerminalTool::NAME.into(),
2329 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2330 input: json!({"command": "sleep 2000", "cd": "."}),
2331 is_input_complete: true,
2332 thought_signature: None,
2333 },
2334 ));
2335 fake_model.end_last_completion_stream();
2336
2337 // Wait for both terminal tools to start by counting terminal content updates
2338 let mut terminals_started = 0;
2339 let deadline = cx.executor().num_cpus() * 100;
2340 for _ in 0..deadline {
2341 cx.run_until_parked();
2342
2343 while let Some(Some(event)) = events.next().now_or_never() {
2344 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2345 update,
2346 ))) = &event
2347 {
2348 if update.fields.content.as_ref().is_some_and(|content| {
2349 content
2350 .iter()
2351 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2352 }) {
2353 terminals_started += 1;
2354 if terminals_started >= 2 {
2355 break;
2356 }
2357 }
2358 }
2359 }
2360 if terminals_started >= 2 {
2361 break;
2362 }
2363
2364 cx.background_executor
2365 .timer(Duration::from_millis(10))
2366 .await;
2367 }
2368 assert!(
2369 terminals_started >= 2,
2370 "expected 2 terminal tools to start, got {terminals_started}"
2371 );
2372
2373 // Cancel the thread while both terminals are running
2374 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2375
2376 // Collect remaining events
2377 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2378
2379 // Verify both terminal handles were killed
2380 let handles = environment.handles();
2381 assert_eq!(
2382 handles.len(),
2383 2,
2384 "expected 2 terminal handles to be created"
2385 );
2386 assert!(
2387 handles[0].was_killed(),
2388 "expected first terminal handle to be killed on cancellation"
2389 );
2390 assert!(
2391 handles[1].was_killed(),
2392 "expected second terminal handle to be killed on cancellation"
2393 );
2394
2395 // Verify we got a cancellation stop event
2396 assert_eq!(
2397 stop_events(remaining_events),
2398 vec![acp::StopReason::Cancelled],
2399 );
2400}
2401
2402#[gpui::test]
2403async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2404 // Tests that clicking the stop button on the terminal card (as opposed to the main
2405 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2406 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2407 always_allow_tools(cx);
2408 let fake_model = model.as_fake();
2409
2410 let environment = Rc::new(cx.update(|cx| {
2411 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2412 }));
2413 let handle = environment.terminal_handle.clone().unwrap();
2414
2415 let mut events = thread
2416 .update(cx, |thread, cx| {
2417 thread.add_tool(crate::TerminalTool::new(
2418 thread.project().clone(),
2419 environment,
2420 ));
2421 thread.send(UserMessageId::new(), ["run a command"], cx)
2422 })
2423 .unwrap();
2424
2425 cx.run_until_parked();
2426
2427 // Simulate the model calling the terminal tool
2428 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2429 LanguageModelToolUse {
2430 id: "terminal_tool_1".into(),
2431 name: TerminalTool::NAME.into(),
2432 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2433 input: json!({"command": "sleep 1000", "cd": "."}),
2434 is_input_complete: true,
2435 thought_signature: None,
2436 },
2437 ));
2438 fake_model.end_last_completion_stream();
2439
2440 // Wait for the terminal tool to start running
2441 wait_for_terminal_tool_started(&mut events, cx).await;
2442
2443 // Simulate user clicking stop on the terminal card itself.
2444 // This sets the flag and signals exit (simulating what the real UI would do).
2445 handle.set_stopped_by_user(true);
2446 handle.killed.store(true, Ordering::SeqCst);
2447 handle.signal_exit();
2448
2449 // Wait for the tool to complete
2450 cx.run_until_parked();
2451
2452 // The thread continues after tool completion - simulate the model ending its turn
2453 fake_model
2454 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2455 fake_model.end_last_completion_stream();
2456
2457 // Collect remaining events
2458 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2459
2460 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2461 assert_eq!(
2462 stop_events(remaining_events),
2463 vec![acp::StopReason::EndTurn],
2464 );
2465
2466 // Verify the tool result indicates user stopped
2467 thread.update(cx, |thread, _cx| {
2468 let message = thread.last_message().unwrap();
2469 let agent_message = message.as_agent_message().unwrap();
2470
2471 let tool_use = agent_message
2472 .content
2473 .iter()
2474 .find_map(|content| match content {
2475 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2476 _ => None,
2477 })
2478 .expect("expected tool use in agent message");
2479
2480 let tool_result = agent_message
2481 .tool_results
2482 .get(&tool_use.id)
2483 .expect("expected tool result");
2484
2485 let result_text = match &tool_result.content {
2486 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2487 _ => panic!("expected text content in tool result"),
2488 };
2489
2490 assert!(
2491 result_text.contains("The user stopped this command"),
2492 "expected tool result to indicate user stopped, got: {result_text}"
2493 );
2494 });
2495}
2496
2497#[gpui::test]
2498async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2499 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2500 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2501 always_allow_tools(cx);
2502 let fake_model = model.as_fake();
2503
2504 let environment = Rc::new(cx.update(|cx| {
2505 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2506 }));
2507 let handle = environment.terminal_handle.clone().unwrap();
2508
2509 let mut events = thread
2510 .update(cx, |thread, cx| {
2511 thread.add_tool(crate::TerminalTool::new(
2512 thread.project().clone(),
2513 environment,
2514 ));
2515 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2516 })
2517 .unwrap();
2518
2519 cx.run_until_parked();
2520
2521 // Simulate the model calling the terminal tool with a short timeout
2522 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2523 LanguageModelToolUse {
2524 id: "terminal_tool_1".into(),
2525 name: TerminalTool::NAME.into(),
2526 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2527 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2528 is_input_complete: true,
2529 thought_signature: None,
2530 },
2531 ));
2532 fake_model.end_last_completion_stream();
2533
2534 // Wait for the terminal tool to start running
2535 wait_for_terminal_tool_started(&mut events, cx).await;
2536
2537 // Advance clock past the timeout
2538 cx.executor().advance_clock(Duration::from_millis(200));
2539 cx.run_until_parked();
2540
2541 // The thread continues after tool completion - simulate the model ending its turn
2542 fake_model
2543 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2544 fake_model.end_last_completion_stream();
2545
2546 // Collect remaining events
2547 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2548
2549 // Verify the terminal was killed due to timeout
2550 assert!(
2551 handle.was_killed(),
2552 "expected terminal handle to be killed on timeout"
2553 );
2554
2555 // Verify we got an EndTurn (the tool completed, just with timeout)
2556 assert_eq!(
2557 stop_events(remaining_events),
2558 vec![acp::StopReason::EndTurn],
2559 );
2560
2561 // Verify the tool result indicates timeout, not user stopped
2562 thread.update(cx, |thread, _cx| {
2563 let message = thread.last_message().unwrap();
2564 let agent_message = message.as_agent_message().unwrap();
2565
2566 let tool_use = agent_message
2567 .content
2568 .iter()
2569 .find_map(|content| match content {
2570 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2571 _ => None,
2572 })
2573 .expect("expected tool use in agent message");
2574
2575 let tool_result = agent_message
2576 .tool_results
2577 .get(&tool_use.id)
2578 .expect("expected tool result");
2579
2580 let result_text = match &tool_result.content {
2581 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2582 _ => panic!("expected text content in tool result"),
2583 };
2584
2585 assert!(
2586 result_text.contains("timed out"),
2587 "expected tool result to indicate timeout, got: {result_text}"
2588 );
2589 assert!(
2590 !result_text.contains("The user stopped"),
2591 "tool result should not mention user stopped when it timed out, got: {result_text}"
2592 );
2593 });
2594}
2595
2596#[gpui::test]
2597async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2598 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2599 let fake_model = model.as_fake();
2600
2601 let events_1 = thread
2602 .update(cx, |thread, cx| {
2603 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2604 })
2605 .unwrap();
2606 cx.run_until_parked();
2607 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2608 cx.run_until_parked();
2609
2610 let events_2 = thread
2611 .update(cx, |thread, cx| {
2612 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2613 })
2614 .unwrap();
2615 cx.run_until_parked();
2616 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2617 fake_model
2618 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2619 fake_model.end_last_completion_stream();
2620
2621 let events_1 = events_1.collect::<Vec<_>>().await;
2622 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2623 let events_2 = events_2.collect::<Vec<_>>().await;
2624 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2625}
2626
2627#[gpui::test]
2628async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2629 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2630 let fake_model = model.as_fake();
2631
2632 let events_1 = thread
2633 .update(cx, |thread, cx| {
2634 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2635 })
2636 .unwrap();
2637 cx.run_until_parked();
2638 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2639 fake_model
2640 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2641 fake_model.end_last_completion_stream();
2642 let events_1 = events_1.collect::<Vec<_>>().await;
2643
2644 let events_2 = thread
2645 .update(cx, |thread, cx| {
2646 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2647 })
2648 .unwrap();
2649 cx.run_until_parked();
2650 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2651 fake_model
2652 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2653 fake_model.end_last_completion_stream();
2654 let events_2 = events_2.collect::<Vec<_>>().await;
2655
2656 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2657 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2658}
2659
2660#[gpui::test]
2661async fn test_refusal(cx: &mut TestAppContext) {
2662 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2663 let fake_model = model.as_fake();
2664
2665 let events = thread
2666 .update(cx, |thread, cx| {
2667 thread.send(UserMessageId::new(), ["Hello"], cx)
2668 })
2669 .unwrap();
2670 cx.run_until_parked();
2671 thread.read_with(cx, |thread, _| {
2672 assert_eq!(
2673 thread.to_markdown(),
2674 indoc! {"
2675 ## User
2676
2677 Hello
2678 "}
2679 );
2680 });
2681
2682 fake_model.send_last_completion_stream_text_chunk("Hey!");
2683 cx.run_until_parked();
2684 thread.read_with(cx, |thread, _| {
2685 assert_eq!(
2686 thread.to_markdown(),
2687 indoc! {"
2688 ## User
2689
2690 Hello
2691
2692 ## Assistant
2693
2694 Hey!
2695 "}
2696 );
2697 });
2698
2699 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2700 fake_model
2701 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2702 let events = events.collect::<Vec<_>>().await;
2703 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2704 thread.read_with(cx, |thread, _| {
2705 assert_eq!(thread.to_markdown(), "");
2706 });
2707}
2708
2709#[gpui::test]
2710async fn test_truncate_first_message(cx: &mut TestAppContext) {
2711 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2712 let fake_model = model.as_fake();
2713
2714 let message_id = UserMessageId::new();
2715 thread
2716 .update(cx, |thread, cx| {
2717 thread.send(message_id.clone(), ["Hello"], cx)
2718 })
2719 .unwrap();
2720 cx.run_until_parked();
2721 thread.read_with(cx, |thread, _| {
2722 assert_eq!(
2723 thread.to_markdown(),
2724 indoc! {"
2725 ## User
2726
2727 Hello
2728 "}
2729 );
2730 assert_eq!(thread.latest_token_usage(), None);
2731 });
2732
2733 fake_model.send_last_completion_stream_text_chunk("Hey!");
2734 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2735 language_model::TokenUsage {
2736 input_tokens: 32_000,
2737 output_tokens: 16_000,
2738 cache_creation_input_tokens: 0,
2739 cache_read_input_tokens: 0,
2740 },
2741 ));
2742 cx.run_until_parked();
2743 thread.read_with(cx, |thread, _| {
2744 assert_eq!(
2745 thread.to_markdown(),
2746 indoc! {"
2747 ## User
2748
2749 Hello
2750
2751 ## Assistant
2752
2753 Hey!
2754 "}
2755 );
2756 assert_eq!(
2757 thread.latest_token_usage(),
2758 Some(acp_thread::TokenUsage {
2759 used_tokens: 32_000 + 16_000,
2760 max_tokens: 1_000_000,
2761 max_output_tokens: None,
2762 input_tokens: 32_000,
2763 output_tokens: 16_000,
2764 })
2765 );
2766 });
2767
2768 thread
2769 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2770 .unwrap();
2771 cx.run_until_parked();
2772 thread.read_with(cx, |thread, _| {
2773 assert_eq!(thread.to_markdown(), "");
2774 assert_eq!(thread.latest_token_usage(), None);
2775 });
2776
2777 // Ensure we can still send a new message after truncation.
2778 thread
2779 .update(cx, |thread, cx| {
2780 thread.send(UserMessageId::new(), ["Hi"], cx)
2781 })
2782 .unwrap();
2783 thread.update(cx, |thread, _cx| {
2784 assert_eq!(
2785 thread.to_markdown(),
2786 indoc! {"
2787 ## User
2788
2789 Hi
2790 "}
2791 );
2792 });
2793 cx.run_until_parked();
2794 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2795 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2796 language_model::TokenUsage {
2797 input_tokens: 40_000,
2798 output_tokens: 20_000,
2799 cache_creation_input_tokens: 0,
2800 cache_read_input_tokens: 0,
2801 },
2802 ));
2803 cx.run_until_parked();
2804 thread.read_with(cx, |thread, _| {
2805 assert_eq!(
2806 thread.to_markdown(),
2807 indoc! {"
2808 ## User
2809
2810 Hi
2811
2812 ## Assistant
2813
2814 Ahoy!
2815 "}
2816 );
2817
2818 assert_eq!(
2819 thread.latest_token_usage(),
2820 Some(acp_thread::TokenUsage {
2821 used_tokens: 40_000 + 20_000,
2822 max_tokens: 1_000_000,
2823 max_output_tokens: None,
2824 input_tokens: 40_000,
2825 output_tokens: 20_000,
2826 })
2827 );
2828 });
2829}
2830
2831#[gpui::test]
2832async fn test_truncate_second_message(cx: &mut TestAppContext) {
2833 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2834 let fake_model = model.as_fake();
2835
2836 thread
2837 .update(cx, |thread, cx| {
2838 thread.send(UserMessageId::new(), ["Message 1"], cx)
2839 })
2840 .unwrap();
2841 cx.run_until_parked();
2842 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2843 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2844 language_model::TokenUsage {
2845 input_tokens: 32_000,
2846 output_tokens: 16_000,
2847 cache_creation_input_tokens: 0,
2848 cache_read_input_tokens: 0,
2849 },
2850 ));
2851 fake_model.end_last_completion_stream();
2852 cx.run_until_parked();
2853
2854 let assert_first_message_state = |cx: &mut TestAppContext| {
2855 thread.clone().read_with(cx, |thread, _| {
2856 assert_eq!(
2857 thread.to_markdown(),
2858 indoc! {"
2859 ## User
2860
2861 Message 1
2862
2863 ## Assistant
2864
2865 Message 1 response
2866 "}
2867 );
2868
2869 assert_eq!(
2870 thread.latest_token_usage(),
2871 Some(acp_thread::TokenUsage {
2872 used_tokens: 32_000 + 16_000,
2873 max_tokens: 1_000_000,
2874 max_output_tokens: None,
2875 input_tokens: 32_000,
2876 output_tokens: 16_000,
2877 })
2878 );
2879 });
2880 };
2881
2882 assert_first_message_state(cx);
2883
2884 let second_message_id = UserMessageId::new();
2885 thread
2886 .update(cx, |thread, cx| {
2887 thread.send(second_message_id.clone(), ["Message 2"], cx)
2888 })
2889 .unwrap();
2890 cx.run_until_parked();
2891
2892 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2893 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2894 language_model::TokenUsage {
2895 input_tokens: 40_000,
2896 output_tokens: 20_000,
2897 cache_creation_input_tokens: 0,
2898 cache_read_input_tokens: 0,
2899 },
2900 ));
2901 fake_model.end_last_completion_stream();
2902 cx.run_until_parked();
2903
2904 thread.read_with(cx, |thread, _| {
2905 assert_eq!(
2906 thread.to_markdown(),
2907 indoc! {"
2908 ## User
2909
2910 Message 1
2911
2912 ## Assistant
2913
2914 Message 1 response
2915
2916 ## User
2917
2918 Message 2
2919
2920 ## Assistant
2921
2922 Message 2 response
2923 "}
2924 );
2925
2926 assert_eq!(
2927 thread.latest_token_usage(),
2928 Some(acp_thread::TokenUsage {
2929 used_tokens: 40_000 + 20_000,
2930 max_tokens: 1_000_000,
2931 max_output_tokens: None,
2932 input_tokens: 40_000,
2933 output_tokens: 20_000,
2934 })
2935 );
2936 });
2937
2938 thread
2939 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2940 .unwrap();
2941 cx.run_until_parked();
2942
2943 assert_first_message_state(cx);
2944}
2945
2946#[gpui::test]
2947async fn test_title_generation(cx: &mut TestAppContext) {
2948 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2949 let fake_model = model.as_fake();
2950
2951 let summary_model = Arc::new(FakeLanguageModel::default());
2952 thread.update(cx, |thread, cx| {
2953 thread.set_summarization_model(Some(summary_model.clone()), cx)
2954 });
2955
2956 let send = thread
2957 .update(cx, |thread, cx| {
2958 thread.send(UserMessageId::new(), ["Hello"], cx)
2959 })
2960 .unwrap();
2961 cx.run_until_parked();
2962
2963 fake_model.send_last_completion_stream_text_chunk("Hey!");
2964 fake_model.end_last_completion_stream();
2965 cx.run_until_parked();
2966 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2967
2968 // Ensure the summary model has been invoked to generate a title.
2969 summary_model.send_last_completion_stream_text_chunk("Hello ");
2970 summary_model.send_last_completion_stream_text_chunk("world\nG");
2971 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2972 summary_model.end_last_completion_stream();
2973 send.collect::<Vec<_>>().await;
2974 cx.run_until_parked();
2975 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2976
2977 // Send another message, ensuring no title is generated this time.
2978 let send = thread
2979 .update(cx, |thread, cx| {
2980 thread.send(UserMessageId::new(), ["Hello again"], cx)
2981 })
2982 .unwrap();
2983 cx.run_until_parked();
2984 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2985 fake_model.end_last_completion_stream();
2986 cx.run_until_parked();
2987 assert_eq!(summary_model.pending_completions(), Vec::new());
2988 send.collect::<Vec<_>>().await;
2989 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2990}
2991
2992#[gpui::test]
2993async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2994 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2995 let fake_model = model.as_fake();
2996
2997 let _events = thread
2998 .update(cx, |thread, cx| {
2999 thread.add_tool(ToolRequiringPermission);
3000 thread.add_tool(EchoTool);
3001 thread.send(UserMessageId::new(), ["Hey!"], cx)
3002 })
3003 .unwrap();
3004 cx.run_until_parked();
3005
3006 let permission_tool_use = LanguageModelToolUse {
3007 id: "tool_id_1".into(),
3008 name: ToolRequiringPermission::NAME.into(),
3009 raw_input: "{}".into(),
3010 input: json!({}),
3011 is_input_complete: true,
3012 thought_signature: None,
3013 };
3014 let echo_tool_use = LanguageModelToolUse {
3015 id: "tool_id_2".into(),
3016 name: EchoTool::NAME.into(),
3017 raw_input: json!({"text": "test"}).to_string(),
3018 input: json!({"text": "test"}),
3019 is_input_complete: true,
3020 thought_signature: None,
3021 };
3022 fake_model.send_last_completion_stream_text_chunk("Hi!");
3023 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3024 permission_tool_use,
3025 ));
3026 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3027 echo_tool_use.clone(),
3028 ));
3029 fake_model.end_last_completion_stream();
3030 cx.run_until_parked();
3031
3032 // Ensure pending tools are skipped when building a request.
3033 let request = thread
3034 .read_with(cx, |thread, cx| {
3035 thread.build_completion_request(CompletionIntent::EditFile, cx)
3036 })
3037 .unwrap();
3038 assert_eq!(
3039 request.messages[1..],
3040 vec![
3041 LanguageModelRequestMessage {
3042 role: Role::User,
3043 content: vec!["Hey!".into()],
3044 cache: true,
3045 reasoning_details: None,
3046 },
3047 LanguageModelRequestMessage {
3048 role: Role::Assistant,
3049 content: vec![
3050 MessageContent::Text("Hi!".into()),
3051 MessageContent::ToolUse(echo_tool_use.clone())
3052 ],
3053 cache: false,
3054 reasoning_details: None,
3055 },
3056 LanguageModelRequestMessage {
3057 role: Role::User,
3058 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3059 tool_use_id: echo_tool_use.id.clone(),
3060 tool_name: echo_tool_use.name,
3061 is_error: false,
3062 content: "test".into(),
3063 output: Some("test".into())
3064 })],
3065 cache: false,
3066 reasoning_details: None,
3067 },
3068 ],
3069 );
3070}
3071
3072#[gpui::test]
3073async fn test_agent_connection(cx: &mut TestAppContext) {
3074 cx.update(settings::init);
3075 let templates = Templates::new();
3076
3077 // Initialize language model system with test provider
3078 cx.update(|cx| {
3079 gpui_tokio::init(cx);
3080
3081 let http_client = FakeHttpClient::with_404_response();
3082 let clock = Arc::new(clock::FakeSystemClock::new());
3083 let client = Client::new(clock, http_client, cx);
3084 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3085 language_model::init(client.clone(), cx);
3086 language_models::init(user_store, client.clone(), cx);
3087 LanguageModelRegistry::test(cx);
3088 });
3089 cx.executor().forbid_parking();
3090
3091 // Create a project for new_thread
3092 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3093 fake_fs.insert_tree(path!("/test"), json!({})).await;
3094 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3095 let cwd = Path::new("/test");
3096 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3097
3098 // Create agent and connection
3099 let agent = NativeAgent::new(
3100 project.clone(),
3101 thread_store,
3102 templates.clone(),
3103 None,
3104 fake_fs.clone(),
3105 &mut cx.to_async(),
3106 )
3107 .await
3108 .unwrap();
3109 let connection = NativeAgentConnection(agent.clone());
3110
3111 // Create a thread using new_thread
3112 let connection_rc = Rc::new(connection.clone());
3113 let acp_thread = cx
3114 .update(|cx| connection_rc.new_session(project, cwd, cx))
3115 .await
3116 .expect("new_thread should succeed");
3117
3118 // Get the session_id from the AcpThread
3119 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3120
3121 // Test model_selector returns Some
3122 let selector_opt = connection.model_selector(&session_id);
3123 assert!(
3124 selector_opt.is_some(),
3125 "agent should always support ModelSelector"
3126 );
3127 let selector = selector_opt.unwrap();
3128
3129 // Test list_models
3130 let listed_models = cx
3131 .update(|cx| selector.list_models(cx))
3132 .await
3133 .expect("list_models should succeed");
3134 let AgentModelList::Grouped(listed_models) = listed_models else {
3135 panic!("Unexpected model list type");
3136 };
3137 assert!(!listed_models.is_empty(), "should have at least one model");
3138 assert_eq!(
3139 listed_models[&AgentModelGroupName("Fake".into())][0]
3140 .id
3141 .0
3142 .as_ref(),
3143 "fake/fake"
3144 );
3145
3146 // Test selected_model returns the default
3147 let model = cx
3148 .update(|cx| selector.selected_model(cx))
3149 .await
3150 .expect("selected_model should succeed");
3151 let model = cx
3152 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3153 .unwrap();
3154 let model = model.as_fake();
3155 assert_eq!(model.id().0, "fake", "should return default model");
3156
3157 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3158 cx.run_until_parked();
3159 model.send_last_completion_stream_text_chunk("def");
3160 cx.run_until_parked();
3161 acp_thread.read_with(cx, |thread, cx| {
3162 assert_eq!(
3163 thread.to_markdown(cx),
3164 indoc! {"
3165 ## User
3166
3167 abc
3168
3169 ## Assistant
3170
3171 def
3172
3173 "}
3174 )
3175 });
3176
3177 // Test cancel
3178 cx.update(|cx| connection.cancel(&session_id, cx));
3179 request.await.expect("prompt should fail gracefully");
3180
3181 // Explicitly close the session and drop the ACP thread.
3182 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3183 .await
3184 .unwrap();
3185 drop(acp_thread);
3186 let result = cx
3187 .update(|cx| {
3188 connection.prompt(
3189 Some(acp_thread::UserMessageId::new()),
3190 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3191 cx,
3192 )
3193 })
3194 .await;
3195 assert_eq!(
3196 result.as_ref().unwrap_err().to_string(),
3197 "Session not found",
3198 "unexpected result: {:?}",
3199 result
3200 );
3201}
3202
3203#[gpui::test]
3204async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3205 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3206 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
3207 let fake_model = model.as_fake();
3208
3209 let mut events = thread
3210 .update(cx, |thread, cx| {
3211 thread.send(UserMessageId::new(), ["Echo something"], cx)
3212 })
3213 .unwrap();
3214 cx.run_until_parked();
3215
3216 // Simulate streaming partial input.
3217 let input = json!({});
3218 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3219 LanguageModelToolUse {
3220 id: "1".into(),
3221 name: EchoTool::NAME.into(),
3222 raw_input: input.to_string(),
3223 input,
3224 is_input_complete: false,
3225 thought_signature: None,
3226 },
3227 ));
3228
3229 // Input streaming completed
3230 let input = json!({ "text": "Hello!" });
3231 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3232 LanguageModelToolUse {
3233 id: "1".into(),
3234 name: "echo".into(),
3235 raw_input: input.to_string(),
3236 input,
3237 is_input_complete: true,
3238 thought_signature: None,
3239 },
3240 ));
3241 fake_model.end_last_completion_stream();
3242 cx.run_until_parked();
3243
3244 let tool_call = expect_tool_call(&mut events).await;
3245 assert_eq!(
3246 tool_call,
3247 acp::ToolCall::new("1", "Echo")
3248 .raw_input(json!({}))
3249 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3250 );
3251 let update = expect_tool_call_update_fields(&mut events).await;
3252 assert_eq!(
3253 update,
3254 acp::ToolCallUpdate::new(
3255 "1",
3256 acp::ToolCallUpdateFields::new()
3257 .title("Echo")
3258 .kind(acp::ToolKind::Other)
3259 .raw_input(json!({ "text": "Hello!"}))
3260 )
3261 );
3262 let update = expect_tool_call_update_fields(&mut events).await;
3263 assert_eq!(
3264 update,
3265 acp::ToolCallUpdate::new(
3266 "1",
3267 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3268 )
3269 );
3270 let update = expect_tool_call_update_fields(&mut events).await;
3271 assert_eq!(
3272 update,
3273 acp::ToolCallUpdate::new(
3274 "1",
3275 acp::ToolCallUpdateFields::new()
3276 .status(acp::ToolCallStatus::Completed)
3277 .raw_output("Hello!")
3278 )
3279 );
3280}
3281
3282#[gpui::test]
3283async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3284 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3285 let fake_model = model.as_fake();
3286
3287 let mut events = thread
3288 .update(cx, |thread, cx| {
3289 thread.send(UserMessageId::new(), ["Hello!"], cx)
3290 })
3291 .unwrap();
3292 cx.run_until_parked();
3293
3294 fake_model.send_last_completion_stream_text_chunk("Hey!");
3295 fake_model.end_last_completion_stream();
3296
3297 let mut retry_events = Vec::new();
3298 while let Some(Ok(event)) = events.next().await {
3299 match event {
3300 ThreadEvent::Retry(retry_status) => {
3301 retry_events.push(retry_status);
3302 }
3303 ThreadEvent::Stop(..) => break,
3304 _ => {}
3305 }
3306 }
3307
3308 assert_eq!(retry_events.len(), 0);
3309 thread.read_with(cx, |thread, _cx| {
3310 assert_eq!(
3311 thread.to_markdown(),
3312 indoc! {"
3313 ## User
3314
3315 Hello!
3316
3317 ## Assistant
3318
3319 Hey!
3320 "}
3321 )
3322 });
3323}
3324
3325#[gpui::test]
3326async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3327 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3328 let fake_model = model.as_fake();
3329
3330 let mut events = thread
3331 .update(cx, |thread, cx| {
3332 thread.send(UserMessageId::new(), ["Hello!"], cx)
3333 })
3334 .unwrap();
3335 cx.run_until_parked();
3336
3337 fake_model.send_last_completion_stream_text_chunk("Hey,");
3338 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3339 provider: LanguageModelProviderName::new("Anthropic"),
3340 retry_after: Some(Duration::from_secs(3)),
3341 });
3342 fake_model.end_last_completion_stream();
3343
3344 cx.executor().advance_clock(Duration::from_secs(3));
3345 cx.run_until_parked();
3346
3347 fake_model.send_last_completion_stream_text_chunk("there!");
3348 fake_model.end_last_completion_stream();
3349 cx.run_until_parked();
3350
3351 let mut retry_events = Vec::new();
3352 while let Some(Ok(event)) = events.next().await {
3353 match event {
3354 ThreadEvent::Retry(retry_status) => {
3355 retry_events.push(retry_status);
3356 }
3357 ThreadEvent::Stop(..) => break,
3358 _ => {}
3359 }
3360 }
3361
3362 assert_eq!(retry_events.len(), 1);
3363 assert!(matches!(
3364 retry_events[0],
3365 acp_thread::RetryStatus { attempt: 1, .. }
3366 ));
3367 thread.read_with(cx, |thread, _cx| {
3368 assert_eq!(
3369 thread.to_markdown(),
3370 indoc! {"
3371 ## User
3372
3373 Hello!
3374
3375 ## Assistant
3376
3377 Hey,
3378
3379 [resume]
3380
3381 ## Assistant
3382
3383 there!
3384 "}
3385 )
3386 });
3387}
3388
3389#[gpui::test]
3390async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3391 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3392 let fake_model = model.as_fake();
3393
3394 let events = thread
3395 .update(cx, |thread, cx| {
3396 thread.add_tool(EchoTool);
3397 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3398 })
3399 .unwrap();
3400 cx.run_until_parked();
3401
3402 let tool_use_1 = LanguageModelToolUse {
3403 id: "tool_1".into(),
3404 name: EchoTool::NAME.into(),
3405 raw_input: json!({"text": "test"}).to_string(),
3406 input: json!({"text": "test"}),
3407 is_input_complete: true,
3408 thought_signature: None,
3409 };
3410 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3411 tool_use_1.clone(),
3412 ));
3413 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3414 provider: LanguageModelProviderName::new("Anthropic"),
3415 retry_after: Some(Duration::from_secs(3)),
3416 });
3417 fake_model.end_last_completion_stream();
3418
3419 cx.executor().advance_clock(Duration::from_secs(3));
3420 let completion = fake_model.pending_completions().pop().unwrap();
3421 assert_eq!(
3422 completion.messages[1..],
3423 vec![
3424 LanguageModelRequestMessage {
3425 role: Role::User,
3426 content: vec!["Call the echo tool!".into()],
3427 cache: false,
3428 reasoning_details: None,
3429 },
3430 LanguageModelRequestMessage {
3431 role: Role::Assistant,
3432 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3433 cache: false,
3434 reasoning_details: None,
3435 },
3436 LanguageModelRequestMessage {
3437 role: Role::User,
3438 content: vec![language_model::MessageContent::ToolResult(
3439 LanguageModelToolResult {
3440 tool_use_id: tool_use_1.id.clone(),
3441 tool_name: tool_use_1.name.clone(),
3442 is_error: false,
3443 content: "test".into(),
3444 output: Some("test".into())
3445 }
3446 )],
3447 cache: true,
3448 reasoning_details: None,
3449 },
3450 ]
3451 );
3452
3453 fake_model.send_last_completion_stream_text_chunk("Done");
3454 fake_model.end_last_completion_stream();
3455 cx.run_until_parked();
3456 events.collect::<Vec<_>>().await;
3457 thread.read_with(cx, |thread, _cx| {
3458 assert_eq!(
3459 thread.last_message(),
3460 Some(Message::Agent(AgentMessage {
3461 content: vec![AgentMessageContent::Text("Done".into())],
3462 tool_results: IndexMap::default(),
3463 reasoning_details: None,
3464 }))
3465 );
3466 })
3467}
3468
3469#[gpui::test]
3470async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3471 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3472 let fake_model = model.as_fake();
3473
3474 let mut events = thread
3475 .update(cx, |thread, cx| {
3476 thread.send(UserMessageId::new(), ["Hello!"], cx)
3477 })
3478 .unwrap();
3479 cx.run_until_parked();
3480
3481 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3482 fake_model.send_last_completion_stream_error(
3483 LanguageModelCompletionError::ServerOverloaded {
3484 provider: LanguageModelProviderName::new("Anthropic"),
3485 retry_after: Some(Duration::from_secs(3)),
3486 },
3487 );
3488 fake_model.end_last_completion_stream();
3489 cx.executor().advance_clock(Duration::from_secs(3));
3490 cx.run_until_parked();
3491 }
3492
3493 let mut errors = Vec::new();
3494 let mut retry_events = Vec::new();
3495 while let Some(event) = events.next().await {
3496 match event {
3497 Ok(ThreadEvent::Retry(retry_status)) => {
3498 retry_events.push(retry_status);
3499 }
3500 Ok(ThreadEvent::Stop(..)) => break,
3501 Err(error) => errors.push(error),
3502 _ => {}
3503 }
3504 }
3505
3506 assert_eq!(
3507 retry_events.len(),
3508 crate::thread::MAX_RETRY_ATTEMPTS as usize
3509 );
3510 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3511 assert_eq!(retry_events[i].attempt, i + 1);
3512 }
3513 assert_eq!(errors.len(), 1);
3514 let error = errors[0]
3515 .downcast_ref::<LanguageModelCompletionError>()
3516 .unwrap();
3517 assert!(matches!(
3518 error,
3519 LanguageModelCompletionError::ServerOverloaded { .. }
3520 ));
3521}
3522
3523/// Filters out the stop events for asserting against in tests
3524fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3525 result_events
3526 .into_iter()
3527 .filter_map(|event| match event.unwrap() {
3528 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3529 _ => None,
3530 })
3531 .collect()
3532}
3533
3534struct ThreadTest {
3535 model: Arc<dyn LanguageModel>,
3536 thread: Entity<Thread>,
3537 project_context: Entity<ProjectContext>,
3538 context_server_store: Entity<ContextServerStore>,
3539 fs: Arc<FakeFs>,
3540}
3541
3542enum TestModel {
3543 Sonnet4,
3544 Fake,
3545}
3546
3547impl TestModel {
3548 fn id(&self) -> LanguageModelId {
3549 match self {
3550 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3551 TestModel::Fake => unreachable!(),
3552 }
3553 }
3554}
3555
3556async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3557 cx.executor().allow_parking();
3558
3559 let fs = FakeFs::new(cx.background_executor.clone());
3560 fs.create_dir(paths::settings_file().parent().unwrap())
3561 .await
3562 .unwrap();
3563 fs.insert_file(
3564 paths::settings_file(),
3565 json!({
3566 "agent": {
3567 "default_profile": "test-profile",
3568 "profiles": {
3569 "test-profile": {
3570 "name": "Test Profile",
3571 "tools": {
3572 EchoTool::NAME: true,
3573 DelayTool::NAME: true,
3574 WordListTool::NAME: true,
3575 ToolRequiringPermission::NAME: true,
3576 InfiniteTool::NAME: true,
3577 CancellationAwareTool::NAME: true,
3578 (TerminalTool::NAME): true,
3579 }
3580 }
3581 }
3582 }
3583 })
3584 .to_string()
3585 .into_bytes(),
3586 )
3587 .await;
3588
3589 cx.update(|cx| {
3590 settings::init(cx);
3591
3592 match model {
3593 TestModel::Fake => {}
3594 TestModel::Sonnet4 => {
3595 gpui_tokio::init(cx);
3596 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3597 cx.set_http_client(Arc::new(http_client));
3598 let client = Client::production(cx);
3599 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3600 language_model::init(client.clone(), cx);
3601 language_models::init(user_store, client.clone(), cx);
3602 }
3603 };
3604
3605 watch_settings(fs.clone(), cx);
3606 });
3607
3608 let templates = Templates::new();
3609
3610 fs.insert_tree(path!("/test"), json!({})).await;
3611 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3612
3613 let model = cx
3614 .update(|cx| {
3615 if let TestModel::Fake = model {
3616 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3617 } else {
3618 let model_id = model.id();
3619 let models = LanguageModelRegistry::read_global(cx);
3620 let model = models
3621 .available_models(cx)
3622 .find(|model| model.id() == model_id)
3623 .unwrap();
3624
3625 let provider = models.provider(&model.provider_id()).unwrap();
3626 let authenticated = provider.authenticate(cx);
3627
3628 cx.spawn(async move |_cx| {
3629 authenticated.await.unwrap();
3630 model
3631 })
3632 }
3633 })
3634 .await;
3635
3636 let project_context = cx.new(|_cx| ProjectContext::default());
3637 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3638 let context_server_registry =
3639 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3640 let thread = cx.new(|cx| {
3641 Thread::new(
3642 project,
3643 project_context.clone(),
3644 context_server_registry,
3645 templates,
3646 Some(model.clone()),
3647 cx,
3648 )
3649 });
3650 ThreadTest {
3651 model,
3652 thread,
3653 project_context,
3654 context_server_store,
3655 fs,
3656 }
3657}
3658
3659#[cfg(test)]
3660#[ctor::ctor]
3661fn init_logger() {
3662 if std::env::var("RUST_LOG").is_ok() {
3663 env_logger::init();
3664 }
3665}
3666
3667fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3668 let fs = fs.clone();
3669 cx.spawn({
3670 async move |cx| {
3671 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3672 cx.background_executor(),
3673 fs,
3674 paths::settings_file().clone(),
3675 );
3676 let _watcher_task = watcher_task;
3677
3678 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3679 cx.update(|cx| {
3680 SettingsStore::update_global(cx, |settings, cx| {
3681 settings.set_user_settings(&new_settings_content, cx)
3682 })
3683 })
3684 .ok();
3685 }
3686 }
3687 })
3688 .detach();
3689}
3690
3691fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3692 completion
3693 .tools
3694 .iter()
3695 .map(|tool| tool.name.clone())
3696 .collect()
3697}
3698
3699fn setup_context_server(
3700 name: &'static str,
3701 tools: Vec<context_server::types::Tool>,
3702 context_server_store: &Entity<ContextServerStore>,
3703 cx: &mut TestAppContext,
3704) -> mpsc::UnboundedReceiver<(
3705 context_server::types::CallToolParams,
3706 oneshot::Sender<context_server::types::CallToolResponse>,
3707)> {
3708 cx.update(|cx| {
3709 let mut settings = ProjectSettings::get_global(cx).clone();
3710 settings.context_servers.insert(
3711 name.into(),
3712 project::project_settings::ContextServerSettings::Stdio {
3713 enabled: true,
3714 remote: false,
3715 command: ContextServerCommand {
3716 path: "somebinary".into(),
3717 args: Vec::new(),
3718 env: None,
3719 timeout: None,
3720 },
3721 },
3722 );
3723 ProjectSettings::override_global(settings, cx);
3724 });
3725
3726 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3727 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3728 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3729 context_server::types::InitializeResponse {
3730 protocol_version: context_server::types::ProtocolVersion(
3731 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3732 ),
3733 server_info: context_server::types::Implementation {
3734 name: name.into(),
3735 version: "1.0.0".to_string(),
3736 },
3737 capabilities: context_server::types::ServerCapabilities {
3738 tools: Some(context_server::types::ToolsCapabilities {
3739 list_changed: Some(true),
3740 }),
3741 ..Default::default()
3742 },
3743 meta: None,
3744 }
3745 })
3746 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3747 let tools = tools.clone();
3748 async move {
3749 context_server::types::ListToolsResponse {
3750 tools,
3751 next_cursor: None,
3752 meta: None,
3753 }
3754 }
3755 })
3756 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3757 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3758 async move {
3759 let (response_tx, response_rx) = oneshot::channel();
3760 mcp_tool_calls_tx
3761 .unbounded_send((params, response_tx))
3762 .unwrap();
3763 response_rx.await.unwrap()
3764 }
3765 });
3766 context_server_store.update(cx, |store, cx| {
3767 store.start_server(
3768 Arc::new(ContextServer::new(
3769 ContextServerId(name.into()),
3770 Arc::new(fake_transport),
3771 )),
3772 cx,
3773 );
3774 });
3775 cx.run_until_parked();
3776 mcp_tool_calls_rx
3777}
3778
3779#[gpui::test]
3780async fn test_tokens_before_message(cx: &mut TestAppContext) {
3781 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3782 let fake_model = model.as_fake();
3783
3784 // First message
3785 let message_1_id = UserMessageId::new();
3786 thread
3787 .update(cx, |thread, cx| {
3788 thread.send(message_1_id.clone(), ["First message"], cx)
3789 })
3790 .unwrap();
3791 cx.run_until_parked();
3792
3793 // Before any response, tokens_before_message should return None for first message
3794 thread.read_with(cx, |thread, _| {
3795 assert_eq!(
3796 thread.tokens_before_message(&message_1_id),
3797 None,
3798 "First message should have no tokens before it"
3799 );
3800 });
3801
3802 // Complete first message with usage
3803 fake_model.send_last_completion_stream_text_chunk("Response 1");
3804 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3805 language_model::TokenUsage {
3806 input_tokens: 100,
3807 output_tokens: 50,
3808 cache_creation_input_tokens: 0,
3809 cache_read_input_tokens: 0,
3810 },
3811 ));
3812 fake_model.end_last_completion_stream();
3813 cx.run_until_parked();
3814
3815 // First message still has no tokens before it
3816 thread.read_with(cx, |thread, _| {
3817 assert_eq!(
3818 thread.tokens_before_message(&message_1_id),
3819 None,
3820 "First message should still have no tokens before it after response"
3821 );
3822 });
3823
3824 // Second message
3825 let message_2_id = UserMessageId::new();
3826 thread
3827 .update(cx, |thread, cx| {
3828 thread.send(message_2_id.clone(), ["Second message"], cx)
3829 })
3830 .unwrap();
3831 cx.run_until_parked();
3832
3833 // Second message should have first message's input tokens before it
3834 thread.read_with(cx, |thread, _| {
3835 assert_eq!(
3836 thread.tokens_before_message(&message_2_id),
3837 Some(100),
3838 "Second message should have 100 tokens before it (from first request)"
3839 );
3840 });
3841
3842 // Complete second message
3843 fake_model.send_last_completion_stream_text_chunk("Response 2");
3844 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3845 language_model::TokenUsage {
3846 input_tokens: 250, // Total for this request (includes previous context)
3847 output_tokens: 75,
3848 cache_creation_input_tokens: 0,
3849 cache_read_input_tokens: 0,
3850 },
3851 ));
3852 fake_model.end_last_completion_stream();
3853 cx.run_until_parked();
3854
3855 // Third message
3856 let message_3_id = UserMessageId::new();
3857 thread
3858 .update(cx, |thread, cx| {
3859 thread.send(message_3_id.clone(), ["Third message"], cx)
3860 })
3861 .unwrap();
3862 cx.run_until_parked();
3863
3864 // Third message should have second message's input tokens (250) before it
3865 thread.read_with(cx, |thread, _| {
3866 assert_eq!(
3867 thread.tokens_before_message(&message_3_id),
3868 Some(250),
3869 "Third message should have 250 tokens before it (from second request)"
3870 );
3871 // Second message should still have 100
3872 assert_eq!(
3873 thread.tokens_before_message(&message_2_id),
3874 Some(100),
3875 "Second message should still have 100 tokens before it"
3876 );
3877 // First message still has none
3878 assert_eq!(
3879 thread.tokens_before_message(&message_1_id),
3880 None,
3881 "First message should still have no tokens before it"
3882 );
3883 });
3884}
3885
3886#[gpui::test]
3887async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3888 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3889 let fake_model = model.as_fake();
3890
3891 // Set up three messages with responses
3892 let message_1_id = UserMessageId::new();
3893 thread
3894 .update(cx, |thread, cx| {
3895 thread.send(message_1_id.clone(), ["Message 1"], cx)
3896 })
3897 .unwrap();
3898 cx.run_until_parked();
3899 fake_model.send_last_completion_stream_text_chunk("Response 1");
3900 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3901 language_model::TokenUsage {
3902 input_tokens: 100,
3903 output_tokens: 50,
3904 cache_creation_input_tokens: 0,
3905 cache_read_input_tokens: 0,
3906 },
3907 ));
3908 fake_model.end_last_completion_stream();
3909 cx.run_until_parked();
3910
3911 let message_2_id = UserMessageId::new();
3912 thread
3913 .update(cx, |thread, cx| {
3914 thread.send(message_2_id.clone(), ["Message 2"], cx)
3915 })
3916 .unwrap();
3917 cx.run_until_parked();
3918 fake_model.send_last_completion_stream_text_chunk("Response 2");
3919 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3920 language_model::TokenUsage {
3921 input_tokens: 250,
3922 output_tokens: 75,
3923 cache_creation_input_tokens: 0,
3924 cache_read_input_tokens: 0,
3925 },
3926 ));
3927 fake_model.end_last_completion_stream();
3928 cx.run_until_parked();
3929
3930 // Verify initial state
3931 thread.read_with(cx, |thread, _| {
3932 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3933 });
3934
3935 // Truncate at message 2 (removes message 2 and everything after)
3936 thread
3937 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3938 .unwrap();
3939 cx.run_until_parked();
3940
3941 // After truncation, message_2_id no longer exists, so lookup should return None
3942 thread.read_with(cx, |thread, _| {
3943 assert_eq!(
3944 thread.tokens_before_message(&message_2_id),
3945 None,
3946 "After truncation, message 2 no longer exists"
3947 );
3948 // Message 1 still exists but has no tokens before it
3949 assert_eq!(
3950 thread.tokens_before_message(&message_1_id),
3951 None,
3952 "First message still has no tokens before it"
3953 );
3954 });
3955}
3956
3957#[gpui::test]
3958async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3959 init_test(cx);
3960
3961 let fs = FakeFs::new(cx.executor());
3962 fs.insert_tree("/root", json!({})).await;
3963 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3964
3965 // Test 1: Deny rule blocks command
3966 {
3967 let environment = Rc::new(cx.update(|cx| {
3968 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3969 }));
3970
3971 cx.update(|cx| {
3972 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3973 settings.tool_permissions.tools.insert(
3974 TerminalTool::NAME.into(),
3975 agent_settings::ToolRules {
3976 default: Some(settings::ToolPermissionMode::Confirm),
3977 always_allow: vec![],
3978 always_deny: vec![
3979 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3980 ],
3981 always_confirm: vec![],
3982 invalid_patterns: vec![],
3983 },
3984 );
3985 agent_settings::AgentSettings::override_global(settings, cx);
3986 });
3987
3988 #[allow(clippy::arc_with_non_send_sync)]
3989 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3990 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3991
3992 let task = cx.update(|cx| {
3993 tool.run(
3994 crate::TerminalToolInput {
3995 command: "rm -rf /".to_string(),
3996 cd: ".".to_string(),
3997 timeout_ms: None,
3998 },
3999 event_stream,
4000 cx,
4001 )
4002 });
4003
4004 let result = task.await;
4005 assert!(
4006 result.is_err(),
4007 "expected command to be blocked by deny rule"
4008 );
4009 let err_msg = result.unwrap_err().to_lowercase();
4010 assert!(
4011 err_msg.contains("blocked"),
4012 "error should mention the command was blocked"
4013 );
4014 }
4015
4016 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4017 {
4018 let environment = Rc::new(cx.update(|cx| {
4019 FakeThreadEnvironment::default()
4020 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4021 }));
4022
4023 cx.update(|cx| {
4024 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4025 settings.tool_permissions.tools.insert(
4026 TerminalTool::NAME.into(),
4027 agent_settings::ToolRules {
4028 default: Some(settings::ToolPermissionMode::Deny),
4029 always_allow: vec![
4030 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4031 ],
4032 always_deny: vec![],
4033 always_confirm: vec![],
4034 invalid_patterns: vec![],
4035 },
4036 );
4037 agent_settings::AgentSettings::override_global(settings, cx);
4038 });
4039
4040 #[allow(clippy::arc_with_non_send_sync)]
4041 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4042 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4043
4044 let task = cx.update(|cx| {
4045 tool.run(
4046 crate::TerminalToolInput {
4047 command: "echo hello".to_string(),
4048 cd: ".".to_string(),
4049 timeout_ms: None,
4050 },
4051 event_stream,
4052 cx,
4053 )
4054 });
4055
4056 let update = rx.expect_update_fields().await;
4057 assert!(
4058 update.content.iter().any(|blocks| {
4059 blocks
4060 .iter()
4061 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4062 }),
4063 "expected terminal content (allow rule should skip confirmation and override default deny)"
4064 );
4065
4066 let result = task.await;
4067 assert!(
4068 result.is_ok(),
4069 "expected command to succeed without confirmation"
4070 );
4071 }
4072
4073 // Test 3: global default: allow does NOT override always_confirm patterns
4074 {
4075 let environment = Rc::new(cx.update(|cx| {
4076 FakeThreadEnvironment::default()
4077 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4078 }));
4079
4080 cx.update(|cx| {
4081 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4082 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4083 settings.tool_permissions.tools.insert(
4084 TerminalTool::NAME.into(),
4085 agent_settings::ToolRules {
4086 default: Some(settings::ToolPermissionMode::Allow),
4087 always_allow: vec![],
4088 always_deny: vec![],
4089 always_confirm: vec![
4090 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4091 ],
4092 invalid_patterns: vec![],
4093 },
4094 );
4095 agent_settings::AgentSettings::override_global(settings, cx);
4096 });
4097
4098 #[allow(clippy::arc_with_non_send_sync)]
4099 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4100 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4101
4102 let _task = cx.update(|cx| {
4103 tool.run(
4104 crate::TerminalToolInput {
4105 command: "sudo rm file".to_string(),
4106 cd: ".".to_string(),
4107 timeout_ms: None,
4108 },
4109 event_stream,
4110 cx,
4111 )
4112 });
4113
4114 // With global default: allow, confirm patterns are still respected
4115 // The expect_authorization() call will panic if no authorization is requested,
4116 // which validates that the confirm pattern still triggers confirmation
4117 let _auth = rx.expect_authorization().await;
4118
4119 drop(_task);
4120 }
4121
4122 // Test 4: tool-specific default: deny is respected even with global default: allow
4123 {
4124 let environment = Rc::new(cx.update(|cx| {
4125 FakeThreadEnvironment::default()
4126 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4127 }));
4128
4129 cx.update(|cx| {
4130 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4131 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4132 settings.tool_permissions.tools.insert(
4133 TerminalTool::NAME.into(),
4134 agent_settings::ToolRules {
4135 default: Some(settings::ToolPermissionMode::Deny),
4136 always_allow: vec![],
4137 always_deny: vec![],
4138 always_confirm: vec![],
4139 invalid_patterns: vec![],
4140 },
4141 );
4142 agent_settings::AgentSettings::override_global(settings, cx);
4143 });
4144
4145 #[allow(clippy::arc_with_non_send_sync)]
4146 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4147 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4148
4149 let task = cx.update(|cx| {
4150 tool.run(
4151 crate::TerminalToolInput {
4152 command: "echo hello".to_string(),
4153 cd: ".".to_string(),
4154 timeout_ms: None,
4155 },
4156 event_stream,
4157 cx,
4158 )
4159 });
4160
4161 // tool-specific default: deny is respected even with global default: allow
4162 let result = task.await;
4163 assert!(
4164 result.is_err(),
4165 "expected command to be blocked by tool-specific deny default"
4166 );
4167 let err_msg = result.unwrap_err().to_lowercase();
4168 assert!(
4169 err_msg.contains("disabled"),
4170 "error should mention the tool is disabled, got: {err_msg}"
4171 );
4172 }
4173}
4174
4175#[gpui::test]
4176async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4177 init_test(cx);
4178 cx.update(|cx| {
4179 LanguageModelRegistry::test(cx);
4180 });
4181 cx.update(|cx| {
4182 cx.update_flags(true, vec!["subagents".to_string()]);
4183 });
4184
4185 let fs = FakeFs::new(cx.executor());
4186 fs.insert_tree(
4187 "/",
4188 json!({
4189 "a": {
4190 "b.md": "Lorem"
4191 }
4192 }),
4193 )
4194 .await;
4195 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4196 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4197 let agent = NativeAgent::new(
4198 project.clone(),
4199 thread_store.clone(),
4200 Templates::new(),
4201 None,
4202 fs.clone(),
4203 &mut cx.to_async(),
4204 )
4205 .await
4206 .unwrap();
4207 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4208
4209 let acp_thread = cx
4210 .update(|cx| {
4211 connection
4212 .clone()
4213 .new_session(project.clone(), Path::new(""), cx)
4214 })
4215 .await
4216 .unwrap();
4217 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4218 let thread = agent.read_with(cx, |agent, _| {
4219 agent.sessions.get(&session_id).unwrap().thread.clone()
4220 });
4221 let model = Arc::new(FakeLanguageModel::default());
4222
4223 // Ensure empty threads are not saved, even if they get mutated.
4224 thread.update(cx, |thread, cx| {
4225 thread.set_model(model.clone(), cx);
4226 });
4227 cx.run_until_parked();
4228
4229 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4230 cx.run_until_parked();
4231 model.send_last_completion_stream_text_chunk("spawning subagent");
4232 let subagent_tool_input = SpawnAgentToolInput {
4233 label: "label".to_string(),
4234 message: "subagent task prompt".to_string(),
4235 session_id: None,
4236 };
4237 let subagent_tool_use = LanguageModelToolUse {
4238 id: "subagent_1".into(),
4239 name: SpawnAgentTool::NAME.into(),
4240 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4241 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4242 is_input_complete: true,
4243 thought_signature: None,
4244 };
4245 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4246 subagent_tool_use,
4247 ));
4248 model.end_last_completion_stream();
4249
4250 cx.run_until_parked();
4251
4252 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4253 thread
4254 .running_subagent_ids(cx)
4255 .get(0)
4256 .expect("subagent thread should be running")
4257 .clone()
4258 });
4259
4260 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4261 agent
4262 .sessions
4263 .get(&subagent_session_id)
4264 .expect("subagent session should exist")
4265 .acp_thread
4266 .clone()
4267 });
4268
4269 model.send_last_completion_stream_text_chunk("subagent task response");
4270 model.end_last_completion_stream();
4271
4272 cx.run_until_parked();
4273
4274 assert_eq!(
4275 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4276 indoc! {"
4277 ## User
4278
4279 subagent task prompt
4280
4281 ## Assistant
4282
4283 subagent task response
4284
4285 "}
4286 );
4287
4288 model.send_last_completion_stream_text_chunk("Response");
4289 model.end_last_completion_stream();
4290
4291 send.await.unwrap();
4292
4293 assert_eq!(
4294 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4295 indoc! {r#"
4296 ## User
4297
4298 Prompt
4299
4300 ## Assistant
4301
4302 spawning subagent
4303
4304 **Tool Call: label**
4305 Status: Completed
4306
4307 subagent task response
4308
4309
4310 ## Assistant
4311
4312 Response
4313
4314 "#},
4315 );
4316}
4317
4318#[gpui::test]
4319async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4320 init_test(cx);
4321 cx.update(|cx| {
4322 LanguageModelRegistry::test(cx);
4323 });
4324 cx.update(|cx| {
4325 cx.update_flags(true, vec!["subagents".to_string()]);
4326 });
4327
4328 let fs = FakeFs::new(cx.executor());
4329 fs.insert_tree(
4330 "/",
4331 json!({
4332 "a": {
4333 "b.md": "Lorem"
4334 }
4335 }),
4336 )
4337 .await;
4338 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4339 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4340 let agent = NativeAgent::new(
4341 project.clone(),
4342 thread_store.clone(),
4343 Templates::new(),
4344 None,
4345 fs.clone(),
4346 &mut cx.to_async(),
4347 )
4348 .await
4349 .unwrap();
4350 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4351
4352 let acp_thread = cx
4353 .update(|cx| {
4354 connection
4355 .clone()
4356 .new_session(project.clone(), Path::new(""), cx)
4357 })
4358 .await
4359 .unwrap();
4360 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4361 let thread = agent.read_with(cx, |agent, _| {
4362 agent.sessions.get(&session_id).unwrap().thread.clone()
4363 });
4364 let model = Arc::new(FakeLanguageModel::default());
4365
4366 // Ensure empty threads are not saved, even if they get mutated.
4367 thread.update(cx, |thread, cx| {
4368 thread.set_model(model.clone(), cx);
4369 });
4370 cx.run_until_parked();
4371
4372 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4373 cx.run_until_parked();
4374 model.send_last_completion_stream_text_chunk("spawning subagent");
4375 let subagent_tool_input = SpawnAgentToolInput {
4376 label: "label".to_string(),
4377 message: "subagent task prompt".to_string(),
4378 session_id: None,
4379 };
4380 let subagent_tool_use = LanguageModelToolUse {
4381 id: "subagent_1".into(),
4382 name: SpawnAgentTool::NAME.into(),
4383 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4384 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4385 is_input_complete: true,
4386 thought_signature: None,
4387 };
4388 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4389 subagent_tool_use,
4390 ));
4391 model.end_last_completion_stream();
4392
4393 cx.run_until_parked();
4394
4395 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4396 thread
4397 .running_subagent_ids(cx)
4398 .get(0)
4399 .expect("subagent thread should be running")
4400 .clone()
4401 });
4402 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4403 agent
4404 .sessions
4405 .get(&subagent_session_id)
4406 .expect("subagent session should exist")
4407 .acp_thread
4408 .clone()
4409 });
4410
4411 // model.send_last_completion_stream_text_chunk("subagent task response");
4412 // model.end_last_completion_stream();
4413
4414 // cx.run_until_parked();
4415
4416 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4417
4418 cx.run_until_parked();
4419
4420 send.await.unwrap();
4421
4422 acp_thread.read_with(cx, |thread, cx| {
4423 assert_eq!(thread.status(), ThreadStatus::Idle);
4424 assert_eq!(
4425 thread.to_markdown(cx),
4426 indoc! {"
4427 ## User
4428
4429 Prompt
4430
4431 ## Assistant
4432
4433 spawning subagent
4434
4435 **Tool Call: label**
4436 Status: Canceled
4437
4438 "}
4439 );
4440 });
4441 subagent_acp_thread.read_with(cx, |thread, cx| {
4442 assert_eq!(thread.status(), ThreadStatus::Idle);
4443 assert_eq!(
4444 thread.to_markdown(cx),
4445 indoc! {"
4446 ## User
4447
4448 subagent task prompt
4449
4450 "}
4451 );
4452 });
4453}
4454
4455#[gpui::test]
4456async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
4457 init_test(cx);
4458 cx.update(|cx| {
4459 LanguageModelRegistry::test(cx);
4460 });
4461 cx.update(|cx| {
4462 cx.update_flags(true, vec!["subagents".to_string()]);
4463 });
4464
4465 let fs = FakeFs::new(cx.executor());
4466 fs.insert_tree(
4467 "/",
4468 json!({
4469 "a": {
4470 "b.md": "Lorem"
4471 }
4472 }),
4473 )
4474 .await;
4475 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4476 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4477 let agent = NativeAgent::new(
4478 project.clone(),
4479 thread_store.clone(),
4480 Templates::new(),
4481 None,
4482 fs.clone(),
4483 &mut cx.to_async(),
4484 )
4485 .await
4486 .unwrap();
4487 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4488
4489 let acp_thread = cx
4490 .update(|cx| {
4491 connection
4492 .clone()
4493 .new_session(project.clone(), Path::new(""), cx)
4494 })
4495 .await
4496 .unwrap();
4497 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4498 let thread = agent.read_with(cx, |agent, _| {
4499 agent.sessions.get(&session_id).unwrap().thread.clone()
4500 });
4501 let model = Arc::new(FakeLanguageModel::default());
4502
4503 thread.update(cx, |thread, cx| {
4504 thread.set_model(model.clone(), cx);
4505 });
4506 cx.run_until_parked();
4507
4508 // === First turn: create subagent ===
4509 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
4510 cx.run_until_parked();
4511 model.send_last_completion_stream_text_chunk("spawning subagent");
4512 let subagent_tool_input = SpawnAgentToolInput {
4513 label: "initial task".to_string(),
4514 message: "do the first task".to_string(),
4515 session_id: None,
4516 };
4517 let subagent_tool_use = LanguageModelToolUse {
4518 id: "subagent_1".into(),
4519 name: SpawnAgentTool::NAME.into(),
4520 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4521 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4522 is_input_complete: true,
4523 thought_signature: None,
4524 };
4525 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4526 subagent_tool_use,
4527 ));
4528 model.end_last_completion_stream();
4529
4530 cx.run_until_parked();
4531
4532 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4533 thread
4534 .running_subagent_ids(cx)
4535 .get(0)
4536 .expect("subagent thread should be running")
4537 .clone()
4538 });
4539
4540 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4541 agent
4542 .sessions
4543 .get(&subagent_session_id)
4544 .expect("subagent session should exist")
4545 .acp_thread
4546 .clone()
4547 });
4548
4549 // Subagent responds
4550 model.send_last_completion_stream_text_chunk("first task response");
4551 model.end_last_completion_stream();
4552
4553 cx.run_until_parked();
4554
4555 // Parent model responds to complete first turn
4556 model.send_last_completion_stream_text_chunk("First response");
4557 model.end_last_completion_stream();
4558
4559 send.await.unwrap();
4560
4561 // Verify subagent is no longer running
4562 thread.read_with(cx, |thread, cx| {
4563 assert!(
4564 thread.running_subagent_ids(cx).is_empty(),
4565 "subagent should not be running after completion"
4566 );
4567 });
4568
4569 // === Second turn: resume subagent with session_id ===
4570 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
4571 cx.run_until_parked();
4572 model.send_last_completion_stream_text_chunk("resuming subagent");
4573 let resume_tool_input = SpawnAgentToolInput {
4574 label: "follow-up task".to_string(),
4575 message: "do the follow-up task".to_string(),
4576 session_id: Some(subagent_session_id.clone()),
4577 };
4578 let resume_tool_use = LanguageModelToolUse {
4579 id: "subagent_2".into(),
4580 name: SpawnAgentTool::NAME.into(),
4581 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
4582 input: serde_json::to_value(&resume_tool_input).unwrap(),
4583 is_input_complete: true,
4584 thought_signature: None,
4585 };
4586 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
4587 model.end_last_completion_stream();
4588
4589 cx.run_until_parked();
4590
4591 // Subagent should be running again with the same session
4592 thread.read_with(cx, |thread, cx| {
4593 let running = thread.running_subagent_ids(cx);
4594 assert_eq!(running.len(), 1, "subagent should be running");
4595 assert_eq!(running[0], subagent_session_id, "should be same session");
4596 });
4597
4598 // Subagent responds to follow-up
4599 model.send_last_completion_stream_text_chunk("follow-up task response");
4600 model.end_last_completion_stream();
4601
4602 cx.run_until_parked();
4603
4604 // Parent model responds to complete second turn
4605 model.send_last_completion_stream_text_chunk("Second response");
4606 model.end_last_completion_stream();
4607
4608 send2.await.unwrap();
4609
4610 // Verify subagent is no longer running
4611 thread.read_with(cx, |thread, cx| {
4612 assert!(
4613 thread.running_subagent_ids(cx).is_empty(),
4614 "subagent should not be running after resume completion"
4615 );
4616 });
4617
4618 // Verify the subagent's acp thread has both conversation turns
4619 assert_eq!(
4620 subagent_acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4621 indoc! {"
4622 ## User
4623
4624 do the first task
4625
4626 ## Assistant
4627
4628 first task response
4629
4630 ## User
4631
4632 do the follow-up task
4633
4634 ## Assistant
4635
4636 follow-up task response
4637
4638 "}
4639 );
4640}
4641
4642#[gpui::test]
4643async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4644 init_test(cx);
4645
4646 cx.update(|cx| {
4647 cx.update_flags(true, vec!["subagents".to_string()]);
4648 });
4649
4650 let fs = FakeFs::new(cx.executor());
4651 fs.insert_tree(path!("/test"), json!({})).await;
4652 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4653 let project_context = cx.new(|_cx| ProjectContext::default());
4654 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4655 let context_server_registry =
4656 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4657 let model = Arc::new(FakeLanguageModel::default());
4658
4659 let environment = Rc::new(cx.update(|cx| {
4660 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4661 }));
4662
4663 let thread = cx.new(|cx| {
4664 let mut thread = Thread::new(
4665 project.clone(),
4666 project_context,
4667 context_server_registry,
4668 Templates::new(),
4669 Some(model),
4670 cx,
4671 );
4672 thread.add_default_tools(environment, cx);
4673 thread
4674 });
4675
4676 thread.read_with(cx, |thread, _| {
4677 assert!(
4678 thread.has_registered_tool(SpawnAgentTool::NAME),
4679 "subagent tool should be present when feature flag is enabled"
4680 );
4681 });
4682}
4683
4684#[gpui::test]
4685async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4686 init_test(cx);
4687
4688 cx.update(|cx| {
4689 cx.update_flags(true, vec!["subagents".to_string()]);
4690 });
4691
4692 let fs = FakeFs::new(cx.executor());
4693 fs.insert_tree(path!("/test"), json!({})).await;
4694 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4695 let project_context = cx.new(|_cx| ProjectContext::default());
4696 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4697 let context_server_registry =
4698 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4699 let model = Arc::new(FakeLanguageModel::default());
4700
4701 let parent_thread = cx.new(|cx| {
4702 Thread::new(
4703 project.clone(),
4704 project_context,
4705 context_server_registry,
4706 Templates::new(),
4707 Some(model.clone()),
4708 cx,
4709 )
4710 });
4711
4712 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4713 subagent_thread.read_with(cx, |subagent_thread, cx| {
4714 assert!(subagent_thread.is_subagent());
4715 assert_eq!(subagent_thread.depth(), 1);
4716 assert_eq!(
4717 subagent_thread.model().map(|model| model.id()),
4718 Some(model.id())
4719 );
4720 assert_eq!(
4721 subagent_thread.parent_thread_id(),
4722 Some(parent_thread.read(cx).id().clone())
4723 );
4724 });
4725}
4726
4727#[gpui::test]
4728async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4729 init_test(cx);
4730
4731 cx.update(|cx| {
4732 cx.update_flags(true, vec!["subagents".to_string()]);
4733 });
4734
4735 let fs = FakeFs::new(cx.executor());
4736 fs.insert_tree(path!("/test"), json!({})).await;
4737 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4738 let project_context = cx.new(|_cx| ProjectContext::default());
4739 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4740 let context_server_registry =
4741 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4742 let model = Arc::new(FakeLanguageModel::default());
4743 let environment = Rc::new(cx.update(|cx| {
4744 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4745 }));
4746
4747 let deep_parent_thread = cx.new(|cx| {
4748 let mut thread = Thread::new(
4749 project.clone(),
4750 project_context,
4751 context_server_registry,
4752 Templates::new(),
4753 Some(model.clone()),
4754 cx,
4755 );
4756 thread.set_subagent_context(SubagentContext {
4757 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4758 depth: MAX_SUBAGENT_DEPTH - 1,
4759 });
4760 thread
4761 });
4762 let deep_subagent_thread = cx.new(|cx| {
4763 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4764 thread.add_default_tools(environment, cx);
4765 thread
4766 });
4767
4768 deep_subagent_thread.read_with(cx, |thread, _| {
4769 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4770 assert!(
4771 !thread.has_registered_tool(SpawnAgentTool::NAME),
4772 "subagent tool should not be present at max depth"
4773 );
4774 });
4775}
4776
4777#[gpui::test]
4778async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4779 init_test(cx);
4780
4781 cx.update(|cx| {
4782 cx.update_flags(true, vec!["subagents".to_string()]);
4783 });
4784
4785 let fs = FakeFs::new(cx.executor());
4786 fs.insert_tree(path!("/test"), json!({})).await;
4787 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4788 let project_context = cx.new(|_cx| ProjectContext::default());
4789 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4790 let context_server_registry =
4791 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4792 let model = Arc::new(FakeLanguageModel::default());
4793
4794 let parent = cx.new(|cx| {
4795 Thread::new(
4796 project.clone(),
4797 project_context.clone(),
4798 context_server_registry.clone(),
4799 Templates::new(),
4800 Some(model.clone()),
4801 cx,
4802 )
4803 });
4804
4805 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4806
4807 parent.update(cx, |thread, _cx| {
4808 thread.register_running_subagent(subagent.downgrade());
4809 });
4810
4811 subagent
4812 .update(cx, |thread, cx| {
4813 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4814 })
4815 .unwrap();
4816 cx.run_until_parked();
4817
4818 subagent.read_with(cx, |thread, _| {
4819 assert!(!thread.is_turn_complete(), "subagent should be running");
4820 });
4821
4822 parent.update(cx, |thread, cx| {
4823 thread.cancel(cx).detach();
4824 });
4825
4826 subagent.read_with(cx, |thread, _| {
4827 assert!(
4828 thread.is_turn_complete(),
4829 "subagent should be cancelled when parent cancels"
4830 );
4831 });
4832}
4833
4834#[gpui::test]
4835async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
4836 init_test(cx);
4837 cx.update(|cx| {
4838 LanguageModelRegistry::test(cx);
4839 });
4840 cx.update(|cx| {
4841 cx.update_flags(true, vec!["subagents".to_string()]);
4842 });
4843
4844 let fs = FakeFs::new(cx.executor());
4845 fs.insert_tree(
4846 "/",
4847 json!({
4848 "a": {
4849 "b.md": "Lorem"
4850 }
4851 }),
4852 )
4853 .await;
4854 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4855 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4856 let agent = NativeAgent::new(
4857 project.clone(),
4858 thread_store.clone(),
4859 Templates::new(),
4860 None,
4861 fs.clone(),
4862 &mut cx.to_async(),
4863 )
4864 .await
4865 .unwrap();
4866 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4867
4868 let acp_thread = cx
4869 .update(|cx| {
4870 connection
4871 .clone()
4872 .new_session(project.clone(), Path::new(""), cx)
4873 })
4874 .await
4875 .unwrap();
4876 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4877 let thread = agent.read_with(cx, |agent, _| {
4878 agent.sessions.get(&session_id).unwrap().thread.clone()
4879 });
4880 let model = Arc::new(FakeLanguageModel::default());
4881
4882 thread.update(cx, |thread, cx| {
4883 thread.set_model(model.clone(), cx);
4884 });
4885 cx.run_until_parked();
4886
4887 // Start the parent turn
4888 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4889 cx.run_until_parked();
4890 model.send_last_completion_stream_text_chunk("spawning subagent");
4891 let subagent_tool_input = SpawnAgentToolInput {
4892 label: "label".to_string(),
4893 message: "subagent task prompt".to_string(),
4894 session_id: None,
4895 };
4896 let subagent_tool_use = LanguageModelToolUse {
4897 id: "subagent_1".into(),
4898 name: SpawnAgentTool::NAME.into(),
4899 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4900 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4901 is_input_complete: true,
4902 thought_signature: None,
4903 };
4904 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4905 subagent_tool_use,
4906 ));
4907 model.end_last_completion_stream();
4908
4909 cx.run_until_parked();
4910
4911 // Verify subagent is running
4912 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4913 thread
4914 .running_subagent_ids(cx)
4915 .get(0)
4916 .expect("subagent thread should be running")
4917 .clone()
4918 });
4919
4920 // Send a usage update that crosses the warning threshold (80% of 1,000,000)
4921 model.send_last_completion_stream_text_chunk("partial work");
4922 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4923 TokenUsage {
4924 input_tokens: 850_000,
4925 output_tokens: 0,
4926 cache_creation_input_tokens: 0,
4927 cache_read_input_tokens: 0,
4928 },
4929 ));
4930
4931 cx.run_until_parked();
4932
4933 // The subagent should no longer be running
4934 thread.read_with(cx, |thread, cx| {
4935 assert!(
4936 thread.running_subagent_ids(cx).is_empty(),
4937 "subagent should be stopped after context window warning"
4938 );
4939 });
4940
4941 // The parent model should get a new completion request to respond to the tool error
4942 model.send_last_completion_stream_text_chunk("Response after warning");
4943 model.end_last_completion_stream();
4944
4945 send.await.unwrap();
4946
4947 // Verify the parent thread shows the warning error in the tool call
4948 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
4949 assert!(
4950 markdown.contains("nearing the end of its context window"),
4951 "tool output should contain context window warning message, got:\n{markdown}"
4952 );
4953 assert!(
4954 markdown.contains("Status: Failed"),
4955 "tool call should have Failed status, got:\n{markdown}"
4956 );
4957
4958 // Verify the subagent session still exists (can be resumed)
4959 agent.read_with(cx, |agent, _cx| {
4960 assert!(
4961 agent.sessions.contains_key(&subagent_session_id),
4962 "subagent session should still exist for potential resume"
4963 );
4964 });
4965}
4966
4967#[gpui::test]
4968async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
4969 init_test(cx);
4970 cx.update(|cx| {
4971 LanguageModelRegistry::test(cx);
4972 });
4973 cx.update(|cx| {
4974 cx.update_flags(true, vec!["subagents".to_string()]);
4975 });
4976
4977 let fs = FakeFs::new(cx.executor());
4978 fs.insert_tree(
4979 "/",
4980 json!({
4981 "a": {
4982 "b.md": "Lorem"
4983 }
4984 }),
4985 )
4986 .await;
4987 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4988 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4989 let agent = NativeAgent::new(
4990 project.clone(),
4991 thread_store.clone(),
4992 Templates::new(),
4993 None,
4994 fs.clone(),
4995 &mut cx.to_async(),
4996 )
4997 .await
4998 .unwrap();
4999 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5000
5001 let acp_thread = cx
5002 .update(|cx| {
5003 connection
5004 .clone()
5005 .new_session(project.clone(), Path::new(""), cx)
5006 })
5007 .await
5008 .unwrap();
5009 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5010 let thread = agent.read_with(cx, |agent, _| {
5011 agent.sessions.get(&session_id).unwrap().thread.clone()
5012 });
5013 let model = Arc::new(FakeLanguageModel::default());
5014
5015 thread.update(cx, |thread, cx| {
5016 thread.set_model(model.clone(), cx);
5017 });
5018 cx.run_until_parked();
5019
5020 // === First turn: create subagent, trigger context window warning ===
5021 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
5022 cx.run_until_parked();
5023 model.send_last_completion_stream_text_chunk("spawning subagent");
5024 let subagent_tool_input = SpawnAgentToolInput {
5025 label: "initial task".to_string(),
5026 message: "do the first task".to_string(),
5027 session_id: None,
5028 };
5029 let subagent_tool_use = LanguageModelToolUse {
5030 id: "subagent_1".into(),
5031 name: SpawnAgentTool::NAME.into(),
5032 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5033 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5034 is_input_complete: true,
5035 thought_signature: None,
5036 };
5037 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5038 subagent_tool_use,
5039 ));
5040 model.end_last_completion_stream();
5041
5042 cx.run_until_parked();
5043
5044 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5045 thread
5046 .running_subagent_ids(cx)
5047 .get(0)
5048 .expect("subagent thread should be running")
5049 .clone()
5050 });
5051
5052 // Subagent sends a usage update that crosses the warning threshold.
5053 // This triggers Normal→Warning, stopping the subagent.
5054 model.send_last_completion_stream_text_chunk("partial work");
5055 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5056 TokenUsage {
5057 input_tokens: 850_000,
5058 output_tokens: 0,
5059 cache_creation_input_tokens: 0,
5060 cache_read_input_tokens: 0,
5061 },
5062 ));
5063
5064 cx.run_until_parked();
5065
5066 // Verify the first turn was stopped with a context window warning
5067 thread.read_with(cx, |thread, cx| {
5068 assert!(
5069 thread.running_subagent_ids(cx).is_empty(),
5070 "subagent should be stopped after context window warning"
5071 );
5072 });
5073
5074 // Parent model responds to complete first turn
5075 model.send_last_completion_stream_text_chunk("First response");
5076 model.end_last_completion_stream();
5077
5078 send.await.unwrap();
5079
5080 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5081 assert!(
5082 markdown.contains("nearing the end of its context window"),
5083 "first turn should have context window warning, got:\n{markdown}"
5084 );
5085
5086 // === Second turn: resume the same subagent (now at Warning level) ===
5087 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
5088 cx.run_until_parked();
5089 model.send_last_completion_stream_text_chunk("resuming subagent");
5090 let resume_tool_input = SpawnAgentToolInput {
5091 label: "follow-up task".to_string(),
5092 message: "do the follow-up task".to_string(),
5093 session_id: Some(subagent_session_id.clone()),
5094 };
5095 let resume_tool_use = LanguageModelToolUse {
5096 id: "subagent_2".into(),
5097 name: SpawnAgentTool::NAME.into(),
5098 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
5099 input: serde_json::to_value(&resume_tool_input).unwrap(),
5100 is_input_complete: true,
5101 thought_signature: None,
5102 };
5103 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
5104 model.end_last_completion_stream();
5105
5106 cx.run_until_parked();
5107
5108 // Subagent responds with tokens still at warning level (no worse).
5109 // Since ratio_before_prompt was already Warning, this should NOT
5110 // trigger the context window warning again.
5111 model.send_last_completion_stream_text_chunk("follow-up task response");
5112 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5113 TokenUsage {
5114 input_tokens: 870_000,
5115 output_tokens: 0,
5116 cache_creation_input_tokens: 0,
5117 cache_read_input_tokens: 0,
5118 },
5119 ));
5120 model.end_last_completion_stream();
5121
5122 cx.run_until_parked();
5123
5124 // Parent model responds to complete second turn
5125 model.send_last_completion_stream_text_chunk("Second response");
5126 model.end_last_completion_stream();
5127
5128 send2.await.unwrap();
5129
5130 // The resumed subagent should have completed normally since the ratio
5131 // didn't transition (it was Warning before and stayed at Warning)
5132 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5133 assert!(
5134 markdown.contains("follow-up task response"),
5135 "resumed subagent should complete normally when already at warning, got:\n{markdown}"
5136 );
5137 // The second tool call should NOT have a context window warning
5138 let second_tool_pos = markdown
5139 .find("follow-up task")
5140 .expect("should find follow-up tool call");
5141 let after_second_tool = &markdown[second_tool_pos..];
5142 assert!(
5143 !after_second_tool.contains("nearing the end of its context window"),
5144 "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
5145 );
5146}
5147
5148#[gpui::test]
5149async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
5150 init_test(cx);
5151 cx.update(|cx| {
5152 LanguageModelRegistry::test(cx);
5153 });
5154 cx.update(|cx| {
5155 cx.update_flags(true, vec!["subagents".to_string()]);
5156 });
5157
5158 let fs = FakeFs::new(cx.executor());
5159 fs.insert_tree(
5160 "/",
5161 json!({
5162 "a": {
5163 "b.md": "Lorem"
5164 }
5165 }),
5166 )
5167 .await;
5168 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5169 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5170 let agent = NativeAgent::new(
5171 project.clone(),
5172 thread_store.clone(),
5173 Templates::new(),
5174 None,
5175 fs.clone(),
5176 &mut cx.to_async(),
5177 )
5178 .await
5179 .unwrap();
5180 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5181
5182 let acp_thread = cx
5183 .update(|cx| {
5184 connection
5185 .clone()
5186 .new_session(project.clone(), Path::new(""), cx)
5187 })
5188 .await
5189 .unwrap();
5190 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5191 let thread = agent.read_with(cx, |agent, _| {
5192 agent.sessions.get(&session_id).unwrap().thread.clone()
5193 });
5194 let model = Arc::new(FakeLanguageModel::default());
5195
5196 thread.update(cx, |thread, cx| {
5197 thread.set_model(model.clone(), cx);
5198 });
5199 cx.run_until_parked();
5200
5201 // Start the parent turn
5202 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5203 cx.run_until_parked();
5204 model.send_last_completion_stream_text_chunk("spawning subagent");
5205 let subagent_tool_input = SpawnAgentToolInput {
5206 label: "label".to_string(),
5207 message: "subagent task prompt".to_string(),
5208 session_id: None,
5209 };
5210 let subagent_tool_use = LanguageModelToolUse {
5211 id: "subagent_1".into(),
5212 name: SpawnAgentTool::NAME.into(),
5213 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5214 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5215 is_input_complete: true,
5216 thought_signature: None,
5217 };
5218 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5219 subagent_tool_use,
5220 ));
5221 model.end_last_completion_stream();
5222
5223 cx.run_until_parked();
5224
5225 // Verify subagent is running
5226 thread.read_with(cx, |thread, cx| {
5227 assert!(
5228 !thread.running_subagent_ids(cx).is_empty(),
5229 "subagent should be running"
5230 );
5231 });
5232
5233 // The subagent's model returns a non-retryable error
5234 model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
5235 tokens: None,
5236 });
5237
5238 cx.run_until_parked();
5239
5240 // The subagent should no longer be running
5241 thread.read_with(cx, |thread, cx| {
5242 assert!(
5243 thread.running_subagent_ids(cx).is_empty(),
5244 "subagent should not be running after error"
5245 );
5246 });
5247
5248 // The parent model should get a new completion request to respond to the tool error
5249 model.send_last_completion_stream_text_chunk("Response after error");
5250 model.end_last_completion_stream();
5251
5252 send.await.unwrap();
5253
5254 // Verify the parent thread shows the error in the tool call
5255 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5256 assert!(
5257 markdown.contains("Status: Failed"),
5258 "tool call should have Failed status after model error, got:\n{markdown}"
5259 );
5260}
5261
5262#[gpui::test]
5263async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5264 init_test(cx);
5265
5266 let fs = FakeFs::new(cx.executor());
5267 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5268 .await;
5269 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5270
5271 cx.update(|cx| {
5272 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5273 settings.tool_permissions.tools.insert(
5274 EditFileTool::NAME.into(),
5275 agent_settings::ToolRules {
5276 default: Some(settings::ToolPermissionMode::Allow),
5277 always_allow: vec![],
5278 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5279 always_confirm: vec![],
5280 invalid_patterns: vec![],
5281 },
5282 );
5283 agent_settings::AgentSettings::override_global(settings, cx);
5284 });
5285
5286 let context_server_registry =
5287 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5288 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5289 let templates = crate::Templates::new();
5290 let thread = cx.new(|cx| {
5291 crate::Thread::new(
5292 project.clone(),
5293 cx.new(|_cx| prompt_store::ProjectContext::default()),
5294 context_server_registry,
5295 templates.clone(),
5296 None,
5297 cx,
5298 )
5299 });
5300
5301 #[allow(clippy::arc_with_non_send_sync)]
5302 let tool = Arc::new(crate::EditFileTool::new(
5303 project.clone(),
5304 thread.downgrade(),
5305 language_registry,
5306 templates,
5307 ));
5308 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5309
5310 let task = cx.update(|cx| {
5311 tool.run(
5312 crate::EditFileToolInput {
5313 display_description: "Edit sensitive file".to_string(),
5314 path: "root/sensitive_config.txt".into(),
5315 mode: crate::EditFileMode::Edit,
5316 },
5317 event_stream,
5318 cx,
5319 )
5320 });
5321
5322 let result = task.await;
5323 assert!(result.is_err(), "expected edit to be blocked");
5324 assert!(
5325 result.unwrap_err().to_string().contains("blocked"),
5326 "error should mention the edit was blocked"
5327 );
5328}
5329
5330#[gpui::test]
5331async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5332 init_test(cx);
5333
5334 let fs = FakeFs::new(cx.executor());
5335 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5336 .await;
5337 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5338
5339 cx.update(|cx| {
5340 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5341 settings.tool_permissions.tools.insert(
5342 DeletePathTool::NAME.into(),
5343 agent_settings::ToolRules {
5344 default: Some(settings::ToolPermissionMode::Allow),
5345 always_allow: vec![],
5346 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5347 always_confirm: vec![],
5348 invalid_patterns: vec![],
5349 },
5350 );
5351 agent_settings::AgentSettings::override_global(settings, cx);
5352 });
5353
5354 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5355
5356 #[allow(clippy::arc_with_non_send_sync)]
5357 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5358 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5359
5360 let task = cx.update(|cx| {
5361 tool.run(
5362 crate::DeletePathToolInput {
5363 path: "root/important_data.txt".to_string(),
5364 },
5365 event_stream,
5366 cx,
5367 )
5368 });
5369
5370 let result = task.await;
5371 assert!(result.is_err(), "expected deletion to be blocked");
5372 assert!(
5373 result.unwrap_err().contains("blocked"),
5374 "error should mention the deletion was blocked"
5375 );
5376}
5377
5378#[gpui::test]
5379async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5380 init_test(cx);
5381
5382 let fs = FakeFs::new(cx.executor());
5383 fs.insert_tree(
5384 "/root",
5385 json!({
5386 "safe.txt": "content",
5387 "protected": {}
5388 }),
5389 )
5390 .await;
5391 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5392
5393 cx.update(|cx| {
5394 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5395 settings.tool_permissions.tools.insert(
5396 MovePathTool::NAME.into(),
5397 agent_settings::ToolRules {
5398 default: Some(settings::ToolPermissionMode::Allow),
5399 always_allow: vec![],
5400 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5401 always_confirm: vec![],
5402 invalid_patterns: vec![],
5403 },
5404 );
5405 agent_settings::AgentSettings::override_global(settings, cx);
5406 });
5407
5408 #[allow(clippy::arc_with_non_send_sync)]
5409 let tool = Arc::new(crate::MovePathTool::new(project));
5410 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5411
5412 let task = cx.update(|cx| {
5413 tool.run(
5414 crate::MovePathToolInput {
5415 source_path: "root/safe.txt".to_string(),
5416 destination_path: "root/protected/safe.txt".to_string(),
5417 },
5418 event_stream,
5419 cx,
5420 )
5421 });
5422
5423 let result = task.await;
5424 assert!(
5425 result.is_err(),
5426 "expected move to be blocked due to destination path"
5427 );
5428 assert!(
5429 result.unwrap_err().contains("blocked"),
5430 "error should mention the move was blocked"
5431 );
5432}
5433
5434#[gpui::test]
5435async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5436 init_test(cx);
5437
5438 let fs = FakeFs::new(cx.executor());
5439 fs.insert_tree(
5440 "/root",
5441 json!({
5442 "secret.txt": "secret content",
5443 "public": {}
5444 }),
5445 )
5446 .await;
5447 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5448
5449 cx.update(|cx| {
5450 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5451 settings.tool_permissions.tools.insert(
5452 MovePathTool::NAME.into(),
5453 agent_settings::ToolRules {
5454 default: Some(settings::ToolPermissionMode::Allow),
5455 always_allow: vec![],
5456 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5457 always_confirm: vec![],
5458 invalid_patterns: vec![],
5459 },
5460 );
5461 agent_settings::AgentSettings::override_global(settings, cx);
5462 });
5463
5464 #[allow(clippy::arc_with_non_send_sync)]
5465 let tool = Arc::new(crate::MovePathTool::new(project));
5466 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5467
5468 let task = cx.update(|cx| {
5469 tool.run(
5470 crate::MovePathToolInput {
5471 source_path: "root/secret.txt".to_string(),
5472 destination_path: "root/public/not_secret.txt".to_string(),
5473 },
5474 event_stream,
5475 cx,
5476 )
5477 });
5478
5479 let result = task.await;
5480 assert!(
5481 result.is_err(),
5482 "expected move to be blocked due to source path"
5483 );
5484 assert!(
5485 result.unwrap_err().contains("blocked"),
5486 "error should mention the move was blocked"
5487 );
5488}
5489
5490#[gpui::test]
5491async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5492 init_test(cx);
5493
5494 let fs = FakeFs::new(cx.executor());
5495 fs.insert_tree(
5496 "/root",
5497 json!({
5498 "confidential.txt": "confidential data",
5499 "dest": {}
5500 }),
5501 )
5502 .await;
5503 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5504
5505 cx.update(|cx| {
5506 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5507 settings.tool_permissions.tools.insert(
5508 CopyPathTool::NAME.into(),
5509 agent_settings::ToolRules {
5510 default: Some(settings::ToolPermissionMode::Allow),
5511 always_allow: vec![],
5512 always_deny: vec![
5513 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5514 ],
5515 always_confirm: vec![],
5516 invalid_patterns: vec![],
5517 },
5518 );
5519 agent_settings::AgentSettings::override_global(settings, cx);
5520 });
5521
5522 #[allow(clippy::arc_with_non_send_sync)]
5523 let tool = Arc::new(crate::CopyPathTool::new(project));
5524 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5525
5526 let task = cx.update(|cx| {
5527 tool.run(
5528 crate::CopyPathToolInput {
5529 source_path: "root/confidential.txt".to_string(),
5530 destination_path: "root/dest/copy.txt".to_string(),
5531 },
5532 event_stream,
5533 cx,
5534 )
5535 });
5536
5537 let result = task.await;
5538 assert!(result.is_err(), "expected copy to be blocked");
5539 assert!(
5540 result.unwrap_err().contains("blocked"),
5541 "error should mention the copy was blocked"
5542 );
5543}
5544
5545#[gpui::test]
5546async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5547 init_test(cx);
5548
5549 let fs = FakeFs::new(cx.executor());
5550 fs.insert_tree(
5551 "/root",
5552 json!({
5553 "normal.txt": "normal content",
5554 "readonly": {
5555 "config.txt": "readonly content"
5556 }
5557 }),
5558 )
5559 .await;
5560 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5561
5562 cx.update(|cx| {
5563 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5564 settings.tool_permissions.tools.insert(
5565 SaveFileTool::NAME.into(),
5566 agent_settings::ToolRules {
5567 default: Some(settings::ToolPermissionMode::Allow),
5568 always_allow: vec![],
5569 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5570 always_confirm: vec![],
5571 invalid_patterns: vec![],
5572 },
5573 );
5574 agent_settings::AgentSettings::override_global(settings, cx);
5575 });
5576
5577 #[allow(clippy::arc_with_non_send_sync)]
5578 let tool = Arc::new(crate::SaveFileTool::new(project));
5579 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5580
5581 let task = cx.update(|cx| {
5582 tool.run(
5583 crate::SaveFileToolInput {
5584 paths: vec![
5585 std::path::PathBuf::from("root/normal.txt"),
5586 std::path::PathBuf::from("root/readonly/config.txt"),
5587 ],
5588 },
5589 event_stream,
5590 cx,
5591 )
5592 });
5593
5594 let result = task.await;
5595 assert!(
5596 result.is_err(),
5597 "expected save to be blocked due to denied path"
5598 );
5599 assert!(
5600 result.unwrap_err().contains("blocked"),
5601 "error should mention the save was blocked"
5602 );
5603}
5604
5605#[gpui::test]
5606async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5607 init_test(cx);
5608
5609 let fs = FakeFs::new(cx.executor());
5610 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5611 .await;
5612 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5613
5614 cx.update(|cx| {
5615 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5616 settings.tool_permissions.tools.insert(
5617 SaveFileTool::NAME.into(),
5618 agent_settings::ToolRules {
5619 default: Some(settings::ToolPermissionMode::Allow),
5620 always_allow: vec![],
5621 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5622 always_confirm: vec![],
5623 invalid_patterns: vec![],
5624 },
5625 );
5626 agent_settings::AgentSettings::override_global(settings, cx);
5627 });
5628
5629 #[allow(clippy::arc_with_non_send_sync)]
5630 let tool = Arc::new(crate::SaveFileTool::new(project));
5631 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5632
5633 let task = cx.update(|cx| {
5634 tool.run(
5635 crate::SaveFileToolInput {
5636 paths: vec![std::path::PathBuf::from("root/config.secret")],
5637 },
5638 event_stream,
5639 cx,
5640 )
5641 });
5642
5643 let result = task.await;
5644 assert!(result.is_err(), "expected save to be blocked");
5645 assert!(
5646 result.unwrap_err().contains("blocked"),
5647 "error should mention the save was blocked"
5648 );
5649}
5650
5651#[gpui::test]
5652async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5653 init_test(cx);
5654
5655 cx.update(|cx| {
5656 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5657 settings.tool_permissions.tools.insert(
5658 WebSearchTool::NAME.into(),
5659 agent_settings::ToolRules {
5660 default: Some(settings::ToolPermissionMode::Allow),
5661 always_allow: vec![],
5662 always_deny: vec![
5663 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5664 ],
5665 always_confirm: vec![],
5666 invalid_patterns: vec![],
5667 },
5668 );
5669 agent_settings::AgentSettings::override_global(settings, cx);
5670 });
5671
5672 #[allow(clippy::arc_with_non_send_sync)]
5673 let tool = Arc::new(crate::WebSearchTool);
5674 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5675
5676 let input: crate::WebSearchToolInput =
5677 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5678
5679 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5680
5681 let result = task.await;
5682 assert!(result.is_err(), "expected search to be blocked");
5683 match result.unwrap_err() {
5684 crate::WebSearchToolOutput::Error { error } => {
5685 assert!(
5686 error.contains("blocked"),
5687 "error should mention the search was blocked"
5688 );
5689 }
5690 other => panic!("expected Error variant, got: {other:?}"),
5691 }
5692}
5693
5694#[gpui::test]
5695async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5696 init_test(cx);
5697
5698 let fs = FakeFs::new(cx.executor());
5699 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5700 .await;
5701 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5702
5703 cx.update(|cx| {
5704 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5705 settings.tool_permissions.tools.insert(
5706 EditFileTool::NAME.into(),
5707 agent_settings::ToolRules {
5708 default: Some(settings::ToolPermissionMode::Confirm),
5709 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5710 always_deny: vec![],
5711 always_confirm: vec![],
5712 invalid_patterns: vec![],
5713 },
5714 );
5715 agent_settings::AgentSettings::override_global(settings, cx);
5716 });
5717
5718 let context_server_registry =
5719 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5720 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5721 let templates = crate::Templates::new();
5722 let thread = cx.new(|cx| {
5723 crate::Thread::new(
5724 project.clone(),
5725 cx.new(|_cx| prompt_store::ProjectContext::default()),
5726 context_server_registry,
5727 templates.clone(),
5728 None,
5729 cx,
5730 )
5731 });
5732
5733 #[allow(clippy::arc_with_non_send_sync)]
5734 let tool = Arc::new(crate::EditFileTool::new(
5735 project,
5736 thread.downgrade(),
5737 language_registry,
5738 templates,
5739 ));
5740 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5741
5742 let _task = cx.update(|cx| {
5743 tool.run(
5744 crate::EditFileToolInput {
5745 display_description: "Edit README".to_string(),
5746 path: "root/README.md".into(),
5747 mode: crate::EditFileMode::Edit,
5748 },
5749 event_stream,
5750 cx,
5751 )
5752 });
5753
5754 cx.run_until_parked();
5755
5756 let event = rx.try_next();
5757 assert!(
5758 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5759 "expected no authorization request for allowed .md file"
5760 );
5761}
5762
5763#[gpui::test]
5764async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5765 init_test(cx);
5766
5767 let fs = FakeFs::new(cx.executor());
5768 fs.insert_tree(
5769 "/root",
5770 json!({
5771 ".zed": {
5772 "settings.json": "{}"
5773 },
5774 "README.md": "# Hello"
5775 }),
5776 )
5777 .await;
5778 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5779
5780 cx.update(|cx| {
5781 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5782 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5783 agent_settings::AgentSettings::override_global(settings, cx);
5784 });
5785
5786 let context_server_registry =
5787 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5788 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5789 let templates = crate::Templates::new();
5790 let thread = cx.new(|cx| {
5791 crate::Thread::new(
5792 project.clone(),
5793 cx.new(|_cx| prompt_store::ProjectContext::default()),
5794 context_server_registry,
5795 templates.clone(),
5796 None,
5797 cx,
5798 )
5799 });
5800
5801 #[allow(clippy::arc_with_non_send_sync)]
5802 let tool = Arc::new(crate::EditFileTool::new(
5803 project,
5804 thread.downgrade(),
5805 language_registry,
5806 templates,
5807 ));
5808
5809 // Editing a file inside .zed/ should still prompt even with global default: allow,
5810 // because local settings paths are sensitive and require confirmation regardless.
5811 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5812 let _task = cx.update(|cx| {
5813 tool.run(
5814 crate::EditFileToolInput {
5815 display_description: "Edit local settings".to_string(),
5816 path: "root/.zed/settings.json".into(),
5817 mode: crate::EditFileMode::Edit,
5818 },
5819 event_stream,
5820 cx,
5821 )
5822 });
5823
5824 let _update = rx.expect_update_fields().await;
5825 let _auth = rx.expect_authorization().await;
5826}
5827
5828#[gpui::test]
5829async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5830 init_test(cx);
5831
5832 cx.update(|cx| {
5833 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5834 settings.tool_permissions.tools.insert(
5835 FetchTool::NAME.into(),
5836 agent_settings::ToolRules {
5837 default: Some(settings::ToolPermissionMode::Allow),
5838 always_allow: vec![],
5839 always_deny: vec![
5840 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5841 ],
5842 always_confirm: vec![],
5843 invalid_patterns: vec![],
5844 },
5845 );
5846 agent_settings::AgentSettings::override_global(settings, cx);
5847 });
5848
5849 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5850
5851 #[allow(clippy::arc_with_non_send_sync)]
5852 let tool = Arc::new(crate::FetchTool::new(http_client));
5853 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5854
5855 let input: crate::FetchToolInput =
5856 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5857
5858 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5859
5860 let result = task.await;
5861 assert!(result.is_err(), "expected fetch to be blocked");
5862 assert!(
5863 result.unwrap_err().contains("blocked"),
5864 "error should mention the fetch was blocked"
5865 );
5866}
5867
5868#[gpui::test]
5869async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5870 init_test(cx);
5871
5872 cx.update(|cx| {
5873 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5874 settings.tool_permissions.tools.insert(
5875 FetchTool::NAME.into(),
5876 agent_settings::ToolRules {
5877 default: Some(settings::ToolPermissionMode::Confirm),
5878 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5879 always_deny: vec![],
5880 always_confirm: vec![],
5881 invalid_patterns: vec![],
5882 },
5883 );
5884 agent_settings::AgentSettings::override_global(settings, cx);
5885 });
5886
5887 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5888
5889 #[allow(clippy::arc_with_non_send_sync)]
5890 let tool = Arc::new(crate::FetchTool::new(http_client));
5891 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5892
5893 let input: crate::FetchToolInput =
5894 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5895
5896 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5897
5898 cx.run_until_parked();
5899
5900 let event = rx.try_next();
5901 assert!(
5902 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5903 "expected no authorization request for allowed docs.rs URL"
5904 );
5905}
5906
5907#[gpui::test]
5908async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5909 init_test(cx);
5910 always_allow_tools(cx);
5911
5912 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5913 let fake_model = model.as_fake();
5914
5915 // Add a tool so we can simulate tool calls
5916 thread.update(cx, |thread, _cx| {
5917 thread.add_tool(EchoTool);
5918 });
5919
5920 // Start a turn by sending a message
5921 let mut events = thread
5922 .update(cx, |thread, cx| {
5923 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5924 })
5925 .unwrap();
5926 cx.run_until_parked();
5927
5928 // Simulate the model making a tool call
5929 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5930 LanguageModelToolUse {
5931 id: "tool_1".into(),
5932 name: "echo".into(),
5933 raw_input: r#"{"text": "hello"}"#.into(),
5934 input: json!({"text": "hello"}),
5935 is_input_complete: true,
5936 thought_signature: None,
5937 },
5938 ));
5939 fake_model
5940 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5941
5942 // Signal that a message is queued before ending the stream
5943 thread.update(cx, |thread, _cx| {
5944 thread.set_has_queued_message(true);
5945 });
5946
5947 // Now end the stream - tool will run, and the boundary check should see the queue
5948 fake_model.end_last_completion_stream();
5949
5950 // Collect all events until the turn stops
5951 let all_events = collect_events_until_stop(&mut events, cx).await;
5952
5953 // Verify we received the tool call event
5954 let tool_call_ids: Vec<_> = all_events
5955 .iter()
5956 .filter_map(|e| match e {
5957 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5958 _ => None,
5959 })
5960 .collect();
5961 assert_eq!(
5962 tool_call_ids,
5963 vec!["tool_1"],
5964 "Should have received a tool call event for our echo tool"
5965 );
5966
5967 // The turn should have stopped with EndTurn
5968 let stop_reasons = stop_events(all_events);
5969 assert_eq!(
5970 stop_reasons,
5971 vec![acp::StopReason::EndTurn],
5972 "Turn should have ended after tool completion due to queued message"
5973 );
5974
5975 // Verify the queued message flag is still set
5976 thread.update(cx, |thread, _cx| {
5977 assert!(
5978 thread.has_queued_message(),
5979 "Should still have queued message flag set"
5980 );
5981 });
5982
5983 // Thread should be idle now
5984 thread.update(cx, |thread, _cx| {
5985 assert!(
5986 thread.is_turn_complete(),
5987 "Thread should not be running after turn ends"
5988 );
5989 });
5990}