1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
33 fake_provider::FakeLanguageModel,
34};
35use pretty_assertions::assert_eq;
36use project::{
37 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
38};
39use prompt_store::ProjectContext;
40use reqwest_client::ReqwestClient;
41use schemars::JsonSchema;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use settings::{Settings, SettingsStore};
45use std::{
46 path::Path,
47 pin::Pin,
48 rc::Rc,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, Ordering},
52 },
53 time::Duration,
54};
55use util::path;
56
57mod edit_file_thread_test;
58mod test_tools;
59use test_tools::*;
60
61fn init_test(cx: &mut TestAppContext) {
62 cx.update(|cx| {
63 let settings_store = SettingsStore::test(cx);
64 cx.set_global(settings_store);
65 });
66}
67
68struct FakeTerminalHandle {
69 killed: Arc<AtomicBool>,
70 stopped_by_user: Arc<AtomicBool>,
71 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
72 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
73 output: acp::TerminalOutputResponse,
74 id: acp::TerminalId,
75}
76
77impl FakeTerminalHandle {
78 fn new_never_exits(cx: &mut App) -> Self {
79 let killed = Arc::new(AtomicBool::new(false));
80 let stopped_by_user = Arc::new(AtomicBool::new(false));
81
82 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
83
84 let wait_for_exit = cx
85 .spawn(async move |_cx| {
86 // Wait for the exit signal (sent when kill() is called)
87 let _ = exit_receiver.await;
88 acp::TerminalExitStatus::new()
89 })
90 .shared();
91
92 Self {
93 killed,
94 stopped_by_user,
95 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
96 wait_for_exit,
97 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
98 id: acp::TerminalId::new("fake_terminal".to_string()),
99 }
100 }
101
102 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
103 let killed = Arc::new(AtomicBool::new(false));
104 let stopped_by_user = Arc::new(AtomicBool::new(false));
105 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
106
107 let wait_for_exit = cx
108 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
109 .shared();
110
111 Self {
112 killed,
113 stopped_by_user,
114 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
115 wait_for_exit,
116 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
117 id: acp::TerminalId::new("fake_terminal".to_string()),
118 }
119 }
120
121 fn was_killed(&self) -> bool {
122 self.killed.load(Ordering::SeqCst)
123 }
124
125 fn set_stopped_by_user(&self, stopped: bool) {
126 self.stopped_by_user.store(stopped, Ordering::SeqCst);
127 }
128
129 fn signal_exit(&self) {
130 if let Some(sender) = self.exit_sender.borrow_mut().take() {
131 let _ = sender.send(());
132 }
133 }
134}
135
136impl crate::TerminalHandle for FakeTerminalHandle {
137 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
138 Ok(self.id.clone())
139 }
140
141 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
142 Ok(self.output.clone())
143 }
144
145 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
146 Ok(self.wait_for_exit.clone())
147 }
148
149 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
150 self.killed.store(true, Ordering::SeqCst);
151 self.signal_exit();
152 Ok(())
153 }
154
155 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
156 Ok(self.stopped_by_user.load(Ordering::SeqCst))
157 }
158}
159
160struct FakeSubagentHandle {
161 session_id: acp::SessionId,
162 send_task: Shared<Task<String>>,
163}
164
165impl SubagentHandle for FakeSubagentHandle {
166 fn id(&self) -> acp::SessionId {
167 self.session_id.clone()
168 }
169
170 fn num_entries(&self, _cx: &App) -> usize {
171 unimplemented!()
172 }
173
174 fn send(&self, _message: String, cx: &AsyncApp) -> Task<Result<String>> {
175 let task = self.send_task.clone();
176 cx.background_spawn(async move { Ok(task.await) })
177 }
178}
179
180#[derive(Default)]
181struct FakeThreadEnvironment {
182 terminal_handle: Option<Rc<FakeTerminalHandle>>,
183 subagent_handle: Option<Rc<FakeSubagentHandle>>,
184}
185
186impl FakeThreadEnvironment {
187 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
188 Self {
189 terminal_handle: Some(terminal_handle.into()),
190 ..self
191 }
192 }
193}
194
195impl crate::ThreadEnvironment for FakeThreadEnvironment {
196 fn create_terminal(
197 &self,
198 _command: String,
199 _cwd: Option<std::path::PathBuf>,
200 _output_byte_limit: Option<u64>,
201 _cx: &mut AsyncApp,
202 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
203 let handle = self
204 .terminal_handle
205 .clone()
206 .expect("Terminal handle not available on FakeThreadEnvironment");
207 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
208 }
209
210 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
211 Ok(self
212 .subagent_handle
213 .clone()
214 .expect("Subagent handle not available on FakeThreadEnvironment")
215 as Rc<dyn SubagentHandle>)
216 }
217}
218
219/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
220struct MultiTerminalEnvironment {
221 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
222}
223
224impl MultiTerminalEnvironment {
225 fn new() -> Self {
226 Self {
227 handles: std::cell::RefCell::new(Vec::new()),
228 }
229 }
230
231 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
232 self.handles.borrow().clone()
233 }
234}
235
236impl crate::ThreadEnvironment for MultiTerminalEnvironment {
237 fn create_terminal(
238 &self,
239 _command: String,
240 _cwd: Option<std::path::PathBuf>,
241 _output_byte_limit: Option<u64>,
242 cx: &mut AsyncApp,
243 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
244 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
245 self.handles.borrow_mut().push(handle.clone());
246 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
247 }
248
249 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
250 unimplemented!()
251 }
252}
253
254fn always_allow_tools(cx: &mut TestAppContext) {
255 cx.update(|cx| {
256 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
257 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
258 agent_settings::AgentSettings::override_global(settings, cx);
259 });
260}
261
262#[gpui::test]
263async fn test_echo(cx: &mut TestAppContext) {
264 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
265 let fake_model = model.as_fake();
266
267 let events = thread
268 .update(cx, |thread, cx| {
269 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
270 })
271 .unwrap();
272 cx.run_until_parked();
273 fake_model.send_last_completion_stream_text_chunk("Hello");
274 fake_model
275 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
276 fake_model.end_last_completion_stream();
277
278 let events = events.collect().await;
279 thread.update(cx, |thread, _cx| {
280 assert_eq!(
281 thread.last_received_or_pending_message().unwrap().role(),
282 Role::Assistant
283 );
284 assert_eq!(
285 thread
286 .last_received_or_pending_message()
287 .unwrap()
288 .to_markdown(),
289 "Hello\n"
290 )
291 });
292 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
293}
294
295#[gpui::test]
296async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
297 init_test(cx);
298 always_allow_tools(cx);
299
300 let fs = FakeFs::new(cx.executor());
301 let project = Project::test(fs, [], cx).await;
302
303 let environment = Rc::new(cx.update(|cx| {
304 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
305 }));
306 let handle = environment.terminal_handle.clone().unwrap();
307
308 #[allow(clippy::arc_with_non_send_sync)]
309 let tool = Arc::new(crate::TerminalTool::new(project, environment));
310 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
311
312 let task = cx.update(|cx| {
313 tool.run(
314 ToolInput::resolved(crate::TerminalToolInput {
315 command: "sleep 1000".to_string(),
316 cd: ".".to_string(),
317 timeout_ms: Some(5),
318 }),
319 event_stream,
320 cx,
321 )
322 });
323
324 let update = rx.expect_update_fields().await;
325 assert!(
326 update.content.iter().any(|blocks| {
327 blocks
328 .iter()
329 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
330 }),
331 "expected tool call update to include terminal content"
332 );
333
334 let mut task_future: Pin<Box<Fuse<Task<Result<String, String>>>>> = Box::pin(task.fuse());
335
336 let deadline = std::time::Instant::now() + Duration::from_millis(500);
337 loop {
338 if let Some(result) = task_future.as_mut().now_or_never() {
339 let result = result.expect("terminal tool task should complete");
340
341 assert!(
342 handle.was_killed(),
343 "expected terminal handle to be killed on timeout"
344 );
345 assert!(
346 result.contains("partial output"),
347 "expected result to include terminal output, got: {result}"
348 );
349 return;
350 }
351
352 if std::time::Instant::now() >= deadline {
353 panic!("timed out waiting for terminal tool task to complete");
354 }
355
356 cx.run_until_parked();
357 cx.background_executor.timer(Duration::from_millis(1)).await;
358 }
359}
360
361#[gpui::test]
362#[ignore]
363async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
364 init_test(cx);
365 always_allow_tools(cx);
366
367 let fs = FakeFs::new(cx.executor());
368 let project = Project::test(fs, [], cx).await;
369
370 let environment = Rc::new(cx.update(|cx| {
371 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
372 }));
373 let handle = environment.terminal_handle.clone().unwrap();
374
375 #[allow(clippy::arc_with_non_send_sync)]
376 let tool = Arc::new(crate::TerminalTool::new(project, environment));
377 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
378
379 let _task = cx.update(|cx| {
380 tool.run(
381 ToolInput::resolved(crate::TerminalToolInput {
382 command: "sleep 1000".to_string(),
383 cd: ".".to_string(),
384 timeout_ms: None,
385 }),
386 event_stream,
387 cx,
388 )
389 });
390
391 let update = rx.expect_update_fields().await;
392 assert!(
393 update.content.iter().any(|blocks| {
394 blocks
395 .iter()
396 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
397 }),
398 "expected tool call update to include terminal content"
399 );
400
401 cx.background_executor
402 .timer(Duration::from_millis(25))
403 .await;
404
405 assert!(
406 !handle.was_killed(),
407 "did not expect terminal handle to be killed without a timeout"
408 );
409}
410
411#[gpui::test]
412async fn test_thinking(cx: &mut TestAppContext) {
413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
414 let fake_model = model.as_fake();
415
416 let events = thread
417 .update(cx, |thread, cx| {
418 thread.send(
419 UserMessageId::new(),
420 [indoc! {"
421 Testing:
422
423 Generate a thinking step where you just think the word 'Think',
424 and have your final answer be 'Hello'
425 "}],
426 cx,
427 )
428 })
429 .unwrap();
430 cx.run_until_parked();
431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
432 text: "Think".to_string(),
433 signature: None,
434 });
435 fake_model.send_last_completion_stream_text_chunk("Hello");
436 fake_model
437 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
438 fake_model.end_last_completion_stream();
439
440 let events = events.collect().await;
441 thread.update(cx, |thread, _cx| {
442 assert_eq!(
443 thread.last_received_or_pending_message().unwrap().role(),
444 Role::Assistant
445 );
446 assert_eq!(
447 thread
448 .last_received_or_pending_message()
449 .unwrap()
450 .to_markdown(),
451 indoc! {"
452 <think>Think</think>
453 Hello
454 "}
455 )
456 });
457 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
458}
459
460#[gpui::test]
461async fn test_system_prompt(cx: &mut TestAppContext) {
462 let ThreadTest {
463 model,
464 thread,
465 project_context,
466 ..
467 } = setup(cx, TestModel::Fake).await;
468 let fake_model = model.as_fake();
469
470 project_context.update(cx, |project_context, _cx| {
471 project_context.shell = "test-shell".into()
472 });
473 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
474 thread
475 .update(cx, |thread, cx| {
476 thread.send(UserMessageId::new(), ["abc"], cx)
477 })
478 .unwrap();
479 cx.run_until_parked();
480 let mut pending_completions = fake_model.pending_completions();
481 assert_eq!(
482 pending_completions.len(),
483 1,
484 "unexpected pending completions: {:?}",
485 pending_completions
486 );
487
488 let pending_completion = pending_completions.pop().unwrap();
489 assert_eq!(pending_completion.messages[0].role, Role::System);
490
491 let system_message = &pending_completion.messages[0];
492 let system_prompt = system_message.content[0].to_str().unwrap();
493 assert!(
494 system_prompt.contains("test-shell"),
495 "unexpected system message: {:?}",
496 system_message
497 );
498 assert!(
499 system_prompt.contains("## Fixing Diagnostics"),
500 "unexpected system message: {:?}",
501 system_message
502 );
503}
504
505#[gpui::test]
506async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
507 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
508 let fake_model = model.as_fake();
509
510 thread
511 .update(cx, |thread, cx| {
512 thread.send(UserMessageId::new(), ["abc"], cx)
513 })
514 .unwrap();
515 cx.run_until_parked();
516 let mut pending_completions = fake_model.pending_completions();
517 assert_eq!(
518 pending_completions.len(),
519 1,
520 "unexpected pending completions: {:?}",
521 pending_completions
522 );
523
524 let pending_completion = pending_completions.pop().unwrap();
525 assert_eq!(pending_completion.messages[0].role, Role::System);
526
527 let system_message = &pending_completion.messages[0];
528 let system_prompt = system_message.content[0].to_str().unwrap();
529 assert!(
530 !system_prompt.contains("## Tool Use"),
531 "unexpected system message: {:?}",
532 system_message
533 );
534 assert!(
535 !system_prompt.contains("## Fixing Diagnostics"),
536 "unexpected system message: {:?}",
537 system_message
538 );
539}
540
541#[gpui::test]
542async fn test_prompt_caching(cx: &mut TestAppContext) {
543 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
544 let fake_model = model.as_fake();
545
546 // Send initial user message and verify it's cached
547 thread
548 .update(cx, |thread, cx| {
549 thread.send(UserMessageId::new(), ["Message 1"], cx)
550 })
551 .unwrap();
552 cx.run_until_parked();
553
554 let completion = fake_model.pending_completions().pop().unwrap();
555 assert_eq!(
556 completion.messages[1..],
557 vec![LanguageModelRequestMessage {
558 role: Role::User,
559 content: vec!["Message 1".into()],
560 cache: true,
561 reasoning_details: None,
562 }]
563 );
564 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
565 "Response to Message 1".into(),
566 ));
567 fake_model.end_last_completion_stream();
568 cx.run_until_parked();
569
570 // Send another user message and verify only the latest is cached
571 thread
572 .update(cx, |thread, cx| {
573 thread.send(UserMessageId::new(), ["Message 2"], cx)
574 })
575 .unwrap();
576 cx.run_until_parked();
577
578 let completion = fake_model.pending_completions().pop().unwrap();
579 assert_eq!(
580 completion.messages[1..],
581 vec![
582 LanguageModelRequestMessage {
583 role: Role::User,
584 content: vec!["Message 1".into()],
585 cache: false,
586 reasoning_details: None,
587 },
588 LanguageModelRequestMessage {
589 role: Role::Assistant,
590 content: vec!["Response to Message 1".into()],
591 cache: false,
592 reasoning_details: None,
593 },
594 LanguageModelRequestMessage {
595 role: Role::User,
596 content: vec!["Message 2".into()],
597 cache: true,
598 reasoning_details: None,
599 }
600 ]
601 );
602 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
603 "Response to Message 2".into(),
604 ));
605 fake_model.end_last_completion_stream();
606 cx.run_until_parked();
607
608 // Simulate a tool call and verify that the latest tool result is cached
609 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
610 thread
611 .update(cx, |thread, cx| {
612 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
613 })
614 .unwrap();
615 cx.run_until_parked();
616
617 let tool_use = LanguageModelToolUse {
618 id: "tool_1".into(),
619 name: EchoTool::NAME.into(),
620 raw_input: json!({"text": "test"}).to_string(),
621 input: json!({"text": "test"}),
622 is_input_complete: true,
623 thought_signature: None,
624 };
625 fake_model
626 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
627 fake_model.end_last_completion_stream();
628 cx.run_until_parked();
629
630 let completion = fake_model.pending_completions().pop().unwrap();
631 let tool_result = LanguageModelToolResult {
632 tool_use_id: "tool_1".into(),
633 tool_name: EchoTool::NAME.into(),
634 is_error: false,
635 content: "test".into(),
636 output: Some("test".into()),
637 };
638 assert_eq!(
639 completion.messages[1..],
640 vec![
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec!["Message 1".into()],
644 cache: false,
645 reasoning_details: None,
646 },
647 LanguageModelRequestMessage {
648 role: Role::Assistant,
649 content: vec!["Response to Message 1".into()],
650 cache: false,
651 reasoning_details: None,
652 },
653 LanguageModelRequestMessage {
654 role: Role::User,
655 content: vec!["Message 2".into()],
656 cache: false,
657 reasoning_details: None,
658 },
659 LanguageModelRequestMessage {
660 role: Role::Assistant,
661 content: vec!["Response to Message 2".into()],
662 cache: false,
663 reasoning_details: None,
664 },
665 LanguageModelRequestMessage {
666 role: Role::User,
667 content: vec!["Use the echo tool".into()],
668 cache: false,
669 reasoning_details: None,
670 },
671 LanguageModelRequestMessage {
672 role: Role::Assistant,
673 content: vec![MessageContent::ToolUse(tool_use)],
674 cache: false,
675 reasoning_details: None,
676 },
677 LanguageModelRequestMessage {
678 role: Role::User,
679 content: vec![MessageContent::ToolResult(tool_result)],
680 cache: true,
681 reasoning_details: None,
682 }
683 ]
684 );
685}
686
687#[gpui::test]
688#[cfg_attr(not(feature = "e2e"), ignore)]
689async fn test_basic_tool_calls(cx: &mut TestAppContext) {
690 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
691
692 // Test a tool call that's likely to complete *before* streaming stops.
693 let events = thread
694 .update(cx, |thread, cx| {
695 thread.add_tool(EchoTool);
696 thread.send(
697 UserMessageId::new(),
698 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
699 cx,
700 )
701 })
702 .unwrap()
703 .collect()
704 .await;
705 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
706
707 // Test a tool calls that's likely to complete *after* streaming stops.
708 let events = thread
709 .update(cx, |thread, cx| {
710 thread.remove_tool(&EchoTool::NAME);
711 thread.add_tool(DelayTool);
712 thread.send(
713 UserMessageId::new(),
714 [
715 "Now call the delay tool with 200ms.",
716 "When the timer goes off, then you echo the output of the tool.",
717 ],
718 cx,
719 )
720 })
721 .unwrap()
722 .collect()
723 .await;
724 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
725 thread.update(cx, |thread, _cx| {
726 assert!(
727 thread
728 .last_received_or_pending_message()
729 .unwrap()
730 .as_agent_message()
731 .unwrap()
732 .content
733 .iter()
734 .any(|content| {
735 if let AgentMessageContent::Text(text) = content {
736 text.contains("Ding")
737 } else {
738 false
739 }
740 }),
741 "{}",
742 thread.to_markdown()
743 );
744 });
745}
746
747#[gpui::test]
748#[cfg_attr(not(feature = "e2e"), ignore)]
749async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
750 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
751
752 // Test a tool call that's likely to complete *before* streaming stops.
753 let mut events = thread
754 .update(cx, |thread, cx| {
755 thread.add_tool(WordListTool);
756 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
757 })
758 .unwrap();
759
760 let mut saw_partial_tool_use = false;
761 while let Some(event) = events.next().await {
762 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
763 thread.update(cx, |thread, _cx| {
764 // Look for a tool use in the thread's last message
765 let message = thread.last_received_or_pending_message().unwrap();
766 let agent_message = message.as_agent_message().unwrap();
767 let last_content = agent_message.content.last().unwrap();
768 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
769 assert_eq!(last_tool_use.name.as_ref(), "word_list");
770 if tool_call.status == acp::ToolCallStatus::Pending {
771 if !last_tool_use.is_input_complete
772 && last_tool_use.input.get("g").is_none()
773 {
774 saw_partial_tool_use = true;
775 }
776 } else {
777 last_tool_use
778 .input
779 .get("a")
780 .expect("'a' has streamed because input is now complete");
781 last_tool_use
782 .input
783 .get("g")
784 .expect("'g' has streamed because input is now complete");
785 }
786 } else {
787 panic!("last content should be a tool use");
788 }
789 });
790 }
791 }
792
793 assert!(
794 saw_partial_tool_use,
795 "should see at least one partially streamed tool use in the history"
796 );
797}
798
799#[gpui::test]
800async fn test_tool_authorization(cx: &mut TestAppContext) {
801 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
802 let fake_model = model.as_fake();
803
804 let mut events = thread
805 .update(cx, |thread, cx| {
806 thread.add_tool(ToolRequiringPermission);
807 thread.send(UserMessageId::new(), ["abc"], cx)
808 })
809 .unwrap();
810 cx.run_until_parked();
811 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
812 LanguageModelToolUse {
813 id: "tool_id_1".into(),
814 name: ToolRequiringPermission::NAME.into(),
815 raw_input: "{}".into(),
816 input: json!({}),
817 is_input_complete: true,
818 thought_signature: None,
819 },
820 ));
821 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
822 LanguageModelToolUse {
823 id: "tool_id_2".into(),
824 name: ToolRequiringPermission::NAME.into(),
825 raw_input: "{}".into(),
826 input: json!({}),
827 is_input_complete: true,
828 thought_signature: None,
829 },
830 ));
831 fake_model.end_last_completion_stream();
832 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
833 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
834
835 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
836 tool_call_auth_1
837 .response
838 .send(acp::PermissionOptionId::new("allow"))
839 .unwrap();
840 cx.run_until_parked();
841
842 // Reject the second - send "deny" option_id directly since Deny is now a button
843 tool_call_auth_2
844 .response
845 .send(acp::PermissionOptionId::new("deny"))
846 .unwrap();
847 cx.run_until_parked();
848
849 let completion = fake_model.pending_completions().pop().unwrap();
850 let message = completion.messages.last().unwrap();
851 assert_eq!(
852 message.content,
853 vec![
854 language_model::MessageContent::ToolResult(LanguageModelToolResult {
855 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
856 tool_name: ToolRequiringPermission::NAME.into(),
857 is_error: false,
858 content: "Allowed".into(),
859 output: Some("Allowed".into())
860 }),
861 language_model::MessageContent::ToolResult(LanguageModelToolResult {
862 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
863 tool_name: ToolRequiringPermission::NAME.into(),
864 is_error: true,
865 content: "Permission to run tool denied by user".into(),
866 output: Some("Permission to run tool denied by user".into())
867 })
868 ]
869 );
870
871 // Simulate yet another tool call.
872 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
873 LanguageModelToolUse {
874 id: "tool_id_3".into(),
875 name: ToolRequiringPermission::NAME.into(),
876 raw_input: "{}".into(),
877 input: json!({}),
878 is_input_complete: true,
879 thought_signature: None,
880 },
881 ));
882 fake_model.end_last_completion_stream();
883
884 // Respond by always allowing tools - send transformed option_id
885 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
886 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
887 tool_call_auth_3
888 .response
889 .send(acp::PermissionOptionId::new(
890 "always_allow:tool_requiring_permission",
891 ))
892 .unwrap();
893 cx.run_until_parked();
894 let completion = fake_model.pending_completions().pop().unwrap();
895 let message = completion.messages.last().unwrap();
896 assert_eq!(
897 message.content,
898 vec![language_model::MessageContent::ToolResult(
899 LanguageModelToolResult {
900 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
901 tool_name: ToolRequiringPermission::NAME.into(),
902 is_error: false,
903 content: "Allowed".into(),
904 output: Some("Allowed".into())
905 }
906 )]
907 );
908
909 // Simulate a final tool call, ensuring we don't trigger authorization.
910 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
911 LanguageModelToolUse {
912 id: "tool_id_4".into(),
913 name: ToolRequiringPermission::NAME.into(),
914 raw_input: "{}".into(),
915 input: json!({}),
916 is_input_complete: true,
917 thought_signature: None,
918 },
919 ));
920 fake_model.end_last_completion_stream();
921 cx.run_until_parked();
922 let completion = fake_model.pending_completions().pop().unwrap();
923 let message = completion.messages.last().unwrap();
924 assert_eq!(
925 message.content,
926 vec![language_model::MessageContent::ToolResult(
927 LanguageModelToolResult {
928 tool_use_id: "tool_id_4".into(),
929 tool_name: ToolRequiringPermission::NAME.into(),
930 is_error: false,
931 content: "Allowed".into(),
932 output: Some("Allowed".into())
933 }
934 )]
935 );
936}
937
938#[gpui::test]
939async fn test_tool_hallucination(cx: &mut TestAppContext) {
940 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
941 let fake_model = model.as_fake();
942
943 let mut events = thread
944 .update(cx, |thread, cx| {
945 thread.send(UserMessageId::new(), ["abc"], cx)
946 })
947 .unwrap();
948 cx.run_until_parked();
949 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
950 LanguageModelToolUse {
951 id: "tool_id_1".into(),
952 name: "nonexistent_tool".into(),
953 raw_input: "{}".into(),
954 input: json!({}),
955 is_input_complete: true,
956 thought_signature: None,
957 },
958 ));
959 fake_model.end_last_completion_stream();
960
961 let tool_call = expect_tool_call(&mut events).await;
962 assert_eq!(tool_call.title, "nonexistent_tool");
963 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
964 let update = expect_tool_call_update_fields(&mut events).await;
965 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
966}
967
968async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
969 let event = events
970 .next()
971 .await
972 .expect("no tool call authorization event received")
973 .unwrap();
974 match event {
975 ThreadEvent::ToolCall(tool_call) => tool_call,
976 event => {
977 panic!("Unexpected event {event:?}");
978 }
979 }
980}
981
982async fn expect_tool_call_update_fields(
983 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
984) -> acp::ToolCallUpdate {
985 let event = events
986 .next()
987 .await
988 .expect("no tool call authorization event received")
989 .unwrap();
990 match event {
991 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
992 event => {
993 panic!("Unexpected event {event:?}");
994 }
995 }
996}
997
998async fn next_tool_call_authorization(
999 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1000) -> ToolCallAuthorization {
1001 loop {
1002 let event = events
1003 .next()
1004 .await
1005 .expect("no tool call authorization event received")
1006 .unwrap();
1007 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1008 let permission_kinds = tool_call_authorization
1009 .options
1010 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1011 .map(|option| option.kind);
1012 let allow_once = tool_call_authorization
1013 .options
1014 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1015 .map(|option| option.kind);
1016
1017 assert_eq!(
1018 permission_kinds,
1019 Some(acp::PermissionOptionKind::AllowAlways)
1020 );
1021 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1022 return tool_call_authorization;
1023 }
1024 }
1025}
1026
1027#[test]
1028fn test_permission_options_terminal_with_pattern() {
1029 let permission_options = ToolPermissionContext::new(
1030 TerminalTool::NAME,
1031 vec!["cargo build --release".to_string()],
1032 )
1033 .build_permission_options();
1034
1035 let PermissionOptions::Dropdown(choices) = permission_options else {
1036 panic!("Expected dropdown permission options");
1037 };
1038
1039 assert_eq!(choices.len(), 3);
1040 let labels: Vec<&str> = choices
1041 .iter()
1042 .map(|choice| choice.allow.name.as_ref())
1043 .collect();
1044 assert!(labels.contains(&"Always for terminal"));
1045 assert!(labels.contains(&"Always for `cargo build` commands"));
1046 assert!(labels.contains(&"Only this time"));
1047}
1048
1049#[test]
1050fn test_permission_options_terminal_command_with_flag_second_token() {
1051 let permission_options =
1052 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1053 .build_permission_options();
1054
1055 let PermissionOptions::Dropdown(choices) = permission_options else {
1056 panic!("Expected dropdown permission options");
1057 };
1058
1059 assert_eq!(choices.len(), 3);
1060 let labels: Vec<&str> = choices
1061 .iter()
1062 .map(|choice| choice.allow.name.as_ref())
1063 .collect();
1064 assert!(labels.contains(&"Always for terminal"));
1065 assert!(labels.contains(&"Always for `ls` commands"));
1066 assert!(labels.contains(&"Only this time"));
1067}
1068
1069#[test]
1070fn test_permission_options_terminal_single_word_command() {
1071 let permission_options =
1072 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1073 .build_permission_options();
1074
1075 let PermissionOptions::Dropdown(choices) = permission_options else {
1076 panic!("Expected dropdown permission options");
1077 };
1078
1079 assert_eq!(choices.len(), 3);
1080 let labels: Vec<&str> = choices
1081 .iter()
1082 .map(|choice| choice.allow.name.as_ref())
1083 .collect();
1084 assert!(labels.contains(&"Always for terminal"));
1085 assert!(labels.contains(&"Always for `whoami` commands"));
1086 assert!(labels.contains(&"Only this time"));
1087}
1088
1089#[test]
1090fn test_permission_options_edit_file_with_path_pattern() {
1091 let permission_options =
1092 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1093 .build_permission_options();
1094
1095 let PermissionOptions::Dropdown(choices) = permission_options else {
1096 panic!("Expected dropdown permission options");
1097 };
1098
1099 let labels: Vec<&str> = choices
1100 .iter()
1101 .map(|choice| choice.allow.name.as_ref())
1102 .collect();
1103 assert!(labels.contains(&"Always for edit file"));
1104 assert!(labels.contains(&"Always for `src/`"));
1105}
1106
1107#[test]
1108fn test_permission_options_fetch_with_domain_pattern() {
1109 let permission_options =
1110 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1111 .build_permission_options();
1112
1113 let PermissionOptions::Dropdown(choices) = permission_options else {
1114 panic!("Expected dropdown permission options");
1115 };
1116
1117 let labels: Vec<&str> = choices
1118 .iter()
1119 .map(|choice| choice.allow.name.as_ref())
1120 .collect();
1121 assert!(labels.contains(&"Always for fetch"));
1122 assert!(labels.contains(&"Always for `docs.rs`"));
1123}
1124
1125#[test]
1126fn test_permission_options_without_pattern() {
1127 let permission_options = ToolPermissionContext::new(
1128 TerminalTool::NAME,
1129 vec!["./deploy.sh --production".to_string()],
1130 )
1131 .build_permission_options();
1132
1133 let PermissionOptions::Dropdown(choices) = permission_options else {
1134 panic!("Expected dropdown permission options");
1135 };
1136
1137 assert_eq!(choices.len(), 2);
1138 let labels: Vec<&str> = choices
1139 .iter()
1140 .map(|choice| choice.allow.name.as_ref())
1141 .collect();
1142 assert!(labels.contains(&"Always for terminal"));
1143 assert!(labels.contains(&"Only this time"));
1144 assert!(!labels.iter().any(|label| label.contains("commands")));
1145}
1146
1147#[test]
1148fn test_permission_options_symlink_target_are_flat_once_only() {
1149 let permission_options =
1150 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1151 .build_permission_options();
1152
1153 let PermissionOptions::Flat(options) = permission_options else {
1154 panic!("Expected flat permission options for symlink target authorization");
1155 };
1156
1157 assert_eq!(options.len(), 2);
1158 assert!(options.iter().any(|option| {
1159 option.option_id.0.as_ref() == "allow"
1160 && option.kind == acp::PermissionOptionKind::AllowOnce
1161 }));
1162 assert!(options.iter().any(|option| {
1163 option.option_id.0.as_ref() == "deny"
1164 && option.kind == acp::PermissionOptionKind::RejectOnce
1165 }));
1166}
1167
1168#[test]
1169fn test_permission_option_ids_for_terminal() {
1170 let permission_options = ToolPermissionContext::new(
1171 TerminalTool::NAME,
1172 vec!["cargo build --release".to_string()],
1173 )
1174 .build_permission_options();
1175
1176 let PermissionOptions::Dropdown(choices) = permission_options else {
1177 panic!("Expected dropdown permission options");
1178 };
1179
1180 let allow_ids: Vec<String> = choices
1181 .iter()
1182 .map(|choice| choice.allow.option_id.0.to_string())
1183 .collect();
1184 let deny_ids: Vec<String> = choices
1185 .iter()
1186 .map(|choice| choice.deny.option_id.0.to_string())
1187 .collect();
1188
1189 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1190 assert!(allow_ids.contains(&"allow".to_string()));
1191 assert!(
1192 allow_ids
1193 .iter()
1194 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1195 "Missing allow pattern option"
1196 );
1197
1198 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1199 assert!(deny_ids.contains(&"deny".to_string()));
1200 assert!(
1201 deny_ids
1202 .iter()
1203 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1204 "Missing deny pattern option"
1205 );
1206}
1207
1208#[gpui::test]
1209#[cfg_attr(not(feature = "e2e"), ignore)]
1210async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1211 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1212
1213 // Test concurrent tool calls with different delay times
1214 let events = thread
1215 .update(cx, |thread, cx| {
1216 thread.add_tool(DelayTool);
1217 thread.send(
1218 UserMessageId::new(),
1219 [
1220 "Call the delay tool twice in the same message.",
1221 "Once with 100ms. Once with 300ms.",
1222 "When both timers are complete, describe the outputs.",
1223 ],
1224 cx,
1225 )
1226 })
1227 .unwrap()
1228 .collect()
1229 .await;
1230
1231 let stop_reasons = stop_events(events);
1232 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1233
1234 thread.update(cx, |thread, _cx| {
1235 let last_message = thread.last_received_or_pending_message().unwrap();
1236 let agent_message = last_message.as_agent_message().unwrap();
1237 let text = agent_message
1238 .content
1239 .iter()
1240 .filter_map(|content| {
1241 if let AgentMessageContent::Text(text) = content {
1242 Some(text.as_str())
1243 } else {
1244 None
1245 }
1246 })
1247 .collect::<String>();
1248
1249 assert!(text.contains("Ding"));
1250 });
1251}
1252
1253#[gpui::test]
1254async fn test_profiles(cx: &mut TestAppContext) {
1255 let ThreadTest {
1256 model, thread, fs, ..
1257 } = setup(cx, TestModel::Fake).await;
1258 let fake_model = model.as_fake();
1259
1260 thread.update(cx, |thread, _cx| {
1261 thread.add_tool(DelayTool);
1262 thread.add_tool(EchoTool);
1263 thread.add_tool(InfiniteTool);
1264 });
1265
1266 // Override profiles and wait for settings to be loaded.
1267 fs.insert_file(
1268 paths::settings_file(),
1269 json!({
1270 "agent": {
1271 "profiles": {
1272 "test-1": {
1273 "name": "Test Profile 1",
1274 "tools": {
1275 EchoTool::NAME: true,
1276 DelayTool::NAME: true,
1277 }
1278 },
1279 "test-2": {
1280 "name": "Test Profile 2",
1281 "tools": {
1282 InfiniteTool::NAME: true,
1283 }
1284 }
1285 }
1286 }
1287 })
1288 .to_string()
1289 .into_bytes(),
1290 )
1291 .await;
1292 cx.run_until_parked();
1293
1294 // Test that test-1 profile (default) has echo and delay tools
1295 thread
1296 .update(cx, |thread, cx| {
1297 thread.set_profile(AgentProfileId("test-1".into()), cx);
1298 thread.send(UserMessageId::new(), ["test"], cx)
1299 })
1300 .unwrap();
1301 cx.run_until_parked();
1302
1303 let mut pending_completions = fake_model.pending_completions();
1304 assert_eq!(pending_completions.len(), 1);
1305 let completion = pending_completions.pop().unwrap();
1306 let tool_names: Vec<String> = completion
1307 .tools
1308 .iter()
1309 .map(|tool| tool.name.clone())
1310 .collect();
1311 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1312 fake_model.end_last_completion_stream();
1313
1314 // Switch to test-2 profile, and verify that it has only the infinite tool.
1315 thread
1316 .update(cx, |thread, cx| {
1317 thread.set_profile(AgentProfileId("test-2".into()), cx);
1318 thread.send(UserMessageId::new(), ["test2"], cx)
1319 })
1320 .unwrap();
1321 cx.run_until_parked();
1322 let mut pending_completions = fake_model.pending_completions();
1323 assert_eq!(pending_completions.len(), 1);
1324 let completion = pending_completions.pop().unwrap();
1325 let tool_names: Vec<String> = completion
1326 .tools
1327 .iter()
1328 .map(|tool| tool.name.clone())
1329 .collect();
1330 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1331}
1332
1333#[gpui::test]
1334async fn test_mcp_tools(cx: &mut TestAppContext) {
1335 let ThreadTest {
1336 model,
1337 thread,
1338 context_server_store,
1339 fs,
1340 ..
1341 } = setup(cx, TestModel::Fake).await;
1342 let fake_model = model.as_fake();
1343
1344 // Override profiles and wait for settings to be loaded.
1345 fs.insert_file(
1346 paths::settings_file(),
1347 json!({
1348 "agent": {
1349 "tool_permissions": { "default": "allow" },
1350 "profiles": {
1351 "test": {
1352 "name": "Test Profile",
1353 "enable_all_context_servers": true,
1354 "tools": {
1355 EchoTool::NAME: true,
1356 }
1357 },
1358 }
1359 }
1360 })
1361 .to_string()
1362 .into_bytes(),
1363 )
1364 .await;
1365 cx.run_until_parked();
1366 thread.update(cx, |thread, cx| {
1367 thread.set_profile(AgentProfileId("test".into()), cx)
1368 });
1369
1370 let mut mcp_tool_calls = setup_context_server(
1371 "test_server",
1372 vec![context_server::types::Tool {
1373 name: "echo".into(),
1374 description: None,
1375 input_schema: serde_json::to_value(EchoTool::input_schema(
1376 LanguageModelToolSchemaFormat::JsonSchema,
1377 ))
1378 .unwrap(),
1379 output_schema: None,
1380 annotations: None,
1381 }],
1382 &context_server_store,
1383 cx,
1384 );
1385
1386 let events = thread.update(cx, |thread, cx| {
1387 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1388 });
1389 cx.run_until_parked();
1390
1391 // Simulate the model calling the MCP tool.
1392 let completion = fake_model.pending_completions().pop().unwrap();
1393 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1394 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1395 LanguageModelToolUse {
1396 id: "tool_1".into(),
1397 name: "echo".into(),
1398 raw_input: json!({"text": "test"}).to_string(),
1399 input: json!({"text": "test"}),
1400 is_input_complete: true,
1401 thought_signature: None,
1402 },
1403 ));
1404 fake_model.end_last_completion_stream();
1405 cx.run_until_parked();
1406
1407 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1408 assert_eq!(tool_call_params.name, "echo");
1409 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1410 tool_call_response
1411 .send(context_server::types::CallToolResponse {
1412 content: vec![context_server::types::ToolResponseContent::Text {
1413 text: "test".into(),
1414 }],
1415 is_error: None,
1416 meta: None,
1417 structured_content: None,
1418 })
1419 .unwrap();
1420 cx.run_until_parked();
1421
1422 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1423 fake_model.send_last_completion_stream_text_chunk("Done!");
1424 fake_model.end_last_completion_stream();
1425 events.collect::<Vec<_>>().await;
1426
1427 // Send again after adding the echo tool, ensuring the name collision is resolved.
1428 let events = thread.update(cx, |thread, cx| {
1429 thread.add_tool(EchoTool);
1430 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1431 });
1432 cx.run_until_parked();
1433 let completion = fake_model.pending_completions().pop().unwrap();
1434 assert_eq!(
1435 tool_names_for_completion(&completion),
1436 vec!["echo", "test_server_echo"]
1437 );
1438 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1439 LanguageModelToolUse {
1440 id: "tool_2".into(),
1441 name: "test_server_echo".into(),
1442 raw_input: json!({"text": "mcp"}).to_string(),
1443 input: json!({"text": "mcp"}),
1444 is_input_complete: true,
1445 thought_signature: None,
1446 },
1447 ));
1448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1449 LanguageModelToolUse {
1450 id: "tool_3".into(),
1451 name: "echo".into(),
1452 raw_input: json!({"text": "native"}).to_string(),
1453 input: json!({"text": "native"}),
1454 is_input_complete: true,
1455 thought_signature: None,
1456 },
1457 ));
1458 fake_model.end_last_completion_stream();
1459 cx.run_until_parked();
1460
1461 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1462 assert_eq!(tool_call_params.name, "echo");
1463 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1464 tool_call_response
1465 .send(context_server::types::CallToolResponse {
1466 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1467 is_error: None,
1468 meta: None,
1469 structured_content: None,
1470 })
1471 .unwrap();
1472 cx.run_until_parked();
1473
1474 // Ensure the tool results were inserted with the correct names.
1475 let completion = fake_model.pending_completions().pop().unwrap();
1476 assert_eq!(
1477 completion.messages.last().unwrap().content,
1478 vec![
1479 MessageContent::ToolResult(LanguageModelToolResult {
1480 tool_use_id: "tool_3".into(),
1481 tool_name: "echo".into(),
1482 is_error: false,
1483 content: "native".into(),
1484 output: Some("native".into()),
1485 },),
1486 MessageContent::ToolResult(LanguageModelToolResult {
1487 tool_use_id: "tool_2".into(),
1488 tool_name: "test_server_echo".into(),
1489 is_error: false,
1490 content: "mcp".into(),
1491 output: Some("mcp".into()),
1492 },),
1493 ]
1494 );
1495 fake_model.end_last_completion_stream();
1496 events.collect::<Vec<_>>().await;
1497}
1498
1499#[gpui::test]
1500async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1501 let ThreadTest {
1502 model,
1503 thread,
1504 context_server_store,
1505 fs,
1506 ..
1507 } = setup(cx, TestModel::Fake).await;
1508 let fake_model = model.as_fake();
1509
1510 // Setup settings to allow MCP tools
1511 fs.insert_file(
1512 paths::settings_file(),
1513 json!({
1514 "agent": {
1515 "always_allow_tool_actions": true,
1516 "profiles": {
1517 "test": {
1518 "name": "Test Profile",
1519 "enable_all_context_servers": true,
1520 "tools": {}
1521 },
1522 }
1523 }
1524 })
1525 .to_string()
1526 .into_bytes(),
1527 )
1528 .await;
1529 cx.run_until_parked();
1530 thread.update(cx, |thread, cx| {
1531 thread.set_profile(AgentProfileId("test".into()), cx)
1532 });
1533
1534 // Setup a context server with a tool
1535 let mut mcp_tool_calls = setup_context_server(
1536 "github_server",
1537 vec![context_server::types::Tool {
1538 name: "issue_read".into(),
1539 description: Some("Read a GitHub issue".into()),
1540 input_schema: json!({
1541 "type": "object",
1542 "properties": {
1543 "issue_url": { "type": "string" }
1544 }
1545 }),
1546 output_schema: None,
1547 annotations: None,
1548 }],
1549 &context_server_store,
1550 cx,
1551 );
1552
1553 // Send a message and have the model call the MCP tool
1554 let events = thread.update(cx, |thread, cx| {
1555 thread
1556 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1557 .unwrap()
1558 });
1559 cx.run_until_parked();
1560
1561 // Verify the MCP tool is available to the model
1562 let completion = fake_model.pending_completions().pop().unwrap();
1563 assert_eq!(
1564 tool_names_for_completion(&completion),
1565 vec!["issue_read"],
1566 "MCP tool should be available"
1567 );
1568
1569 // Simulate the model calling the MCP tool
1570 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1571 LanguageModelToolUse {
1572 id: "tool_1".into(),
1573 name: "issue_read".into(),
1574 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1575 .to_string(),
1576 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1577 is_input_complete: true,
1578 thought_signature: None,
1579 },
1580 ));
1581 fake_model.end_last_completion_stream();
1582 cx.run_until_parked();
1583
1584 // The MCP server receives the tool call and responds with content
1585 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1586 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1587 assert_eq!(tool_call_params.name, "issue_read");
1588 tool_call_response
1589 .send(context_server::types::CallToolResponse {
1590 content: vec![context_server::types::ToolResponseContent::Text {
1591 text: expected_tool_output.into(),
1592 }],
1593 is_error: None,
1594 meta: None,
1595 structured_content: None,
1596 })
1597 .unwrap();
1598 cx.run_until_parked();
1599
1600 // After tool completes, the model continues with a new completion request
1601 // that includes the tool results. We need to respond to this.
1602 let _completion = fake_model.pending_completions().pop().unwrap();
1603 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1604 fake_model
1605 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1606 fake_model.end_last_completion_stream();
1607 events.collect::<Vec<_>>().await;
1608
1609 // Verify the tool result is stored in the thread by checking the markdown output.
1610 // The tool result is in the first assistant message (not the last one, which is
1611 // the model's response after the tool completed).
1612 thread.update(cx, |thread, _cx| {
1613 let markdown = thread.to_markdown();
1614 assert!(
1615 markdown.contains("**Tool Result**: issue_read"),
1616 "Thread should contain tool result header"
1617 );
1618 assert!(
1619 markdown.contains(expected_tool_output),
1620 "Thread should contain tool output: {}",
1621 expected_tool_output
1622 );
1623 });
1624
1625 // Simulate app restart: disconnect the MCP server.
1626 // After restart, the MCP server won't be connected yet when the thread is replayed.
1627 context_server_store.update(cx, |store, cx| {
1628 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1629 });
1630 cx.run_until_parked();
1631
1632 // Replay the thread (this is what happens when loading a saved thread)
1633 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1634
1635 let mut found_tool_call = None;
1636 let mut found_tool_call_update_with_output = None;
1637
1638 while let Some(event) = replay_events.next().await {
1639 let event = event.unwrap();
1640 match &event {
1641 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1642 found_tool_call = Some(tc.clone());
1643 }
1644 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1645 if update.tool_call_id.to_string() == "tool_1" =>
1646 {
1647 if update.fields.raw_output.is_some() {
1648 found_tool_call_update_with_output = Some(update.clone());
1649 }
1650 }
1651 _ => {}
1652 }
1653 }
1654
1655 // The tool call should be found
1656 assert!(
1657 found_tool_call.is_some(),
1658 "Tool call should be emitted during replay"
1659 );
1660
1661 assert!(
1662 found_tool_call_update_with_output.is_some(),
1663 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1664 );
1665
1666 let update = found_tool_call_update_with_output.unwrap();
1667 assert_eq!(
1668 update.fields.raw_output,
1669 Some(expected_tool_output.into()),
1670 "raw_output should contain the saved tool result"
1671 );
1672
1673 // Also verify the status is correct (completed, not failed)
1674 assert_eq!(
1675 update.fields.status,
1676 Some(acp::ToolCallStatus::Completed),
1677 "Tool call status should reflect the original completion status"
1678 );
1679}
1680
1681#[gpui::test]
1682async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1683 let ThreadTest {
1684 model,
1685 thread,
1686 context_server_store,
1687 fs,
1688 ..
1689 } = setup(cx, TestModel::Fake).await;
1690 let fake_model = model.as_fake();
1691
1692 // Set up a profile with all tools enabled
1693 fs.insert_file(
1694 paths::settings_file(),
1695 json!({
1696 "agent": {
1697 "profiles": {
1698 "test": {
1699 "name": "Test Profile",
1700 "enable_all_context_servers": true,
1701 "tools": {
1702 EchoTool::NAME: true,
1703 DelayTool::NAME: true,
1704 WordListTool::NAME: true,
1705 ToolRequiringPermission::NAME: true,
1706 InfiniteTool::NAME: true,
1707 }
1708 },
1709 }
1710 }
1711 })
1712 .to_string()
1713 .into_bytes(),
1714 )
1715 .await;
1716 cx.run_until_parked();
1717
1718 thread.update(cx, |thread, cx| {
1719 thread.set_profile(AgentProfileId("test".into()), cx);
1720 thread.add_tool(EchoTool);
1721 thread.add_tool(DelayTool);
1722 thread.add_tool(WordListTool);
1723 thread.add_tool(ToolRequiringPermission);
1724 thread.add_tool(InfiniteTool);
1725 });
1726
1727 // Set up multiple context servers with some overlapping tool names
1728 let _server1_calls = setup_context_server(
1729 "xxx",
1730 vec![
1731 context_server::types::Tool {
1732 name: "echo".into(), // Conflicts with native EchoTool
1733 description: None,
1734 input_schema: serde_json::to_value(EchoTool::input_schema(
1735 LanguageModelToolSchemaFormat::JsonSchema,
1736 ))
1737 .unwrap(),
1738 output_schema: None,
1739 annotations: None,
1740 },
1741 context_server::types::Tool {
1742 name: "unique_tool_1".into(),
1743 description: None,
1744 input_schema: json!({"type": "object", "properties": {}}),
1745 output_schema: None,
1746 annotations: None,
1747 },
1748 ],
1749 &context_server_store,
1750 cx,
1751 );
1752
1753 let _server2_calls = setup_context_server(
1754 "yyy",
1755 vec![
1756 context_server::types::Tool {
1757 name: "echo".into(), // Also conflicts with native EchoTool
1758 description: None,
1759 input_schema: serde_json::to_value(EchoTool::input_schema(
1760 LanguageModelToolSchemaFormat::JsonSchema,
1761 ))
1762 .unwrap(),
1763 output_schema: None,
1764 annotations: None,
1765 },
1766 context_server::types::Tool {
1767 name: "unique_tool_2".into(),
1768 description: None,
1769 input_schema: json!({"type": "object", "properties": {}}),
1770 output_schema: None,
1771 annotations: None,
1772 },
1773 context_server::types::Tool {
1774 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1775 description: None,
1776 input_schema: json!({"type": "object", "properties": {}}),
1777 output_schema: None,
1778 annotations: None,
1779 },
1780 context_server::types::Tool {
1781 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1782 description: None,
1783 input_schema: json!({"type": "object", "properties": {}}),
1784 output_schema: None,
1785 annotations: None,
1786 },
1787 ],
1788 &context_server_store,
1789 cx,
1790 );
1791 let _server3_calls = setup_context_server(
1792 "zzz",
1793 vec![
1794 context_server::types::Tool {
1795 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1796 description: None,
1797 input_schema: json!({"type": "object", "properties": {}}),
1798 output_schema: None,
1799 annotations: None,
1800 },
1801 context_server::types::Tool {
1802 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1803 description: None,
1804 input_schema: json!({"type": "object", "properties": {}}),
1805 output_schema: None,
1806 annotations: None,
1807 },
1808 context_server::types::Tool {
1809 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1810 description: None,
1811 input_schema: json!({"type": "object", "properties": {}}),
1812 output_schema: None,
1813 annotations: None,
1814 },
1815 ],
1816 &context_server_store,
1817 cx,
1818 );
1819
1820 // Server with spaces in name - tests snake_case conversion for API compatibility
1821 let _server4_calls = setup_context_server(
1822 "Azure DevOps",
1823 vec![context_server::types::Tool {
1824 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1825 description: None,
1826 input_schema: serde_json::to_value(EchoTool::input_schema(
1827 LanguageModelToolSchemaFormat::JsonSchema,
1828 ))
1829 .unwrap(),
1830 output_schema: None,
1831 annotations: None,
1832 }],
1833 &context_server_store,
1834 cx,
1835 );
1836
1837 thread
1838 .update(cx, |thread, cx| {
1839 thread.send(UserMessageId::new(), ["Go"], cx)
1840 })
1841 .unwrap();
1842 cx.run_until_parked();
1843 let completion = fake_model.pending_completions().pop().unwrap();
1844 assert_eq!(
1845 tool_names_for_completion(&completion),
1846 vec![
1847 "azure_dev_ops_echo",
1848 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1849 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1850 "delay",
1851 "echo",
1852 "infinite",
1853 "tool_requiring_permission",
1854 "unique_tool_1",
1855 "unique_tool_2",
1856 "word_list",
1857 "xxx_echo",
1858 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1859 "yyy_echo",
1860 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1861 ]
1862 );
1863}
1864
1865#[gpui::test]
1866#[cfg_attr(not(feature = "e2e"), ignore)]
1867async fn test_cancellation(cx: &mut TestAppContext) {
1868 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1869
1870 let mut events = thread
1871 .update(cx, |thread, cx| {
1872 thread.add_tool(InfiniteTool);
1873 thread.add_tool(EchoTool);
1874 thread.send(
1875 UserMessageId::new(),
1876 ["Call the echo tool, then call the infinite tool, then explain their output"],
1877 cx,
1878 )
1879 })
1880 .unwrap();
1881
1882 // Wait until both tools are called.
1883 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1884 let mut echo_id = None;
1885 let mut echo_completed = false;
1886 while let Some(event) = events.next().await {
1887 match event.unwrap() {
1888 ThreadEvent::ToolCall(tool_call) => {
1889 assert_eq!(tool_call.title, expected_tools.remove(0));
1890 if tool_call.title == "Echo" {
1891 echo_id = Some(tool_call.tool_call_id);
1892 }
1893 }
1894 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1895 acp::ToolCallUpdate {
1896 tool_call_id,
1897 fields:
1898 acp::ToolCallUpdateFields {
1899 status: Some(acp::ToolCallStatus::Completed),
1900 ..
1901 },
1902 ..
1903 },
1904 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1905 echo_completed = true;
1906 }
1907 _ => {}
1908 }
1909
1910 if expected_tools.is_empty() && echo_completed {
1911 break;
1912 }
1913 }
1914
1915 // Cancel the current send and ensure that the event stream is closed, even
1916 // if one of the tools is still running.
1917 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1918 let events = events.collect::<Vec<_>>().await;
1919 let last_event = events.last();
1920 assert!(
1921 matches!(
1922 last_event,
1923 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1924 ),
1925 "unexpected event {last_event:?}"
1926 );
1927
1928 // Ensure we can still send a new message after cancellation.
1929 let events = thread
1930 .update(cx, |thread, cx| {
1931 thread.send(
1932 UserMessageId::new(),
1933 ["Testing: reply with 'Hello' then stop."],
1934 cx,
1935 )
1936 })
1937 .unwrap()
1938 .collect::<Vec<_>>()
1939 .await;
1940 thread.update(cx, |thread, _cx| {
1941 let message = thread.last_received_or_pending_message().unwrap();
1942 let agent_message = message.as_agent_message().unwrap();
1943 assert_eq!(
1944 agent_message.content,
1945 vec![AgentMessageContent::Text("Hello".to_string())]
1946 );
1947 });
1948 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1949}
1950
1951#[gpui::test]
1952async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1953 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1954 always_allow_tools(cx);
1955 let fake_model = model.as_fake();
1956
1957 let environment = Rc::new(cx.update(|cx| {
1958 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1959 }));
1960 let handle = environment.terminal_handle.clone().unwrap();
1961
1962 let mut events = thread
1963 .update(cx, |thread, cx| {
1964 thread.add_tool(crate::TerminalTool::new(
1965 thread.project().clone(),
1966 environment,
1967 ));
1968 thread.send(UserMessageId::new(), ["run a command"], cx)
1969 })
1970 .unwrap();
1971
1972 cx.run_until_parked();
1973
1974 // Simulate the model calling the terminal tool
1975 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1976 LanguageModelToolUse {
1977 id: "terminal_tool_1".into(),
1978 name: TerminalTool::NAME.into(),
1979 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1980 input: json!({"command": "sleep 1000", "cd": "."}),
1981 is_input_complete: true,
1982 thought_signature: None,
1983 },
1984 ));
1985 fake_model.end_last_completion_stream();
1986
1987 // Wait for the terminal tool to start running
1988 wait_for_terminal_tool_started(&mut events, cx).await;
1989
1990 // Cancel the thread while the terminal is running
1991 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1992
1993 // Collect remaining events, driving the executor to let cancellation complete
1994 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1995
1996 // Verify the terminal was killed
1997 assert!(
1998 handle.was_killed(),
1999 "expected terminal handle to be killed on cancellation"
2000 );
2001
2002 // Verify we got a cancellation stop event
2003 assert_eq!(
2004 stop_events(remaining_events),
2005 vec![acp::StopReason::Cancelled],
2006 );
2007
2008 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2009 thread.update(cx, |thread, _cx| {
2010 let message = thread.last_received_or_pending_message().unwrap();
2011 let agent_message = message.as_agent_message().unwrap();
2012
2013 let tool_use = agent_message
2014 .content
2015 .iter()
2016 .find_map(|content| match content {
2017 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2018 _ => None,
2019 })
2020 .expect("expected tool use in agent message");
2021
2022 let tool_result = agent_message
2023 .tool_results
2024 .get(&tool_use.id)
2025 .expect("expected tool result");
2026
2027 let result_text = match &tool_result.content {
2028 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2029 _ => panic!("expected text content in tool result"),
2030 };
2031
2032 // "partial output" comes from FakeTerminalHandle's output field
2033 assert!(
2034 result_text.contains("partial output"),
2035 "expected tool result to contain terminal output, got: {result_text}"
2036 );
2037 // Match the actual format from process_content in terminal_tool.rs
2038 assert!(
2039 result_text.contains("The user stopped this command"),
2040 "expected tool result to indicate user stopped, got: {result_text}"
2041 );
2042 });
2043
2044 // Verify we can send a new message after cancellation
2045 verify_thread_recovery(&thread, &fake_model, cx).await;
2046}
2047
2048#[gpui::test]
2049async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2050 // This test verifies that tools which properly handle cancellation via
2051 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2052 // to cancellation and report that they were cancelled.
2053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2054 always_allow_tools(cx);
2055 let fake_model = model.as_fake();
2056
2057 let (tool, was_cancelled) = CancellationAwareTool::new();
2058
2059 let mut events = thread
2060 .update(cx, |thread, cx| {
2061 thread.add_tool(tool);
2062 thread.send(
2063 UserMessageId::new(),
2064 ["call the cancellation aware tool"],
2065 cx,
2066 )
2067 })
2068 .unwrap();
2069
2070 cx.run_until_parked();
2071
2072 // Simulate the model calling the cancellation-aware tool
2073 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2074 LanguageModelToolUse {
2075 id: "cancellation_aware_1".into(),
2076 name: "cancellation_aware".into(),
2077 raw_input: r#"{}"#.into(),
2078 input: json!({}),
2079 is_input_complete: true,
2080 thought_signature: None,
2081 },
2082 ));
2083 fake_model.end_last_completion_stream();
2084
2085 cx.run_until_parked();
2086
2087 // Wait for the tool call to be reported
2088 let mut tool_started = false;
2089 let deadline = cx.executor().num_cpus() * 100;
2090 for _ in 0..deadline {
2091 cx.run_until_parked();
2092
2093 while let Some(Some(event)) = events.next().now_or_never() {
2094 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2095 if tool_call.title == "Cancellation Aware Tool" {
2096 tool_started = true;
2097 break;
2098 }
2099 }
2100 }
2101
2102 if tool_started {
2103 break;
2104 }
2105
2106 cx.background_executor
2107 .timer(Duration::from_millis(10))
2108 .await;
2109 }
2110 assert!(tool_started, "expected cancellation aware tool to start");
2111
2112 // Cancel the thread and wait for it to complete
2113 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2114
2115 // The cancel task should complete promptly because the tool handles cancellation
2116 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2117 futures::select! {
2118 _ = cancel_task.fuse() => {}
2119 _ = timeout.fuse() => {
2120 panic!("cancel task timed out - tool did not respond to cancellation");
2121 }
2122 }
2123
2124 // Verify the tool detected cancellation via its flag
2125 assert!(
2126 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2127 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2128 );
2129
2130 // Collect remaining events
2131 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2132
2133 // Verify we got a cancellation stop event
2134 assert_eq!(
2135 stop_events(remaining_events),
2136 vec![acp::StopReason::Cancelled],
2137 );
2138
2139 // Verify we can send a new message after cancellation
2140 verify_thread_recovery(&thread, &fake_model, cx).await;
2141}
2142
2143/// Helper to verify thread can recover after cancellation by sending a simple message.
2144async fn verify_thread_recovery(
2145 thread: &Entity<Thread>,
2146 fake_model: &FakeLanguageModel,
2147 cx: &mut TestAppContext,
2148) {
2149 let events = thread
2150 .update(cx, |thread, cx| {
2151 thread.send(
2152 UserMessageId::new(),
2153 ["Testing: reply with 'Hello' then stop."],
2154 cx,
2155 )
2156 })
2157 .unwrap();
2158 cx.run_until_parked();
2159 fake_model.send_last_completion_stream_text_chunk("Hello");
2160 fake_model
2161 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2162 fake_model.end_last_completion_stream();
2163
2164 let events = events.collect::<Vec<_>>().await;
2165 thread.update(cx, |thread, _cx| {
2166 let message = thread.last_received_or_pending_message().unwrap();
2167 let agent_message = message.as_agent_message().unwrap();
2168 assert_eq!(
2169 agent_message.content,
2170 vec![AgentMessageContent::Text("Hello".to_string())]
2171 );
2172 });
2173 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2174}
2175
2176/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2177async fn wait_for_terminal_tool_started(
2178 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2179 cx: &mut TestAppContext,
2180) {
2181 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2182 for _ in 0..deadline {
2183 cx.run_until_parked();
2184
2185 while let Some(Some(event)) = events.next().now_or_never() {
2186 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2187 update,
2188 ))) = &event
2189 {
2190 if update.fields.content.as_ref().is_some_and(|content| {
2191 content
2192 .iter()
2193 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2194 }) {
2195 return;
2196 }
2197 }
2198 }
2199
2200 cx.background_executor
2201 .timer(Duration::from_millis(10))
2202 .await;
2203 }
2204 panic!("terminal tool did not start within the expected time");
2205}
2206
2207/// Collects events until a Stop event is received, driving the executor to completion.
2208async fn collect_events_until_stop(
2209 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2210 cx: &mut TestAppContext,
2211) -> Vec<Result<ThreadEvent>> {
2212 let mut collected = Vec::new();
2213 let deadline = cx.executor().num_cpus() * 200;
2214
2215 for _ in 0..deadline {
2216 cx.executor().advance_clock(Duration::from_millis(10));
2217 cx.run_until_parked();
2218
2219 while let Some(Some(event)) = events.next().now_or_never() {
2220 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2221 collected.push(event);
2222 if is_stop {
2223 return collected;
2224 }
2225 }
2226 }
2227 panic!(
2228 "did not receive Stop event within the expected time; collected {} events",
2229 collected.len()
2230 );
2231}
2232
2233#[gpui::test]
2234async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2235 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2236 always_allow_tools(cx);
2237 let fake_model = model.as_fake();
2238
2239 let environment = Rc::new(cx.update(|cx| {
2240 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2241 }));
2242 let handle = environment.terminal_handle.clone().unwrap();
2243
2244 let message_id = UserMessageId::new();
2245 let mut events = thread
2246 .update(cx, |thread, cx| {
2247 thread.add_tool(crate::TerminalTool::new(
2248 thread.project().clone(),
2249 environment,
2250 ));
2251 thread.send(message_id.clone(), ["run a command"], cx)
2252 })
2253 .unwrap();
2254
2255 cx.run_until_parked();
2256
2257 // Simulate the model calling the terminal tool
2258 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2259 LanguageModelToolUse {
2260 id: "terminal_tool_1".into(),
2261 name: TerminalTool::NAME.into(),
2262 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2263 input: json!({"command": "sleep 1000", "cd": "."}),
2264 is_input_complete: true,
2265 thought_signature: None,
2266 },
2267 ));
2268 fake_model.end_last_completion_stream();
2269
2270 // Wait for the terminal tool to start running
2271 wait_for_terminal_tool_started(&mut events, cx).await;
2272
2273 // Truncate the thread while the terminal is running
2274 thread
2275 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2276 .unwrap();
2277
2278 // Drive the executor to let cancellation complete
2279 let _ = collect_events_until_stop(&mut events, cx).await;
2280
2281 // Verify the terminal was killed
2282 assert!(
2283 handle.was_killed(),
2284 "expected terminal handle to be killed on truncate"
2285 );
2286
2287 // Verify the thread is empty after truncation
2288 thread.update(cx, |thread, _cx| {
2289 assert_eq!(
2290 thread.to_markdown(),
2291 "",
2292 "expected thread to be empty after truncating the only message"
2293 );
2294 });
2295
2296 // Verify we can send a new message after truncation
2297 verify_thread_recovery(&thread, &fake_model, cx).await;
2298}
2299
2300#[gpui::test]
2301async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2302 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2303 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2304 always_allow_tools(cx);
2305 let fake_model = model.as_fake();
2306
2307 let environment = Rc::new(MultiTerminalEnvironment::new());
2308
2309 let mut events = thread
2310 .update(cx, |thread, cx| {
2311 thread.add_tool(crate::TerminalTool::new(
2312 thread.project().clone(),
2313 environment.clone(),
2314 ));
2315 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2316 })
2317 .unwrap();
2318
2319 cx.run_until_parked();
2320
2321 // Simulate the model calling two terminal tools
2322 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2323 LanguageModelToolUse {
2324 id: "terminal_tool_1".into(),
2325 name: TerminalTool::NAME.into(),
2326 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2327 input: json!({"command": "sleep 1000", "cd": "."}),
2328 is_input_complete: true,
2329 thought_signature: None,
2330 },
2331 ));
2332 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2333 LanguageModelToolUse {
2334 id: "terminal_tool_2".into(),
2335 name: TerminalTool::NAME.into(),
2336 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2337 input: json!({"command": "sleep 2000", "cd": "."}),
2338 is_input_complete: true,
2339 thought_signature: None,
2340 },
2341 ));
2342 fake_model.end_last_completion_stream();
2343
2344 // Wait for both terminal tools to start by counting terminal content updates
2345 let mut terminals_started = 0;
2346 let deadline = cx.executor().num_cpus() * 100;
2347 for _ in 0..deadline {
2348 cx.run_until_parked();
2349
2350 while let Some(Some(event)) = events.next().now_or_never() {
2351 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2352 update,
2353 ))) = &event
2354 {
2355 if update.fields.content.as_ref().is_some_and(|content| {
2356 content
2357 .iter()
2358 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2359 }) {
2360 terminals_started += 1;
2361 if terminals_started >= 2 {
2362 break;
2363 }
2364 }
2365 }
2366 }
2367 if terminals_started >= 2 {
2368 break;
2369 }
2370
2371 cx.background_executor
2372 .timer(Duration::from_millis(10))
2373 .await;
2374 }
2375 assert!(
2376 terminals_started >= 2,
2377 "expected 2 terminal tools to start, got {terminals_started}"
2378 );
2379
2380 // Cancel the thread while both terminals are running
2381 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2382
2383 // Collect remaining events
2384 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2385
2386 // Verify both terminal handles were killed
2387 let handles = environment.handles();
2388 assert_eq!(
2389 handles.len(),
2390 2,
2391 "expected 2 terminal handles to be created"
2392 );
2393 assert!(
2394 handles[0].was_killed(),
2395 "expected first terminal handle to be killed on cancellation"
2396 );
2397 assert!(
2398 handles[1].was_killed(),
2399 "expected second terminal handle to be killed on cancellation"
2400 );
2401
2402 // Verify we got a cancellation stop event
2403 assert_eq!(
2404 stop_events(remaining_events),
2405 vec![acp::StopReason::Cancelled],
2406 );
2407}
2408
2409#[gpui::test]
2410async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2411 // Tests that clicking the stop button on the terminal card (as opposed to the main
2412 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2414 always_allow_tools(cx);
2415 let fake_model = model.as_fake();
2416
2417 let environment = Rc::new(cx.update(|cx| {
2418 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2419 }));
2420 let handle = environment.terminal_handle.clone().unwrap();
2421
2422 let mut events = thread
2423 .update(cx, |thread, cx| {
2424 thread.add_tool(crate::TerminalTool::new(
2425 thread.project().clone(),
2426 environment,
2427 ));
2428 thread.send(UserMessageId::new(), ["run a command"], cx)
2429 })
2430 .unwrap();
2431
2432 cx.run_until_parked();
2433
2434 // Simulate the model calling the terminal tool
2435 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2436 LanguageModelToolUse {
2437 id: "terminal_tool_1".into(),
2438 name: TerminalTool::NAME.into(),
2439 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2440 input: json!({"command": "sleep 1000", "cd": "."}),
2441 is_input_complete: true,
2442 thought_signature: None,
2443 },
2444 ));
2445 fake_model.end_last_completion_stream();
2446
2447 // Wait for the terminal tool to start running
2448 wait_for_terminal_tool_started(&mut events, cx).await;
2449
2450 // Simulate user clicking stop on the terminal card itself.
2451 // This sets the flag and signals exit (simulating what the real UI would do).
2452 handle.set_stopped_by_user(true);
2453 handle.killed.store(true, Ordering::SeqCst);
2454 handle.signal_exit();
2455
2456 // Wait for the tool to complete
2457 cx.run_until_parked();
2458
2459 // The thread continues after tool completion - simulate the model ending its turn
2460 fake_model
2461 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2462 fake_model.end_last_completion_stream();
2463
2464 // Collect remaining events
2465 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2466
2467 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2468 assert_eq!(
2469 stop_events(remaining_events),
2470 vec![acp::StopReason::EndTurn],
2471 );
2472
2473 // Verify the tool result indicates user stopped
2474 thread.update(cx, |thread, _cx| {
2475 let message = thread.last_received_or_pending_message().unwrap();
2476 let agent_message = message.as_agent_message().unwrap();
2477
2478 let tool_use = agent_message
2479 .content
2480 .iter()
2481 .find_map(|content| match content {
2482 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2483 _ => None,
2484 })
2485 .expect("expected tool use in agent message");
2486
2487 let tool_result = agent_message
2488 .tool_results
2489 .get(&tool_use.id)
2490 .expect("expected tool result");
2491
2492 let result_text = match &tool_result.content {
2493 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2494 _ => panic!("expected text content in tool result"),
2495 };
2496
2497 assert!(
2498 result_text.contains("The user stopped this command"),
2499 "expected tool result to indicate user stopped, got: {result_text}"
2500 );
2501 });
2502}
2503
2504#[gpui::test]
2505async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2506 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2507 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2508 always_allow_tools(cx);
2509 let fake_model = model.as_fake();
2510
2511 let environment = Rc::new(cx.update(|cx| {
2512 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2513 }));
2514 let handle = environment.terminal_handle.clone().unwrap();
2515
2516 let mut events = thread
2517 .update(cx, |thread, cx| {
2518 thread.add_tool(crate::TerminalTool::new(
2519 thread.project().clone(),
2520 environment,
2521 ));
2522 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2523 })
2524 .unwrap();
2525
2526 cx.run_until_parked();
2527
2528 // Simulate the model calling the terminal tool with a short timeout
2529 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2530 LanguageModelToolUse {
2531 id: "terminal_tool_1".into(),
2532 name: TerminalTool::NAME.into(),
2533 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2534 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2535 is_input_complete: true,
2536 thought_signature: None,
2537 },
2538 ));
2539 fake_model.end_last_completion_stream();
2540
2541 // Wait for the terminal tool to start running
2542 wait_for_terminal_tool_started(&mut events, cx).await;
2543
2544 // Advance clock past the timeout
2545 cx.executor().advance_clock(Duration::from_millis(200));
2546 cx.run_until_parked();
2547
2548 // The thread continues after tool completion - simulate the model ending its turn
2549 fake_model
2550 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2551 fake_model.end_last_completion_stream();
2552
2553 // Collect remaining events
2554 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2555
2556 // Verify the terminal was killed due to timeout
2557 assert!(
2558 handle.was_killed(),
2559 "expected terminal handle to be killed on timeout"
2560 );
2561
2562 // Verify we got an EndTurn (the tool completed, just with timeout)
2563 assert_eq!(
2564 stop_events(remaining_events),
2565 vec![acp::StopReason::EndTurn],
2566 );
2567
2568 // Verify the tool result indicates timeout, not user stopped
2569 thread.update(cx, |thread, _cx| {
2570 let message = thread.last_received_or_pending_message().unwrap();
2571 let agent_message = message.as_agent_message().unwrap();
2572
2573 let tool_use = agent_message
2574 .content
2575 .iter()
2576 .find_map(|content| match content {
2577 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2578 _ => None,
2579 })
2580 .expect("expected tool use in agent message");
2581
2582 let tool_result = agent_message
2583 .tool_results
2584 .get(&tool_use.id)
2585 .expect("expected tool result");
2586
2587 let result_text = match &tool_result.content {
2588 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2589 _ => panic!("expected text content in tool result"),
2590 };
2591
2592 assert!(
2593 result_text.contains("timed out"),
2594 "expected tool result to indicate timeout, got: {result_text}"
2595 );
2596 assert!(
2597 !result_text.contains("The user stopped"),
2598 "tool result should not mention user stopped when it timed out, got: {result_text}"
2599 );
2600 });
2601}
2602
2603#[gpui::test]
2604async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2605 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2606 let fake_model = model.as_fake();
2607
2608 let events_1 = thread
2609 .update(cx, |thread, cx| {
2610 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2611 })
2612 .unwrap();
2613 cx.run_until_parked();
2614 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2615 cx.run_until_parked();
2616
2617 let events_2 = thread
2618 .update(cx, |thread, cx| {
2619 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2620 })
2621 .unwrap();
2622 cx.run_until_parked();
2623 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2624 fake_model
2625 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2626 fake_model.end_last_completion_stream();
2627
2628 let events_1 = events_1.collect::<Vec<_>>().await;
2629 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2630 let events_2 = events_2.collect::<Vec<_>>().await;
2631 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2632}
2633
2634#[gpui::test]
2635async fn test_retry_cancelled_promptly_on_new_send(cx: &mut TestAppContext) {
2636 // Regression test: when a completion fails with a retryable error (e.g. upstream 500),
2637 // the retry loop waits on a timer. If the user switches models and sends a new message
2638 // during that delay, the old turn should exit immediately instead of retrying with the
2639 // stale model.
2640 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2641 let model_a = model.as_fake();
2642
2643 // Start a turn with model_a.
2644 let events_1 = thread
2645 .update(cx, |thread, cx| {
2646 thread.send(UserMessageId::new(), ["Hello"], cx)
2647 })
2648 .unwrap();
2649 cx.run_until_parked();
2650 assert_eq!(model_a.completion_count(), 1);
2651
2652 // Model returns a retryable upstream 500. The turn enters the retry delay.
2653 model_a.send_last_completion_stream_error(
2654 LanguageModelCompletionError::UpstreamProviderError {
2655 message: "Internal server error".to_string(),
2656 status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
2657 retry_after: None,
2658 },
2659 );
2660 model_a.end_last_completion_stream();
2661 cx.run_until_parked();
2662
2663 // The old completion was consumed; model_a has no pending requests yet because the
2664 // retry timer hasn't fired.
2665 assert_eq!(model_a.completion_count(), 0);
2666
2667 // Switch to model_b and send a new message. This cancels the old turn.
2668 let model_b = Arc::new(FakeLanguageModel::with_id_and_thinking(
2669 "fake", "model-b", "Model B", false,
2670 ));
2671 thread.update(cx, |thread, cx| {
2672 thread.set_model(model_b.clone(), cx);
2673 });
2674 let events_2 = thread
2675 .update(cx, |thread, cx| {
2676 thread.send(UserMessageId::new(), ["Continue"], cx)
2677 })
2678 .unwrap();
2679 cx.run_until_parked();
2680
2681 // model_b should have received its completion request.
2682 assert_eq!(model_b.as_fake().completion_count(), 1);
2683
2684 // Advance the clock well past the retry delay (BASE_RETRY_DELAY = 5s).
2685 cx.executor().advance_clock(Duration::from_secs(10));
2686 cx.run_until_parked();
2687
2688 // model_a must NOT have received another completion request — the cancelled turn
2689 // should have exited during the retry delay rather than retrying with the old model.
2690 assert_eq!(
2691 model_a.completion_count(),
2692 0,
2693 "old model should not receive a retry request after cancellation"
2694 );
2695
2696 // Complete model_b's turn.
2697 model_b
2698 .as_fake()
2699 .send_last_completion_stream_text_chunk("Done!");
2700 model_b
2701 .as_fake()
2702 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2703 model_b.as_fake().end_last_completion_stream();
2704
2705 let events_1 = events_1.collect::<Vec<_>>().await;
2706 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2707
2708 let events_2 = events_2.collect::<Vec<_>>().await;
2709 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2710}
2711
2712#[gpui::test]
2713async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2714 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2715 let fake_model = model.as_fake();
2716
2717 let events_1 = thread
2718 .update(cx, |thread, cx| {
2719 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2720 })
2721 .unwrap();
2722 cx.run_until_parked();
2723 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2724 fake_model
2725 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2726 fake_model.end_last_completion_stream();
2727 let events_1 = events_1.collect::<Vec<_>>().await;
2728
2729 let events_2 = thread
2730 .update(cx, |thread, cx| {
2731 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2732 })
2733 .unwrap();
2734 cx.run_until_parked();
2735 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2736 fake_model
2737 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2738 fake_model.end_last_completion_stream();
2739 let events_2 = events_2.collect::<Vec<_>>().await;
2740
2741 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2742 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2743}
2744
2745#[gpui::test]
2746async fn test_refusal(cx: &mut TestAppContext) {
2747 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2748 let fake_model = model.as_fake();
2749
2750 let events = thread
2751 .update(cx, |thread, cx| {
2752 thread.send(UserMessageId::new(), ["Hello"], cx)
2753 })
2754 .unwrap();
2755 cx.run_until_parked();
2756 thread.read_with(cx, |thread, _| {
2757 assert_eq!(
2758 thread.to_markdown(),
2759 indoc! {"
2760 ## User
2761
2762 Hello
2763 "}
2764 );
2765 });
2766
2767 fake_model.send_last_completion_stream_text_chunk("Hey!");
2768 cx.run_until_parked();
2769 thread.read_with(cx, |thread, _| {
2770 assert_eq!(
2771 thread.to_markdown(),
2772 indoc! {"
2773 ## User
2774
2775 Hello
2776
2777 ## Assistant
2778
2779 Hey!
2780 "}
2781 );
2782 });
2783
2784 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2785 fake_model
2786 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2787 let events = events.collect::<Vec<_>>().await;
2788 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2789 thread.read_with(cx, |thread, _| {
2790 assert_eq!(thread.to_markdown(), "");
2791 });
2792}
2793
2794#[gpui::test]
2795async fn test_truncate_first_message(cx: &mut TestAppContext) {
2796 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2797 let fake_model = model.as_fake();
2798
2799 let message_id = UserMessageId::new();
2800 thread
2801 .update(cx, |thread, cx| {
2802 thread.send(message_id.clone(), ["Hello"], cx)
2803 })
2804 .unwrap();
2805 cx.run_until_parked();
2806 thread.read_with(cx, |thread, _| {
2807 assert_eq!(
2808 thread.to_markdown(),
2809 indoc! {"
2810 ## User
2811
2812 Hello
2813 "}
2814 );
2815 assert_eq!(thread.latest_token_usage(), None);
2816 });
2817
2818 fake_model.send_last_completion_stream_text_chunk("Hey!");
2819 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2820 language_model::TokenUsage {
2821 input_tokens: 32_000,
2822 output_tokens: 16_000,
2823 cache_creation_input_tokens: 0,
2824 cache_read_input_tokens: 0,
2825 },
2826 ));
2827 cx.run_until_parked();
2828 thread.read_with(cx, |thread, _| {
2829 assert_eq!(
2830 thread.to_markdown(),
2831 indoc! {"
2832 ## User
2833
2834 Hello
2835
2836 ## Assistant
2837
2838 Hey!
2839 "}
2840 );
2841 assert_eq!(
2842 thread.latest_token_usage(),
2843 Some(acp_thread::TokenUsage {
2844 used_tokens: 32_000 + 16_000,
2845 max_tokens: 1_000_000,
2846 max_output_tokens: None,
2847 input_tokens: 32_000,
2848 output_tokens: 16_000,
2849 })
2850 );
2851 });
2852
2853 thread
2854 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2855 .unwrap();
2856 cx.run_until_parked();
2857 thread.read_with(cx, |thread, _| {
2858 assert_eq!(thread.to_markdown(), "");
2859 assert_eq!(thread.latest_token_usage(), None);
2860 });
2861
2862 // Ensure we can still send a new message after truncation.
2863 thread
2864 .update(cx, |thread, cx| {
2865 thread.send(UserMessageId::new(), ["Hi"], cx)
2866 })
2867 .unwrap();
2868 thread.update(cx, |thread, _cx| {
2869 assert_eq!(
2870 thread.to_markdown(),
2871 indoc! {"
2872 ## User
2873
2874 Hi
2875 "}
2876 );
2877 });
2878 cx.run_until_parked();
2879 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2880 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2881 language_model::TokenUsage {
2882 input_tokens: 40_000,
2883 output_tokens: 20_000,
2884 cache_creation_input_tokens: 0,
2885 cache_read_input_tokens: 0,
2886 },
2887 ));
2888 cx.run_until_parked();
2889 thread.read_with(cx, |thread, _| {
2890 assert_eq!(
2891 thread.to_markdown(),
2892 indoc! {"
2893 ## User
2894
2895 Hi
2896
2897 ## Assistant
2898
2899 Ahoy!
2900 "}
2901 );
2902
2903 assert_eq!(
2904 thread.latest_token_usage(),
2905 Some(acp_thread::TokenUsage {
2906 used_tokens: 40_000 + 20_000,
2907 max_tokens: 1_000_000,
2908 max_output_tokens: None,
2909 input_tokens: 40_000,
2910 output_tokens: 20_000,
2911 })
2912 );
2913 });
2914}
2915
2916#[gpui::test]
2917async fn test_truncate_second_message(cx: &mut TestAppContext) {
2918 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2919 let fake_model = model.as_fake();
2920
2921 thread
2922 .update(cx, |thread, cx| {
2923 thread.send(UserMessageId::new(), ["Message 1"], cx)
2924 })
2925 .unwrap();
2926 cx.run_until_parked();
2927 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2928 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2929 language_model::TokenUsage {
2930 input_tokens: 32_000,
2931 output_tokens: 16_000,
2932 cache_creation_input_tokens: 0,
2933 cache_read_input_tokens: 0,
2934 },
2935 ));
2936 fake_model.end_last_completion_stream();
2937 cx.run_until_parked();
2938
2939 let assert_first_message_state = |cx: &mut TestAppContext| {
2940 thread.clone().read_with(cx, |thread, _| {
2941 assert_eq!(
2942 thread.to_markdown(),
2943 indoc! {"
2944 ## User
2945
2946 Message 1
2947
2948 ## Assistant
2949
2950 Message 1 response
2951 "}
2952 );
2953
2954 assert_eq!(
2955 thread.latest_token_usage(),
2956 Some(acp_thread::TokenUsage {
2957 used_tokens: 32_000 + 16_000,
2958 max_tokens: 1_000_000,
2959 max_output_tokens: None,
2960 input_tokens: 32_000,
2961 output_tokens: 16_000,
2962 })
2963 );
2964 });
2965 };
2966
2967 assert_first_message_state(cx);
2968
2969 let second_message_id = UserMessageId::new();
2970 thread
2971 .update(cx, |thread, cx| {
2972 thread.send(second_message_id.clone(), ["Message 2"], cx)
2973 })
2974 .unwrap();
2975 cx.run_until_parked();
2976
2977 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2978 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2979 language_model::TokenUsage {
2980 input_tokens: 40_000,
2981 output_tokens: 20_000,
2982 cache_creation_input_tokens: 0,
2983 cache_read_input_tokens: 0,
2984 },
2985 ));
2986 fake_model.end_last_completion_stream();
2987 cx.run_until_parked();
2988
2989 thread.read_with(cx, |thread, _| {
2990 assert_eq!(
2991 thread.to_markdown(),
2992 indoc! {"
2993 ## User
2994
2995 Message 1
2996
2997 ## Assistant
2998
2999 Message 1 response
3000
3001 ## User
3002
3003 Message 2
3004
3005 ## Assistant
3006
3007 Message 2 response
3008 "}
3009 );
3010
3011 assert_eq!(
3012 thread.latest_token_usage(),
3013 Some(acp_thread::TokenUsage {
3014 used_tokens: 40_000 + 20_000,
3015 max_tokens: 1_000_000,
3016 max_output_tokens: None,
3017 input_tokens: 40_000,
3018 output_tokens: 20_000,
3019 })
3020 );
3021 });
3022
3023 thread
3024 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
3025 .unwrap();
3026 cx.run_until_parked();
3027
3028 assert_first_message_state(cx);
3029}
3030
3031#[gpui::test]
3032async fn test_title_generation(cx: &mut TestAppContext) {
3033 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3034 let fake_model = model.as_fake();
3035
3036 let summary_model = Arc::new(FakeLanguageModel::default());
3037 thread.update(cx, |thread, cx| {
3038 thread.set_summarization_model(Some(summary_model.clone()), cx)
3039 });
3040
3041 let send = thread
3042 .update(cx, |thread, cx| {
3043 thread.send(UserMessageId::new(), ["Hello"], cx)
3044 })
3045 .unwrap();
3046 cx.run_until_parked();
3047
3048 fake_model.send_last_completion_stream_text_chunk("Hey!");
3049 fake_model.end_last_completion_stream();
3050 cx.run_until_parked();
3051 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
3052
3053 // Ensure the summary model has been invoked to generate a title.
3054 summary_model.send_last_completion_stream_text_chunk("Hello ");
3055 summary_model.send_last_completion_stream_text_chunk("world\nG");
3056 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
3057 summary_model.end_last_completion_stream();
3058 send.collect::<Vec<_>>().await;
3059 cx.run_until_parked();
3060 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3061
3062 // Send another message, ensuring no title is generated this time.
3063 let send = thread
3064 .update(cx, |thread, cx| {
3065 thread.send(UserMessageId::new(), ["Hello again"], cx)
3066 })
3067 .unwrap();
3068 cx.run_until_parked();
3069 fake_model.send_last_completion_stream_text_chunk("Hey again!");
3070 fake_model.end_last_completion_stream();
3071 cx.run_until_parked();
3072 assert_eq!(summary_model.pending_completions(), Vec::new());
3073 send.collect::<Vec<_>>().await;
3074 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3075}
3076
3077#[gpui::test]
3078async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
3079 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3080 let fake_model = model.as_fake();
3081
3082 let _events = thread
3083 .update(cx, |thread, cx| {
3084 thread.add_tool(ToolRequiringPermission);
3085 thread.add_tool(EchoTool);
3086 thread.send(UserMessageId::new(), ["Hey!"], cx)
3087 })
3088 .unwrap();
3089 cx.run_until_parked();
3090
3091 let permission_tool_use = LanguageModelToolUse {
3092 id: "tool_id_1".into(),
3093 name: ToolRequiringPermission::NAME.into(),
3094 raw_input: "{}".into(),
3095 input: json!({}),
3096 is_input_complete: true,
3097 thought_signature: None,
3098 };
3099 let echo_tool_use = LanguageModelToolUse {
3100 id: "tool_id_2".into(),
3101 name: EchoTool::NAME.into(),
3102 raw_input: json!({"text": "test"}).to_string(),
3103 input: json!({"text": "test"}),
3104 is_input_complete: true,
3105 thought_signature: None,
3106 };
3107 fake_model.send_last_completion_stream_text_chunk("Hi!");
3108 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3109 permission_tool_use,
3110 ));
3111 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3112 echo_tool_use.clone(),
3113 ));
3114 fake_model.end_last_completion_stream();
3115 cx.run_until_parked();
3116
3117 // Ensure pending tools are skipped when building a request.
3118 let request = thread
3119 .read_with(cx, |thread, cx| {
3120 thread.build_completion_request(CompletionIntent::EditFile, cx)
3121 })
3122 .unwrap();
3123 assert_eq!(
3124 request.messages[1..],
3125 vec![
3126 LanguageModelRequestMessage {
3127 role: Role::User,
3128 content: vec!["Hey!".into()],
3129 cache: true,
3130 reasoning_details: None,
3131 },
3132 LanguageModelRequestMessage {
3133 role: Role::Assistant,
3134 content: vec![
3135 MessageContent::Text("Hi!".into()),
3136 MessageContent::ToolUse(echo_tool_use.clone())
3137 ],
3138 cache: false,
3139 reasoning_details: None,
3140 },
3141 LanguageModelRequestMessage {
3142 role: Role::User,
3143 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3144 tool_use_id: echo_tool_use.id.clone(),
3145 tool_name: echo_tool_use.name,
3146 is_error: false,
3147 content: "test".into(),
3148 output: Some("test".into())
3149 })],
3150 cache: false,
3151 reasoning_details: None,
3152 },
3153 ],
3154 );
3155}
3156
3157#[gpui::test]
3158async fn test_agent_connection(cx: &mut TestAppContext) {
3159 cx.update(settings::init);
3160 let templates = Templates::new();
3161
3162 // Initialize language model system with test provider
3163 cx.update(|cx| {
3164 gpui_tokio::init(cx);
3165
3166 let http_client = FakeHttpClient::with_404_response();
3167 let clock = Arc::new(clock::FakeSystemClock::new());
3168 let client = Client::new(clock, http_client, cx);
3169 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3170 language_model::init(client.clone(), cx);
3171 language_models::init(user_store, client.clone(), cx);
3172 LanguageModelRegistry::test(cx);
3173 });
3174 cx.executor().forbid_parking();
3175
3176 // Create a project for new_thread
3177 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3178 fake_fs.insert_tree(path!("/test"), json!({})).await;
3179 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3180 let cwd = Path::new("/test");
3181 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3182
3183 // Create agent and connection
3184 let agent = NativeAgent::new(
3185 project.clone(),
3186 thread_store,
3187 templates.clone(),
3188 None,
3189 fake_fs.clone(),
3190 &mut cx.to_async(),
3191 )
3192 .await
3193 .unwrap();
3194 let connection = NativeAgentConnection(agent.clone());
3195
3196 // Create a thread using new_thread
3197 let connection_rc = Rc::new(connection.clone());
3198 let acp_thread = cx
3199 .update(|cx| connection_rc.new_session(project, cwd, cx))
3200 .await
3201 .expect("new_thread should succeed");
3202
3203 // Get the session_id from the AcpThread
3204 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3205
3206 // Test model_selector returns Some
3207 let selector_opt = connection.model_selector(&session_id);
3208 assert!(
3209 selector_opt.is_some(),
3210 "agent should always support ModelSelector"
3211 );
3212 let selector = selector_opt.unwrap();
3213
3214 // Test list_models
3215 let listed_models = cx
3216 .update(|cx| selector.list_models(cx))
3217 .await
3218 .expect("list_models should succeed");
3219 let AgentModelList::Grouped(listed_models) = listed_models else {
3220 panic!("Unexpected model list type");
3221 };
3222 assert!(!listed_models.is_empty(), "should have at least one model");
3223 assert_eq!(
3224 listed_models[&AgentModelGroupName("Fake".into())][0]
3225 .id
3226 .0
3227 .as_ref(),
3228 "fake/fake"
3229 );
3230
3231 // Test selected_model returns the default
3232 let model = cx
3233 .update(|cx| selector.selected_model(cx))
3234 .await
3235 .expect("selected_model should succeed");
3236 let model = cx
3237 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3238 .unwrap();
3239 let model = model.as_fake();
3240 assert_eq!(model.id().0, "fake", "should return default model");
3241
3242 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3243 cx.run_until_parked();
3244 model.send_last_completion_stream_text_chunk("def");
3245 cx.run_until_parked();
3246 acp_thread.read_with(cx, |thread, cx| {
3247 assert_eq!(
3248 thread.to_markdown(cx),
3249 indoc! {"
3250 ## User
3251
3252 abc
3253
3254 ## Assistant
3255
3256 def
3257
3258 "}
3259 )
3260 });
3261
3262 // Test cancel
3263 cx.update(|cx| connection.cancel(&session_id, cx));
3264 request.await.expect("prompt should fail gracefully");
3265
3266 // Explicitly close the session and drop the ACP thread.
3267 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3268 .await
3269 .unwrap();
3270 drop(acp_thread);
3271 let result = cx
3272 .update(|cx| {
3273 connection.prompt(
3274 Some(acp_thread::UserMessageId::new()),
3275 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3276 cx,
3277 )
3278 })
3279 .await;
3280 assert_eq!(
3281 result.as_ref().unwrap_err().to_string(),
3282 "Session not found",
3283 "unexpected result: {:?}",
3284 result
3285 );
3286}
3287
3288#[gpui::test]
3289async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3290 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3291 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
3292 let fake_model = model.as_fake();
3293
3294 let mut events = thread
3295 .update(cx, |thread, cx| {
3296 thread.send(UserMessageId::new(), ["Echo something"], cx)
3297 })
3298 .unwrap();
3299 cx.run_until_parked();
3300
3301 // Simulate streaming partial input.
3302 let input = json!({});
3303 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3304 LanguageModelToolUse {
3305 id: "1".into(),
3306 name: EchoTool::NAME.into(),
3307 raw_input: input.to_string(),
3308 input,
3309 is_input_complete: false,
3310 thought_signature: None,
3311 },
3312 ));
3313
3314 // Input streaming completed
3315 let input = json!({ "text": "Hello!" });
3316 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3317 LanguageModelToolUse {
3318 id: "1".into(),
3319 name: "echo".into(),
3320 raw_input: input.to_string(),
3321 input,
3322 is_input_complete: true,
3323 thought_signature: None,
3324 },
3325 ));
3326 fake_model.end_last_completion_stream();
3327 cx.run_until_parked();
3328
3329 let tool_call = expect_tool_call(&mut events).await;
3330 assert_eq!(
3331 tool_call,
3332 acp::ToolCall::new("1", "Echo")
3333 .raw_input(json!({}))
3334 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3335 );
3336 let update = expect_tool_call_update_fields(&mut events).await;
3337 assert_eq!(
3338 update,
3339 acp::ToolCallUpdate::new(
3340 "1",
3341 acp::ToolCallUpdateFields::new()
3342 .title("Echo")
3343 .kind(acp::ToolKind::Other)
3344 .raw_input(json!({ "text": "Hello!"}))
3345 )
3346 );
3347 let update = expect_tool_call_update_fields(&mut events).await;
3348 assert_eq!(
3349 update,
3350 acp::ToolCallUpdate::new(
3351 "1",
3352 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3353 )
3354 );
3355 let update = expect_tool_call_update_fields(&mut events).await;
3356 assert_eq!(
3357 update,
3358 acp::ToolCallUpdate::new(
3359 "1",
3360 acp::ToolCallUpdateFields::new()
3361 .status(acp::ToolCallStatus::Completed)
3362 .raw_output("Hello!")
3363 )
3364 );
3365}
3366
3367#[gpui::test]
3368async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3369 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3370 let fake_model = model.as_fake();
3371
3372 let mut events = thread
3373 .update(cx, |thread, cx| {
3374 thread.send(UserMessageId::new(), ["Hello!"], cx)
3375 })
3376 .unwrap();
3377 cx.run_until_parked();
3378
3379 fake_model.send_last_completion_stream_text_chunk("Hey!");
3380 fake_model.end_last_completion_stream();
3381
3382 let mut retry_events = Vec::new();
3383 while let Some(Ok(event)) = events.next().await {
3384 match event {
3385 ThreadEvent::Retry(retry_status) => {
3386 retry_events.push(retry_status);
3387 }
3388 ThreadEvent::Stop(..) => break,
3389 _ => {}
3390 }
3391 }
3392
3393 assert_eq!(retry_events.len(), 0);
3394 thread.read_with(cx, |thread, _cx| {
3395 assert_eq!(
3396 thread.to_markdown(),
3397 indoc! {"
3398 ## User
3399
3400 Hello!
3401
3402 ## Assistant
3403
3404 Hey!
3405 "}
3406 )
3407 });
3408}
3409
3410#[gpui::test]
3411async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3412 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3413 let fake_model = model.as_fake();
3414
3415 let mut events = thread
3416 .update(cx, |thread, cx| {
3417 thread.send(UserMessageId::new(), ["Hello!"], cx)
3418 })
3419 .unwrap();
3420 cx.run_until_parked();
3421
3422 fake_model.send_last_completion_stream_text_chunk("Hey,");
3423 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3424 provider: LanguageModelProviderName::new("Anthropic"),
3425 retry_after: Some(Duration::from_secs(3)),
3426 });
3427 fake_model.end_last_completion_stream();
3428
3429 cx.executor().advance_clock(Duration::from_secs(3));
3430 cx.run_until_parked();
3431
3432 fake_model.send_last_completion_stream_text_chunk("there!");
3433 fake_model.end_last_completion_stream();
3434 cx.run_until_parked();
3435
3436 let mut retry_events = Vec::new();
3437 while let Some(Ok(event)) = events.next().await {
3438 match event {
3439 ThreadEvent::Retry(retry_status) => {
3440 retry_events.push(retry_status);
3441 }
3442 ThreadEvent::Stop(..) => break,
3443 _ => {}
3444 }
3445 }
3446
3447 assert_eq!(retry_events.len(), 1);
3448 assert!(matches!(
3449 retry_events[0],
3450 acp_thread::RetryStatus { attempt: 1, .. }
3451 ));
3452 thread.read_with(cx, |thread, _cx| {
3453 assert_eq!(
3454 thread.to_markdown(),
3455 indoc! {"
3456 ## User
3457
3458 Hello!
3459
3460 ## Assistant
3461
3462 Hey,
3463
3464 [resume]
3465
3466 ## Assistant
3467
3468 there!
3469 "}
3470 )
3471 });
3472}
3473
3474#[gpui::test]
3475async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3476 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3477 let fake_model = model.as_fake();
3478
3479 let events = thread
3480 .update(cx, |thread, cx| {
3481 thread.add_tool(EchoTool);
3482 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3483 })
3484 .unwrap();
3485 cx.run_until_parked();
3486
3487 let tool_use_1 = LanguageModelToolUse {
3488 id: "tool_1".into(),
3489 name: EchoTool::NAME.into(),
3490 raw_input: json!({"text": "test"}).to_string(),
3491 input: json!({"text": "test"}),
3492 is_input_complete: true,
3493 thought_signature: None,
3494 };
3495 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3496 tool_use_1.clone(),
3497 ));
3498 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3499 provider: LanguageModelProviderName::new("Anthropic"),
3500 retry_after: Some(Duration::from_secs(3)),
3501 });
3502 fake_model.end_last_completion_stream();
3503
3504 cx.executor().advance_clock(Duration::from_secs(3));
3505 let completion = fake_model.pending_completions().pop().unwrap();
3506 assert_eq!(
3507 completion.messages[1..],
3508 vec![
3509 LanguageModelRequestMessage {
3510 role: Role::User,
3511 content: vec!["Call the echo tool!".into()],
3512 cache: false,
3513 reasoning_details: None,
3514 },
3515 LanguageModelRequestMessage {
3516 role: Role::Assistant,
3517 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3518 cache: false,
3519 reasoning_details: None,
3520 },
3521 LanguageModelRequestMessage {
3522 role: Role::User,
3523 content: vec![language_model::MessageContent::ToolResult(
3524 LanguageModelToolResult {
3525 tool_use_id: tool_use_1.id.clone(),
3526 tool_name: tool_use_1.name.clone(),
3527 is_error: false,
3528 content: "test".into(),
3529 output: Some("test".into())
3530 }
3531 )],
3532 cache: true,
3533 reasoning_details: None,
3534 },
3535 ]
3536 );
3537
3538 fake_model.send_last_completion_stream_text_chunk("Done");
3539 fake_model.end_last_completion_stream();
3540 cx.run_until_parked();
3541 events.collect::<Vec<_>>().await;
3542 thread.read_with(cx, |thread, _cx| {
3543 assert_eq!(
3544 thread.last_received_or_pending_message(),
3545 Some(Message::Agent(AgentMessage {
3546 content: vec![AgentMessageContent::Text("Done".into())],
3547 tool_results: IndexMap::default(),
3548 reasoning_details: None,
3549 }))
3550 );
3551 })
3552}
3553
3554#[gpui::test]
3555async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3556 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3557 let fake_model = model.as_fake();
3558
3559 let mut events = thread
3560 .update(cx, |thread, cx| {
3561 thread.send(UserMessageId::new(), ["Hello!"], cx)
3562 })
3563 .unwrap();
3564 cx.run_until_parked();
3565
3566 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3567 fake_model.send_last_completion_stream_error(
3568 LanguageModelCompletionError::ServerOverloaded {
3569 provider: LanguageModelProviderName::new("Anthropic"),
3570 retry_after: Some(Duration::from_secs(3)),
3571 },
3572 );
3573 fake_model.end_last_completion_stream();
3574 cx.executor().advance_clock(Duration::from_secs(3));
3575 cx.run_until_parked();
3576 }
3577
3578 let mut errors = Vec::new();
3579 let mut retry_events = Vec::new();
3580 while let Some(event) = events.next().await {
3581 match event {
3582 Ok(ThreadEvent::Retry(retry_status)) => {
3583 retry_events.push(retry_status);
3584 }
3585 Ok(ThreadEvent::Stop(..)) => break,
3586 Err(error) => errors.push(error),
3587 _ => {}
3588 }
3589 }
3590
3591 assert_eq!(
3592 retry_events.len(),
3593 crate::thread::MAX_RETRY_ATTEMPTS as usize
3594 );
3595 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3596 assert_eq!(retry_events[i].attempt, i + 1);
3597 }
3598 assert_eq!(errors.len(), 1);
3599 let error = errors[0]
3600 .downcast_ref::<LanguageModelCompletionError>()
3601 .unwrap();
3602 assert!(matches!(
3603 error,
3604 LanguageModelCompletionError::ServerOverloaded { .. }
3605 ));
3606}
3607
3608#[gpui::test]
3609async fn test_streaming_tool_completes_when_llm_stream_ends_without_final_input(
3610 cx: &mut TestAppContext,
3611) {
3612 init_test(cx);
3613 always_allow_tools(cx);
3614
3615 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3616 let fake_model = model.as_fake();
3617
3618 thread.update(cx, |thread, _cx| {
3619 thread.add_tool(StreamingEchoTool);
3620 });
3621
3622 let _events = thread
3623 .update(cx, |thread, cx| {
3624 thread.send(UserMessageId::new(), ["Use the streaming_echo tool"], cx)
3625 })
3626 .unwrap();
3627 cx.run_until_parked();
3628
3629 // Send a partial tool use (is_input_complete = false), simulating the LLM
3630 // streaming input for a tool.
3631 let tool_use = LanguageModelToolUse {
3632 id: "tool_1".into(),
3633 name: "streaming_echo".into(),
3634 raw_input: r#"{"text": "partial"}"#.into(),
3635 input: json!({"text": "partial"}),
3636 is_input_complete: false,
3637 thought_signature: None,
3638 };
3639 fake_model
3640 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
3641 cx.run_until_parked();
3642
3643 // Send a stream error WITHOUT ever sending is_input_complete = true.
3644 // Before the fix, this would deadlock: the tool waits for more partials
3645 // (or cancellation), run_turn_internal waits for the tool, and the sender
3646 // keeping the channel open lives inside RunningTurn.
3647 fake_model.send_last_completion_stream_error(
3648 LanguageModelCompletionError::UpstreamProviderError {
3649 message: "Internal server error".to_string(),
3650 status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
3651 retry_after: None,
3652 },
3653 );
3654 fake_model.end_last_completion_stream();
3655
3656 // Advance past the retry delay so run_turn_internal retries.
3657 cx.executor().advance_clock(Duration::from_secs(5));
3658 cx.run_until_parked();
3659
3660 // The retry request should contain the streaming tool's error result,
3661 // proving the tool terminated and its result was forwarded.
3662 let completion = fake_model
3663 .pending_completions()
3664 .pop()
3665 .expect("No running turn");
3666 assert_eq!(
3667 completion.messages[1..],
3668 vec![
3669 LanguageModelRequestMessage {
3670 role: Role::User,
3671 content: vec!["Use the streaming_echo tool".into()],
3672 cache: false,
3673 reasoning_details: None,
3674 },
3675 LanguageModelRequestMessage {
3676 role: Role::Assistant,
3677 content: vec![language_model::MessageContent::ToolUse(tool_use.clone())],
3678 cache: false,
3679 reasoning_details: None,
3680 },
3681 LanguageModelRequestMessage {
3682 role: Role::User,
3683 content: vec![language_model::MessageContent::ToolResult(
3684 LanguageModelToolResult {
3685 tool_use_id: tool_use.id.clone(),
3686 tool_name: tool_use.name,
3687 is_error: true,
3688 content: "Failed to receive tool input: tool input was not fully received"
3689 .into(),
3690 output: Some(
3691 "Failed to receive tool input: tool input was not fully received"
3692 .into()
3693 ),
3694 }
3695 )],
3696 cache: true,
3697 reasoning_details: None,
3698 },
3699 ]
3700 );
3701
3702 // Finish the retry round so the turn completes cleanly.
3703 fake_model.send_last_completion_stream_text_chunk("Done");
3704 fake_model.end_last_completion_stream();
3705 cx.run_until_parked();
3706
3707 thread.read_with(cx, |thread, _cx| {
3708 assert!(
3709 thread.is_turn_complete(),
3710 "Thread should not be stuck; the turn should have completed",
3711 );
3712 });
3713}
3714
3715/// Filters out the stop events for asserting against in tests
3716fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3717 result_events
3718 .into_iter()
3719 .filter_map(|event| match event.unwrap() {
3720 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3721 _ => None,
3722 })
3723 .collect()
3724}
3725
3726struct ThreadTest {
3727 model: Arc<dyn LanguageModel>,
3728 thread: Entity<Thread>,
3729 project_context: Entity<ProjectContext>,
3730 context_server_store: Entity<ContextServerStore>,
3731 fs: Arc<FakeFs>,
3732}
3733
3734enum TestModel {
3735 Sonnet4,
3736 Fake,
3737}
3738
3739impl TestModel {
3740 fn id(&self) -> LanguageModelId {
3741 match self {
3742 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3743 TestModel::Fake => unreachable!(),
3744 }
3745 }
3746}
3747
3748async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3749 cx.executor().allow_parking();
3750
3751 let fs = FakeFs::new(cx.background_executor.clone());
3752 fs.create_dir(paths::settings_file().parent().unwrap())
3753 .await
3754 .unwrap();
3755 fs.insert_file(
3756 paths::settings_file(),
3757 json!({
3758 "agent": {
3759 "default_profile": "test-profile",
3760 "profiles": {
3761 "test-profile": {
3762 "name": "Test Profile",
3763 "tools": {
3764 EchoTool::NAME: true,
3765 DelayTool::NAME: true,
3766 WordListTool::NAME: true,
3767 ToolRequiringPermission::NAME: true,
3768 InfiniteTool::NAME: true,
3769 CancellationAwareTool::NAME: true,
3770 StreamingEchoTool::NAME: true,
3771 (TerminalTool::NAME): true,
3772 }
3773 }
3774 }
3775 }
3776 })
3777 .to_string()
3778 .into_bytes(),
3779 )
3780 .await;
3781
3782 cx.update(|cx| {
3783 settings::init(cx);
3784
3785 match model {
3786 TestModel::Fake => {}
3787 TestModel::Sonnet4 => {
3788 gpui_tokio::init(cx);
3789 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3790 cx.set_http_client(Arc::new(http_client));
3791 let client = Client::production(cx);
3792 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3793 language_model::init(client.clone(), cx);
3794 language_models::init(user_store, client.clone(), cx);
3795 }
3796 };
3797
3798 watch_settings(fs.clone(), cx);
3799 });
3800
3801 let templates = Templates::new();
3802
3803 fs.insert_tree(path!("/test"), json!({})).await;
3804 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3805
3806 let model = cx
3807 .update(|cx| {
3808 if let TestModel::Fake = model {
3809 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3810 } else {
3811 let model_id = model.id();
3812 let models = LanguageModelRegistry::read_global(cx);
3813 let model = models
3814 .available_models(cx)
3815 .find(|model| model.id() == model_id)
3816 .unwrap();
3817
3818 let provider = models.provider(&model.provider_id()).unwrap();
3819 let authenticated = provider.authenticate(cx);
3820
3821 cx.spawn(async move |_cx| {
3822 authenticated.await.unwrap();
3823 model
3824 })
3825 }
3826 })
3827 .await;
3828
3829 let project_context = cx.new(|_cx| ProjectContext::default());
3830 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3831 let context_server_registry =
3832 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3833 let thread = cx.new(|cx| {
3834 Thread::new(
3835 project,
3836 project_context.clone(),
3837 context_server_registry,
3838 templates,
3839 Some(model.clone()),
3840 cx,
3841 )
3842 });
3843 ThreadTest {
3844 model,
3845 thread,
3846 project_context,
3847 context_server_store,
3848 fs,
3849 }
3850}
3851
3852#[cfg(test)]
3853#[ctor::ctor]
3854fn init_logger() {
3855 if std::env::var("RUST_LOG").is_ok() {
3856 env_logger::init();
3857 }
3858}
3859
3860fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3861 let fs = fs.clone();
3862 cx.spawn({
3863 async move |cx| {
3864 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3865 cx.background_executor(),
3866 fs,
3867 paths::settings_file().clone(),
3868 );
3869 let _watcher_task = watcher_task;
3870
3871 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3872 cx.update(|cx| {
3873 SettingsStore::update_global(cx, |settings, cx| {
3874 settings.set_user_settings(&new_settings_content, cx)
3875 })
3876 })
3877 .ok();
3878 }
3879 }
3880 })
3881 .detach();
3882}
3883
3884fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3885 completion
3886 .tools
3887 .iter()
3888 .map(|tool| tool.name.clone())
3889 .collect()
3890}
3891
3892fn setup_context_server(
3893 name: &'static str,
3894 tools: Vec<context_server::types::Tool>,
3895 context_server_store: &Entity<ContextServerStore>,
3896 cx: &mut TestAppContext,
3897) -> mpsc::UnboundedReceiver<(
3898 context_server::types::CallToolParams,
3899 oneshot::Sender<context_server::types::CallToolResponse>,
3900)> {
3901 cx.update(|cx| {
3902 let mut settings = ProjectSettings::get_global(cx).clone();
3903 settings.context_servers.insert(
3904 name.into(),
3905 project::project_settings::ContextServerSettings::Stdio {
3906 enabled: true,
3907 remote: false,
3908 command: ContextServerCommand {
3909 path: "somebinary".into(),
3910 args: Vec::new(),
3911 env: None,
3912 timeout: None,
3913 },
3914 },
3915 );
3916 ProjectSettings::override_global(settings, cx);
3917 });
3918
3919 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3920 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3921 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3922 context_server::types::InitializeResponse {
3923 protocol_version: context_server::types::ProtocolVersion(
3924 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3925 ),
3926 server_info: context_server::types::Implementation {
3927 name: name.into(),
3928 version: "1.0.0".to_string(),
3929 },
3930 capabilities: context_server::types::ServerCapabilities {
3931 tools: Some(context_server::types::ToolsCapabilities {
3932 list_changed: Some(true),
3933 }),
3934 ..Default::default()
3935 },
3936 meta: None,
3937 }
3938 })
3939 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3940 let tools = tools.clone();
3941 async move {
3942 context_server::types::ListToolsResponse {
3943 tools,
3944 next_cursor: None,
3945 meta: None,
3946 }
3947 }
3948 })
3949 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3950 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3951 async move {
3952 let (response_tx, response_rx) = oneshot::channel();
3953 mcp_tool_calls_tx
3954 .unbounded_send((params, response_tx))
3955 .unwrap();
3956 response_rx.await.unwrap()
3957 }
3958 });
3959 context_server_store.update(cx, |store, cx| {
3960 store.start_server(
3961 Arc::new(ContextServer::new(
3962 ContextServerId(name.into()),
3963 Arc::new(fake_transport),
3964 )),
3965 cx,
3966 );
3967 });
3968 cx.run_until_parked();
3969 mcp_tool_calls_rx
3970}
3971
3972#[gpui::test]
3973async fn test_tokens_before_message(cx: &mut TestAppContext) {
3974 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3975 let fake_model = model.as_fake();
3976
3977 // First message
3978 let message_1_id = UserMessageId::new();
3979 thread
3980 .update(cx, |thread, cx| {
3981 thread.send(message_1_id.clone(), ["First message"], cx)
3982 })
3983 .unwrap();
3984 cx.run_until_parked();
3985
3986 // Before any response, tokens_before_message should return None for first message
3987 thread.read_with(cx, |thread, _| {
3988 assert_eq!(
3989 thread.tokens_before_message(&message_1_id),
3990 None,
3991 "First message should have no tokens before it"
3992 );
3993 });
3994
3995 // Complete first message with usage
3996 fake_model.send_last_completion_stream_text_chunk("Response 1");
3997 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3998 language_model::TokenUsage {
3999 input_tokens: 100,
4000 output_tokens: 50,
4001 cache_creation_input_tokens: 0,
4002 cache_read_input_tokens: 0,
4003 },
4004 ));
4005 fake_model.end_last_completion_stream();
4006 cx.run_until_parked();
4007
4008 // First message still has no tokens before it
4009 thread.read_with(cx, |thread, _| {
4010 assert_eq!(
4011 thread.tokens_before_message(&message_1_id),
4012 None,
4013 "First message should still have no tokens before it after response"
4014 );
4015 });
4016
4017 // Second message
4018 let message_2_id = UserMessageId::new();
4019 thread
4020 .update(cx, |thread, cx| {
4021 thread.send(message_2_id.clone(), ["Second message"], cx)
4022 })
4023 .unwrap();
4024 cx.run_until_parked();
4025
4026 // Second message should have first message's input tokens before it
4027 thread.read_with(cx, |thread, _| {
4028 assert_eq!(
4029 thread.tokens_before_message(&message_2_id),
4030 Some(100),
4031 "Second message should have 100 tokens before it (from first request)"
4032 );
4033 });
4034
4035 // Complete second message
4036 fake_model.send_last_completion_stream_text_chunk("Response 2");
4037 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4038 language_model::TokenUsage {
4039 input_tokens: 250, // Total for this request (includes previous context)
4040 output_tokens: 75,
4041 cache_creation_input_tokens: 0,
4042 cache_read_input_tokens: 0,
4043 },
4044 ));
4045 fake_model.end_last_completion_stream();
4046 cx.run_until_parked();
4047
4048 // Third message
4049 let message_3_id = UserMessageId::new();
4050 thread
4051 .update(cx, |thread, cx| {
4052 thread.send(message_3_id.clone(), ["Third message"], cx)
4053 })
4054 .unwrap();
4055 cx.run_until_parked();
4056
4057 // Third message should have second message's input tokens (250) before it
4058 thread.read_with(cx, |thread, _| {
4059 assert_eq!(
4060 thread.tokens_before_message(&message_3_id),
4061 Some(250),
4062 "Third message should have 250 tokens before it (from second request)"
4063 );
4064 // Second message should still have 100
4065 assert_eq!(
4066 thread.tokens_before_message(&message_2_id),
4067 Some(100),
4068 "Second message should still have 100 tokens before it"
4069 );
4070 // First message still has none
4071 assert_eq!(
4072 thread.tokens_before_message(&message_1_id),
4073 None,
4074 "First message should still have no tokens before it"
4075 );
4076 });
4077}
4078
4079#[gpui::test]
4080async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
4081 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4082 let fake_model = model.as_fake();
4083
4084 // Set up three messages with responses
4085 let message_1_id = UserMessageId::new();
4086 thread
4087 .update(cx, |thread, cx| {
4088 thread.send(message_1_id.clone(), ["Message 1"], cx)
4089 })
4090 .unwrap();
4091 cx.run_until_parked();
4092 fake_model.send_last_completion_stream_text_chunk("Response 1");
4093 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4094 language_model::TokenUsage {
4095 input_tokens: 100,
4096 output_tokens: 50,
4097 cache_creation_input_tokens: 0,
4098 cache_read_input_tokens: 0,
4099 },
4100 ));
4101 fake_model.end_last_completion_stream();
4102 cx.run_until_parked();
4103
4104 let message_2_id = UserMessageId::new();
4105 thread
4106 .update(cx, |thread, cx| {
4107 thread.send(message_2_id.clone(), ["Message 2"], cx)
4108 })
4109 .unwrap();
4110 cx.run_until_parked();
4111 fake_model.send_last_completion_stream_text_chunk("Response 2");
4112 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4113 language_model::TokenUsage {
4114 input_tokens: 250,
4115 output_tokens: 75,
4116 cache_creation_input_tokens: 0,
4117 cache_read_input_tokens: 0,
4118 },
4119 ));
4120 fake_model.end_last_completion_stream();
4121 cx.run_until_parked();
4122
4123 // Verify initial state
4124 thread.read_with(cx, |thread, _| {
4125 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
4126 });
4127
4128 // Truncate at message 2 (removes message 2 and everything after)
4129 thread
4130 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
4131 .unwrap();
4132 cx.run_until_parked();
4133
4134 // After truncation, message_2_id no longer exists, so lookup should return None
4135 thread.read_with(cx, |thread, _| {
4136 assert_eq!(
4137 thread.tokens_before_message(&message_2_id),
4138 None,
4139 "After truncation, message 2 no longer exists"
4140 );
4141 // Message 1 still exists but has no tokens before it
4142 assert_eq!(
4143 thread.tokens_before_message(&message_1_id),
4144 None,
4145 "First message still has no tokens before it"
4146 );
4147 });
4148}
4149
4150#[gpui::test]
4151async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
4152 init_test(cx);
4153
4154 let fs = FakeFs::new(cx.executor());
4155 fs.insert_tree("/root", json!({})).await;
4156 let project = Project::test(fs, ["/root".as_ref()], cx).await;
4157
4158 // Test 1: Deny rule blocks command
4159 {
4160 let environment = Rc::new(cx.update(|cx| {
4161 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4162 }));
4163
4164 cx.update(|cx| {
4165 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4166 settings.tool_permissions.tools.insert(
4167 TerminalTool::NAME.into(),
4168 agent_settings::ToolRules {
4169 default: Some(settings::ToolPermissionMode::Confirm),
4170 always_allow: vec![],
4171 always_deny: vec![
4172 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
4173 ],
4174 always_confirm: vec![],
4175 invalid_patterns: vec![],
4176 },
4177 );
4178 agent_settings::AgentSettings::override_global(settings, cx);
4179 });
4180
4181 #[allow(clippy::arc_with_non_send_sync)]
4182 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4183 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4184
4185 let task = cx.update(|cx| {
4186 tool.run(
4187 ToolInput::resolved(crate::TerminalToolInput {
4188 command: "rm -rf /".to_string(),
4189 cd: ".".to_string(),
4190 timeout_ms: None,
4191 }),
4192 event_stream,
4193 cx,
4194 )
4195 });
4196
4197 let result = task.await;
4198 assert!(
4199 result.is_err(),
4200 "expected command to be blocked by deny rule"
4201 );
4202 let err_msg = result.unwrap_err().to_lowercase();
4203 assert!(
4204 err_msg.contains("blocked"),
4205 "error should mention the command was blocked"
4206 );
4207 }
4208
4209 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4210 {
4211 let environment = Rc::new(cx.update(|cx| {
4212 FakeThreadEnvironment::default()
4213 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4214 }));
4215
4216 cx.update(|cx| {
4217 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4218 settings.tool_permissions.tools.insert(
4219 TerminalTool::NAME.into(),
4220 agent_settings::ToolRules {
4221 default: Some(settings::ToolPermissionMode::Deny),
4222 always_allow: vec![
4223 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4224 ],
4225 always_deny: vec![],
4226 always_confirm: vec![],
4227 invalid_patterns: vec![],
4228 },
4229 );
4230 agent_settings::AgentSettings::override_global(settings, cx);
4231 });
4232
4233 #[allow(clippy::arc_with_non_send_sync)]
4234 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4235 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4236
4237 let task = cx.update(|cx| {
4238 tool.run(
4239 ToolInput::resolved(crate::TerminalToolInput {
4240 command: "echo hello".to_string(),
4241 cd: ".".to_string(),
4242 timeout_ms: None,
4243 }),
4244 event_stream,
4245 cx,
4246 )
4247 });
4248
4249 let update = rx.expect_update_fields().await;
4250 assert!(
4251 update.content.iter().any(|blocks| {
4252 blocks
4253 .iter()
4254 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4255 }),
4256 "expected terminal content (allow rule should skip confirmation and override default deny)"
4257 );
4258
4259 let result = task.await;
4260 assert!(
4261 result.is_ok(),
4262 "expected command to succeed without confirmation"
4263 );
4264 }
4265
4266 // Test 3: global default: allow does NOT override always_confirm patterns
4267 {
4268 let environment = Rc::new(cx.update(|cx| {
4269 FakeThreadEnvironment::default()
4270 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4271 }));
4272
4273 cx.update(|cx| {
4274 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4275 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4276 settings.tool_permissions.tools.insert(
4277 TerminalTool::NAME.into(),
4278 agent_settings::ToolRules {
4279 default: Some(settings::ToolPermissionMode::Allow),
4280 always_allow: vec![],
4281 always_deny: vec![],
4282 always_confirm: vec![
4283 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4284 ],
4285 invalid_patterns: vec![],
4286 },
4287 );
4288 agent_settings::AgentSettings::override_global(settings, cx);
4289 });
4290
4291 #[allow(clippy::arc_with_non_send_sync)]
4292 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4293 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4294
4295 let _task = cx.update(|cx| {
4296 tool.run(
4297 ToolInput::resolved(crate::TerminalToolInput {
4298 command: "sudo rm file".to_string(),
4299 cd: ".".to_string(),
4300 timeout_ms: None,
4301 }),
4302 event_stream,
4303 cx,
4304 )
4305 });
4306
4307 // With global default: allow, confirm patterns are still respected
4308 // The expect_authorization() call will panic if no authorization is requested,
4309 // which validates that the confirm pattern still triggers confirmation
4310 let _auth = rx.expect_authorization().await;
4311
4312 drop(_task);
4313 }
4314
4315 // Test 4: tool-specific default: deny is respected even with global default: allow
4316 {
4317 let environment = Rc::new(cx.update(|cx| {
4318 FakeThreadEnvironment::default()
4319 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4320 }));
4321
4322 cx.update(|cx| {
4323 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4324 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4325 settings.tool_permissions.tools.insert(
4326 TerminalTool::NAME.into(),
4327 agent_settings::ToolRules {
4328 default: Some(settings::ToolPermissionMode::Deny),
4329 always_allow: vec![],
4330 always_deny: vec![],
4331 always_confirm: vec![],
4332 invalid_patterns: vec![],
4333 },
4334 );
4335 agent_settings::AgentSettings::override_global(settings, cx);
4336 });
4337
4338 #[allow(clippy::arc_with_non_send_sync)]
4339 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4340 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4341
4342 let task = cx.update(|cx| {
4343 tool.run(
4344 ToolInput::resolved(crate::TerminalToolInput {
4345 command: "echo hello".to_string(),
4346 cd: ".".to_string(),
4347 timeout_ms: None,
4348 }),
4349 event_stream,
4350 cx,
4351 )
4352 });
4353
4354 // tool-specific default: deny is respected even with global default: allow
4355 let result = task.await;
4356 assert!(
4357 result.is_err(),
4358 "expected command to be blocked by tool-specific deny default"
4359 );
4360 let err_msg = result.unwrap_err().to_lowercase();
4361 assert!(
4362 err_msg.contains("disabled"),
4363 "error should mention the tool is disabled, got: {err_msg}"
4364 );
4365 }
4366}
4367
4368#[gpui::test]
4369async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4370 init_test(cx);
4371 cx.update(|cx| {
4372 LanguageModelRegistry::test(cx);
4373 });
4374 cx.update(|cx| {
4375 cx.update_flags(true, vec!["subagents".to_string()]);
4376 });
4377
4378 let fs = FakeFs::new(cx.executor());
4379 fs.insert_tree(
4380 "/",
4381 json!({
4382 "a": {
4383 "b.md": "Lorem"
4384 }
4385 }),
4386 )
4387 .await;
4388 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4389 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4390 let agent = NativeAgent::new(
4391 project.clone(),
4392 thread_store.clone(),
4393 Templates::new(),
4394 None,
4395 fs.clone(),
4396 &mut cx.to_async(),
4397 )
4398 .await
4399 .unwrap();
4400 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4401
4402 let acp_thread = cx
4403 .update(|cx| {
4404 connection
4405 .clone()
4406 .new_session(project.clone(), Path::new(""), cx)
4407 })
4408 .await
4409 .unwrap();
4410 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4411 let thread = agent.read_with(cx, |agent, _| {
4412 agent.sessions.get(&session_id).unwrap().thread.clone()
4413 });
4414 let model = Arc::new(FakeLanguageModel::default());
4415
4416 // Ensure empty threads are not saved, even if they get mutated.
4417 thread.update(cx, |thread, cx| {
4418 thread.set_model(model.clone(), cx);
4419 });
4420 cx.run_until_parked();
4421
4422 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4423 cx.run_until_parked();
4424 model.send_last_completion_stream_text_chunk("spawning subagent");
4425 let subagent_tool_input = SpawnAgentToolInput {
4426 label: "label".to_string(),
4427 message: "subagent task prompt".to_string(),
4428 session_id: None,
4429 };
4430 let subagent_tool_use = LanguageModelToolUse {
4431 id: "subagent_1".into(),
4432 name: SpawnAgentTool::NAME.into(),
4433 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4434 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4435 is_input_complete: true,
4436 thought_signature: None,
4437 };
4438 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4439 subagent_tool_use,
4440 ));
4441 model.end_last_completion_stream();
4442
4443 cx.run_until_parked();
4444
4445 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4446 thread
4447 .running_subagent_ids(cx)
4448 .get(0)
4449 .expect("subagent thread should be running")
4450 .clone()
4451 });
4452
4453 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4454 agent
4455 .sessions
4456 .get(&subagent_session_id)
4457 .expect("subagent session should exist")
4458 .acp_thread
4459 .clone()
4460 });
4461
4462 model.send_last_completion_stream_text_chunk("subagent task response");
4463 model.end_last_completion_stream();
4464
4465 cx.run_until_parked();
4466
4467 assert_eq!(
4468 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4469 indoc! {"
4470 ## User
4471
4472 subagent task prompt
4473
4474 ## Assistant
4475
4476 subagent task response
4477
4478 "}
4479 );
4480
4481 model.send_last_completion_stream_text_chunk("Response");
4482 model.end_last_completion_stream();
4483
4484 send.await.unwrap();
4485
4486 assert_eq!(
4487 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4488 indoc! {r#"
4489 ## User
4490
4491 Prompt
4492
4493 ## Assistant
4494
4495 spawning subagent
4496
4497 **Tool Call: label**
4498 Status: Completed
4499
4500 subagent task response
4501
4502 ## Assistant
4503
4504 Response
4505
4506 "#},
4507 );
4508}
4509
4510#[gpui::test]
4511async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppContext) {
4512 init_test(cx);
4513 cx.update(|cx| {
4514 LanguageModelRegistry::test(cx);
4515 });
4516 cx.update(|cx| {
4517 cx.update_flags(true, vec!["subagents".to_string()]);
4518 });
4519
4520 let fs = FakeFs::new(cx.executor());
4521 fs.insert_tree(
4522 "/",
4523 json!({
4524 "a": {
4525 "b.md": "Lorem"
4526 }
4527 }),
4528 )
4529 .await;
4530 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4531 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4532 let agent = NativeAgent::new(
4533 project.clone(),
4534 thread_store.clone(),
4535 Templates::new(),
4536 None,
4537 fs.clone(),
4538 &mut cx.to_async(),
4539 )
4540 .await
4541 .unwrap();
4542 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4543
4544 let acp_thread = cx
4545 .update(|cx| {
4546 connection
4547 .clone()
4548 .new_session(project.clone(), Path::new(""), cx)
4549 })
4550 .await
4551 .unwrap();
4552 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4553 let thread = agent.read_with(cx, |agent, _| {
4554 agent.sessions.get(&session_id).unwrap().thread.clone()
4555 });
4556 let model = Arc::new(FakeLanguageModel::default());
4557
4558 // Ensure empty threads are not saved, even if they get mutated.
4559 thread.update(cx, |thread, cx| {
4560 thread.set_model(model.clone(), cx);
4561 });
4562 cx.run_until_parked();
4563
4564 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4565 cx.run_until_parked();
4566 model.send_last_completion_stream_text_chunk("spawning subagent");
4567 let subagent_tool_input = SpawnAgentToolInput {
4568 label: "label".to_string(),
4569 message: "subagent task prompt".to_string(),
4570 session_id: None,
4571 };
4572 let subagent_tool_use = LanguageModelToolUse {
4573 id: "subagent_1".into(),
4574 name: SpawnAgentTool::NAME.into(),
4575 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4576 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4577 is_input_complete: true,
4578 thought_signature: None,
4579 };
4580 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4581 subagent_tool_use,
4582 ));
4583 model.end_last_completion_stream();
4584
4585 cx.run_until_parked();
4586
4587 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4588 thread
4589 .running_subagent_ids(cx)
4590 .get(0)
4591 .expect("subagent thread should be running")
4592 .clone()
4593 });
4594
4595 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4596 agent
4597 .sessions
4598 .get(&subagent_session_id)
4599 .expect("subagent session should exist")
4600 .acp_thread
4601 .clone()
4602 });
4603
4604 model.send_last_completion_stream_text_chunk("subagent task response 1");
4605 model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
4606 text: "thinking more about the subagent task".into(),
4607 signature: None,
4608 });
4609 model.send_last_completion_stream_text_chunk("subagent task response 2");
4610 model.end_last_completion_stream();
4611
4612 cx.run_until_parked();
4613
4614 assert_eq!(
4615 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4616 indoc! {"
4617 ## User
4618
4619 subagent task prompt
4620
4621 ## Assistant
4622
4623 subagent task response 1
4624
4625 <thinking>
4626 thinking more about the subagent task
4627 </thinking>
4628
4629 subagent task response 2
4630
4631 "}
4632 );
4633
4634 model.send_last_completion_stream_text_chunk("Response");
4635 model.end_last_completion_stream();
4636
4637 send.await.unwrap();
4638
4639 assert_eq!(
4640 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4641 indoc! {r#"
4642 ## User
4643
4644 Prompt
4645
4646 ## Assistant
4647
4648 spawning subagent
4649
4650 **Tool Call: label**
4651 Status: Completed
4652
4653 subagent task response 1
4654
4655 subagent task response 2
4656
4657 ## Assistant
4658
4659 Response
4660
4661 "#},
4662 );
4663}
4664
4665#[gpui::test]
4666async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4667 init_test(cx);
4668 cx.update(|cx| {
4669 LanguageModelRegistry::test(cx);
4670 });
4671 cx.update(|cx| {
4672 cx.update_flags(true, vec!["subagents".to_string()]);
4673 });
4674
4675 let fs = FakeFs::new(cx.executor());
4676 fs.insert_tree(
4677 "/",
4678 json!({
4679 "a": {
4680 "b.md": "Lorem"
4681 }
4682 }),
4683 )
4684 .await;
4685 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4686 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4687 let agent = NativeAgent::new(
4688 project.clone(),
4689 thread_store.clone(),
4690 Templates::new(),
4691 None,
4692 fs.clone(),
4693 &mut cx.to_async(),
4694 )
4695 .await
4696 .unwrap();
4697 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4698
4699 let acp_thread = cx
4700 .update(|cx| {
4701 connection
4702 .clone()
4703 .new_session(project.clone(), Path::new(""), cx)
4704 })
4705 .await
4706 .unwrap();
4707 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4708 let thread = agent.read_with(cx, |agent, _| {
4709 agent.sessions.get(&session_id).unwrap().thread.clone()
4710 });
4711 let model = Arc::new(FakeLanguageModel::default());
4712
4713 // Ensure empty threads are not saved, even if they get mutated.
4714 thread.update(cx, |thread, cx| {
4715 thread.set_model(model.clone(), cx);
4716 });
4717 cx.run_until_parked();
4718
4719 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4720 cx.run_until_parked();
4721 model.send_last_completion_stream_text_chunk("spawning subagent");
4722 let subagent_tool_input = SpawnAgentToolInput {
4723 label: "label".to_string(),
4724 message: "subagent task prompt".to_string(),
4725 session_id: None,
4726 };
4727 let subagent_tool_use = LanguageModelToolUse {
4728 id: "subagent_1".into(),
4729 name: SpawnAgentTool::NAME.into(),
4730 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4731 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4732 is_input_complete: true,
4733 thought_signature: None,
4734 };
4735 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4736 subagent_tool_use,
4737 ));
4738 model.end_last_completion_stream();
4739
4740 cx.run_until_parked();
4741
4742 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4743 thread
4744 .running_subagent_ids(cx)
4745 .get(0)
4746 .expect("subagent thread should be running")
4747 .clone()
4748 });
4749 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4750 agent
4751 .sessions
4752 .get(&subagent_session_id)
4753 .expect("subagent session should exist")
4754 .acp_thread
4755 .clone()
4756 });
4757
4758 // model.send_last_completion_stream_text_chunk("subagent task response");
4759 // model.end_last_completion_stream();
4760
4761 // cx.run_until_parked();
4762
4763 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4764
4765 cx.run_until_parked();
4766
4767 send.await.unwrap();
4768
4769 acp_thread.read_with(cx, |thread, cx| {
4770 assert_eq!(thread.status(), ThreadStatus::Idle);
4771 assert_eq!(
4772 thread.to_markdown(cx),
4773 indoc! {"
4774 ## User
4775
4776 Prompt
4777
4778 ## Assistant
4779
4780 spawning subagent
4781
4782 **Tool Call: label**
4783 Status: Canceled
4784
4785 "}
4786 );
4787 });
4788 subagent_acp_thread.read_with(cx, |thread, cx| {
4789 assert_eq!(thread.status(), ThreadStatus::Idle);
4790 assert_eq!(
4791 thread.to_markdown(cx),
4792 indoc! {"
4793 ## User
4794
4795 subagent task prompt
4796
4797 "}
4798 );
4799 });
4800}
4801
4802#[gpui::test]
4803async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
4804 init_test(cx);
4805 cx.update(|cx| {
4806 LanguageModelRegistry::test(cx);
4807 });
4808 cx.update(|cx| {
4809 cx.update_flags(true, vec!["subagents".to_string()]);
4810 });
4811
4812 let fs = FakeFs::new(cx.executor());
4813 fs.insert_tree(
4814 "/",
4815 json!({
4816 "a": {
4817 "b.md": "Lorem"
4818 }
4819 }),
4820 )
4821 .await;
4822 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4823 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4824 let agent = NativeAgent::new(
4825 project.clone(),
4826 thread_store.clone(),
4827 Templates::new(),
4828 None,
4829 fs.clone(),
4830 &mut cx.to_async(),
4831 )
4832 .await
4833 .unwrap();
4834 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4835
4836 let acp_thread = cx
4837 .update(|cx| {
4838 connection
4839 .clone()
4840 .new_session(project.clone(), Path::new(""), cx)
4841 })
4842 .await
4843 .unwrap();
4844 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4845 let thread = agent.read_with(cx, |agent, _| {
4846 agent.sessions.get(&session_id).unwrap().thread.clone()
4847 });
4848 let model = Arc::new(FakeLanguageModel::default());
4849
4850 thread.update(cx, |thread, cx| {
4851 thread.set_model(model.clone(), cx);
4852 });
4853 cx.run_until_parked();
4854
4855 // === First turn: create subagent ===
4856 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
4857 cx.run_until_parked();
4858 model.send_last_completion_stream_text_chunk("spawning subagent");
4859 let subagent_tool_input = SpawnAgentToolInput {
4860 label: "initial task".to_string(),
4861 message: "do the first task".to_string(),
4862 session_id: None,
4863 };
4864 let subagent_tool_use = LanguageModelToolUse {
4865 id: "subagent_1".into(),
4866 name: SpawnAgentTool::NAME.into(),
4867 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4868 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4869 is_input_complete: true,
4870 thought_signature: None,
4871 };
4872 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4873 subagent_tool_use,
4874 ));
4875 model.end_last_completion_stream();
4876
4877 cx.run_until_parked();
4878
4879 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4880 thread
4881 .running_subagent_ids(cx)
4882 .get(0)
4883 .expect("subagent thread should be running")
4884 .clone()
4885 });
4886
4887 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4888 agent
4889 .sessions
4890 .get(&subagent_session_id)
4891 .expect("subagent session should exist")
4892 .acp_thread
4893 .clone()
4894 });
4895
4896 // Subagent responds
4897 model.send_last_completion_stream_text_chunk("first task response");
4898 model.end_last_completion_stream();
4899
4900 cx.run_until_parked();
4901
4902 // Parent model responds to complete first turn
4903 model.send_last_completion_stream_text_chunk("First response");
4904 model.end_last_completion_stream();
4905
4906 send.await.unwrap();
4907
4908 // Verify subagent is no longer running
4909 thread.read_with(cx, |thread, cx| {
4910 assert!(
4911 thread.running_subagent_ids(cx).is_empty(),
4912 "subagent should not be running after completion"
4913 );
4914 });
4915
4916 // === Second turn: resume subagent with session_id ===
4917 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
4918 cx.run_until_parked();
4919 model.send_last_completion_stream_text_chunk("resuming subagent");
4920 let resume_tool_input = SpawnAgentToolInput {
4921 label: "follow-up task".to_string(),
4922 message: "do the follow-up task".to_string(),
4923 session_id: Some(subagent_session_id.clone()),
4924 };
4925 let resume_tool_use = LanguageModelToolUse {
4926 id: "subagent_2".into(),
4927 name: SpawnAgentTool::NAME.into(),
4928 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
4929 input: serde_json::to_value(&resume_tool_input).unwrap(),
4930 is_input_complete: true,
4931 thought_signature: None,
4932 };
4933 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
4934 model.end_last_completion_stream();
4935
4936 cx.run_until_parked();
4937
4938 // Subagent should be running again with the same session
4939 thread.read_with(cx, |thread, cx| {
4940 let running = thread.running_subagent_ids(cx);
4941 assert_eq!(running.len(), 1, "subagent should be running");
4942 assert_eq!(running[0], subagent_session_id, "should be same session");
4943 });
4944
4945 // Subagent responds to follow-up
4946 model.send_last_completion_stream_text_chunk("follow-up task response");
4947 model.end_last_completion_stream();
4948
4949 cx.run_until_parked();
4950
4951 // Parent model responds to complete second turn
4952 model.send_last_completion_stream_text_chunk("Second response");
4953 model.end_last_completion_stream();
4954
4955 send2.await.unwrap();
4956
4957 // Verify subagent is no longer running
4958 thread.read_with(cx, |thread, cx| {
4959 assert!(
4960 thread.running_subagent_ids(cx).is_empty(),
4961 "subagent should not be running after resume completion"
4962 );
4963 });
4964
4965 // Verify the subagent's acp thread has both conversation turns
4966 assert_eq!(
4967 subagent_acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4968 indoc! {"
4969 ## User
4970
4971 do the first task
4972
4973 ## Assistant
4974
4975 first task response
4976
4977 ## User
4978
4979 do the follow-up task
4980
4981 ## Assistant
4982
4983 follow-up task response
4984
4985 "}
4986 );
4987}
4988
4989#[gpui::test]
4990async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4991 init_test(cx);
4992
4993 cx.update(|cx| {
4994 cx.update_flags(true, vec!["subagents".to_string()]);
4995 });
4996
4997 let fs = FakeFs::new(cx.executor());
4998 fs.insert_tree(path!("/test"), json!({})).await;
4999 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5000 let project_context = cx.new(|_cx| ProjectContext::default());
5001 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5002 let context_server_registry =
5003 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5004 let model = Arc::new(FakeLanguageModel::default());
5005
5006 let environment = Rc::new(cx.update(|cx| {
5007 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
5008 }));
5009
5010 let thread = cx.new(|cx| {
5011 let mut thread = Thread::new(
5012 project.clone(),
5013 project_context,
5014 context_server_registry,
5015 Templates::new(),
5016 Some(model),
5017 cx,
5018 );
5019 thread.add_default_tools(environment, cx);
5020 thread
5021 });
5022
5023 thread.read_with(cx, |thread, _| {
5024 assert!(
5025 thread.has_registered_tool(SpawnAgentTool::NAME),
5026 "subagent tool should be present when feature flag is enabled"
5027 );
5028 });
5029}
5030
5031#[gpui::test]
5032async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
5033 init_test(cx);
5034
5035 cx.update(|cx| {
5036 cx.update_flags(true, vec!["subagents".to_string()]);
5037 });
5038
5039 let fs = FakeFs::new(cx.executor());
5040 fs.insert_tree(path!("/test"), json!({})).await;
5041 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5042 let project_context = cx.new(|_cx| ProjectContext::default());
5043 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5044 let context_server_registry =
5045 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5046 let model = Arc::new(FakeLanguageModel::default());
5047
5048 let parent_thread = cx.new(|cx| {
5049 Thread::new(
5050 project.clone(),
5051 project_context,
5052 context_server_registry,
5053 Templates::new(),
5054 Some(model.clone()),
5055 cx,
5056 )
5057 });
5058
5059 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
5060 subagent_thread.read_with(cx, |subagent_thread, cx| {
5061 assert!(subagent_thread.is_subagent());
5062 assert_eq!(subagent_thread.depth(), 1);
5063 assert_eq!(
5064 subagent_thread.model().map(|model| model.id()),
5065 Some(model.id())
5066 );
5067 assert_eq!(
5068 subagent_thread.parent_thread_id(),
5069 Some(parent_thread.read(cx).id().clone())
5070 );
5071 });
5072}
5073
5074#[gpui::test]
5075async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
5076 init_test(cx);
5077
5078 cx.update(|cx| {
5079 cx.update_flags(true, vec!["subagents".to_string()]);
5080 });
5081
5082 let fs = FakeFs::new(cx.executor());
5083 fs.insert_tree(path!("/test"), json!({})).await;
5084 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5085 let project_context = cx.new(|_cx| ProjectContext::default());
5086 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5087 let context_server_registry =
5088 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5089 let model = Arc::new(FakeLanguageModel::default());
5090 let environment = Rc::new(cx.update(|cx| {
5091 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
5092 }));
5093
5094 let deep_parent_thread = cx.new(|cx| {
5095 let mut thread = Thread::new(
5096 project.clone(),
5097 project_context,
5098 context_server_registry,
5099 Templates::new(),
5100 Some(model.clone()),
5101 cx,
5102 );
5103 thread.set_subagent_context(SubagentContext {
5104 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
5105 depth: MAX_SUBAGENT_DEPTH - 1,
5106 });
5107 thread
5108 });
5109 let deep_subagent_thread = cx.new(|cx| {
5110 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
5111 thread.add_default_tools(environment, cx);
5112 thread
5113 });
5114
5115 deep_subagent_thread.read_with(cx, |thread, _| {
5116 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
5117 assert!(
5118 !thread.has_registered_tool(SpawnAgentTool::NAME),
5119 "subagent tool should not be present at max depth"
5120 );
5121 });
5122}
5123
5124#[gpui::test]
5125async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
5126 init_test(cx);
5127
5128 cx.update(|cx| {
5129 cx.update_flags(true, vec!["subagents".to_string()]);
5130 });
5131
5132 let fs = FakeFs::new(cx.executor());
5133 fs.insert_tree(path!("/test"), json!({})).await;
5134 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5135 let project_context = cx.new(|_cx| ProjectContext::default());
5136 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5137 let context_server_registry =
5138 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5139 let model = Arc::new(FakeLanguageModel::default());
5140
5141 let parent = cx.new(|cx| {
5142 Thread::new(
5143 project.clone(),
5144 project_context.clone(),
5145 context_server_registry.clone(),
5146 Templates::new(),
5147 Some(model.clone()),
5148 cx,
5149 )
5150 });
5151
5152 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
5153
5154 parent.update(cx, |thread, _cx| {
5155 thread.register_running_subagent(subagent.downgrade());
5156 });
5157
5158 subagent
5159 .update(cx, |thread, cx| {
5160 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
5161 })
5162 .unwrap();
5163 cx.run_until_parked();
5164
5165 subagent.read_with(cx, |thread, _| {
5166 assert!(!thread.is_turn_complete(), "subagent should be running");
5167 });
5168
5169 parent.update(cx, |thread, cx| {
5170 thread.cancel(cx).detach();
5171 });
5172
5173 subagent.read_with(cx, |thread, _| {
5174 assert!(
5175 thread.is_turn_complete(),
5176 "subagent should be cancelled when parent cancels"
5177 );
5178 });
5179}
5180
5181#[gpui::test]
5182async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
5183 init_test(cx);
5184 cx.update(|cx| {
5185 LanguageModelRegistry::test(cx);
5186 });
5187 cx.update(|cx| {
5188 cx.update_flags(true, vec!["subagents".to_string()]);
5189 });
5190
5191 let fs = FakeFs::new(cx.executor());
5192 fs.insert_tree(
5193 "/",
5194 json!({
5195 "a": {
5196 "b.md": "Lorem"
5197 }
5198 }),
5199 )
5200 .await;
5201 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5202 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5203 let agent = NativeAgent::new(
5204 project.clone(),
5205 thread_store.clone(),
5206 Templates::new(),
5207 None,
5208 fs.clone(),
5209 &mut cx.to_async(),
5210 )
5211 .await
5212 .unwrap();
5213 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5214
5215 let acp_thread = cx
5216 .update(|cx| {
5217 connection
5218 .clone()
5219 .new_session(project.clone(), Path::new(""), cx)
5220 })
5221 .await
5222 .unwrap();
5223 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5224 let thread = agent.read_with(cx, |agent, _| {
5225 agent.sessions.get(&session_id).unwrap().thread.clone()
5226 });
5227 let model = Arc::new(FakeLanguageModel::default());
5228
5229 thread.update(cx, |thread, cx| {
5230 thread.set_model(model.clone(), cx);
5231 });
5232 cx.run_until_parked();
5233
5234 // Start the parent turn
5235 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5236 cx.run_until_parked();
5237 model.send_last_completion_stream_text_chunk("spawning subagent");
5238 let subagent_tool_input = SpawnAgentToolInput {
5239 label: "label".to_string(),
5240 message: "subagent task prompt".to_string(),
5241 session_id: None,
5242 };
5243 let subagent_tool_use = LanguageModelToolUse {
5244 id: "subagent_1".into(),
5245 name: SpawnAgentTool::NAME.into(),
5246 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5247 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5248 is_input_complete: true,
5249 thought_signature: None,
5250 };
5251 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5252 subagent_tool_use,
5253 ));
5254 model.end_last_completion_stream();
5255
5256 cx.run_until_parked();
5257
5258 // Verify subagent is running
5259 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5260 thread
5261 .running_subagent_ids(cx)
5262 .get(0)
5263 .expect("subagent thread should be running")
5264 .clone()
5265 });
5266
5267 // Send a usage update that crosses the warning threshold (80% of 1,000,000)
5268 model.send_last_completion_stream_text_chunk("partial work");
5269 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5270 TokenUsage {
5271 input_tokens: 850_000,
5272 output_tokens: 0,
5273 cache_creation_input_tokens: 0,
5274 cache_read_input_tokens: 0,
5275 },
5276 ));
5277
5278 cx.run_until_parked();
5279
5280 // The subagent should no longer be running
5281 thread.read_with(cx, |thread, cx| {
5282 assert!(
5283 thread.running_subagent_ids(cx).is_empty(),
5284 "subagent should be stopped after context window warning"
5285 );
5286 });
5287
5288 // The parent model should get a new completion request to respond to the tool error
5289 model.send_last_completion_stream_text_chunk("Response after warning");
5290 model.end_last_completion_stream();
5291
5292 send.await.unwrap();
5293
5294 // Verify the parent thread shows the warning error in the tool call
5295 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5296 assert!(
5297 markdown.contains("nearing the end of its context window"),
5298 "tool output should contain context window warning message, got:\n{markdown}"
5299 );
5300 assert!(
5301 markdown.contains("Status: Failed"),
5302 "tool call should have Failed status, got:\n{markdown}"
5303 );
5304
5305 // Verify the subagent session still exists (can be resumed)
5306 agent.read_with(cx, |agent, _cx| {
5307 assert!(
5308 agent.sessions.contains_key(&subagent_session_id),
5309 "subagent session should still exist for potential resume"
5310 );
5311 });
5312}
5313
5314#[gpui::test]
5315async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
5316 init_test(cx);
5317 cx.update(|cx| {
5318 LanguageModelRegistry::test(cx);
5319 });
5320 cx.update(|cx| {
5321 cx.update_flags(true, vec!["subagents".to_string()]);
5322 });
5323
5324 let fs = FakeFs::new(cx.executor());
5325 fs.insert_tree(
5326 "/",
5327 json!({
5328 "a": {
5329 "b.md": "Lorem"
5330 }
5331 }),
5332 )
5333 .await;
5334 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5335 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5336 let agent = NativeAgent::new(
5337 project.clone(),
5338 thread_store.clone(),
5339 Templates::new(),
5340 None,
5341 fs.clone(),
5342 &mut cx.to_async(),
5343 )
5344 .await
5345 .unwrap();
5346 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5347
5348 let acp_thread = cx
5349 .update(|cx| {
5350 connection
5351 .clone()
5352 .new_session(project.clone(), Path::new(""), cx)
5353 })
5354 .await
5355 .unwrap();
5356 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5357 let thread = agent.read_with(cx, |agent, _| {
5358 agent.sessions.get(&session_id).unwrap().thread.clone()
5359 });
5360 let model = Arc::new(FakeLanguageModel::default());
5361
5362 thread.update(cx, |thread, cx| {
5363 thread.set_model(model.clone(), cx);
5364 });
5365 cx.run_until_parked();
5366
5367 // === First turn: create subagent, trigger context window warning ===
5368 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
5369 cx.run_until_parked();
5370 model.send_last_completion_stream_text_chunk("spawning subagent");
5371 let subagent_tool_input = SpawnAgentToolInput {
5372 label: "initial task".to_string(),
5373 message: "do the first task".to_string(),
5374 session_id: None,
5375 };
5376 let subagent_tool_use = LanguageModelToolUse {
5377 id: "subagent_1".into(),
5378 name: SpawnAgentTool::NAME.into(),
5379 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5380 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5381 is_input_complete: true,
5382 thought_signature: None,
5383 };
5384 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5385 subagent_tool_use,
5386 ));
5387 model.end_last_completion_stream();
5388
5389 cx.run_until_parked();
5390
5391 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5392 thread
5393 .running_subagent_ids(cx)
5394 .get(0)
5395 .expect("subagent thread should be running")
5396 .clone()
5397 });
5398
5399 // Subagent sends a usage update that crosses the warning threshold.
5400 // This triggers Normal→Warning, stopping the subagent.
5401 model.send_last_completion_stream_text_chunk("partial work");
5402 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5403 TokenUsage {
5404 input_tokens: 850_000,
5405 output_tokens: 0,
5406 cache_creation_input_tokens: 0,
5407 cache_read_input_tokens: 0,
5408 },
5409 ));
5410
5411 cx.run_until_parked();
5412
5413 // Verify the first turn was stopped with a context window warning
5414 thread.read_with(cx, |thread, cx| {
5415 assert!(
5416 thread.running_subagent_ids(cx).is_empty(),
5417 "subagent should be stopped after context window warning"
5418 );
5419 });
5420
5421 // Parent model responds to complete first turn
5422 model.send_last_completion_stream_text_chunk("First response");
5423 model.end_last_completion_stream();
5424
5425 send.await.unwrap();
5426
5427 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5428 assert!(
5429 markdown.contains("nearing the end of its context window"),
5430 "first turn should have context window warning, got:\n{markdown}"
5431 );
5432
5433 // === Second turn: resume the same subagent (now at Warning level) ===
5434 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
5435 cx.run_until_parked();
5436 model.send_last_completion_stream_text_chunk("resuming subagent");
5437 let resume_tool_input = SpawnAgentToolInput {
5438 label: "follow-up task".to_string(),
5439 message: "do the follow-up task".to_string(),
5440 session_id: Some(subagent_session_id.clone()),
5441 };
5442 let resume_tool_use = LanguageModelToolUse {
5443 id: "subagent_2".into(),
5444 name: SpawnAgentTool::NAME.into(),
5445 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
5446 input: serde_json::to_value(&resume_tool_input).unwrap(),
5447 is_input_complete: true,
5448 thought_signature: None,
5449 };
5450 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
5451 model.end_last_completion_stream();
5452
5453 cx.run_until_parked();
5454
5455 // Subagent responds with tokens still at warning level (no worse).
5456 // Since ratio_before_prompt was already Warning, this should NOT
5457 // trigger the context window warning again.
5458 model.send_last_completion_stream_text_chunk("follow-up task response");
5459 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5460 TokenUsage {
5461 input_tokens: 870_000,
5462 output_tokens: 0,
5463 cache_creation_input_tokens: 0,
5464 cache_read_input_tokens: 0,
5465 },
5466 ));
5467 model.end_last_completion_stream();
5468
5469 cx.run_until_parked();
5470
5471 // Parent model responds to complete second turn
5472 model.send_last_completion_stream_text_chunk("Second response");
5473 model.end_last_completion_stream();
5474
5475 send2.await.unwrap();
5476
5477 // The resumed subagent should have completed normally since the ratio
5478 // didn't transition (it was Warning before and stayed at Warning)
5479 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5480 assert!(
5481 markdown.contains("follow-up task response"),
5482 "resumed subagent should complete normally when already at warning, got:\n{markdown}"
5483 );
5484 // The second tool call should NOT have a context window warning
5485 let second_tool_pos = markdown
5486 .find("follow-up task")
5487 .expect("should find follow-up tool call");
5488 let after_second_tool = &markdown[second_tool_pos..];
5489 assert!(
5490 !after_second_tool.contains("nearing the end of its context window"),
5491 "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
5492 );
5493}
5494
5495#[gpui::test]
5496async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
5497 init_test(cx);
5498 cx.update(|cx| {
5499 LanguageModelRegistry::test(cx);
5500 });
5501 cx.update(|cx| {
5502 cx.update_flags(true, vec!["subagents".to_string()]);
5503 });
5504
5505 let fs = FakeFs::new(cx.executor());
5506 fs.insert_tree(
5507 "/",
5508 json!({
5509 "a": {
5510 "b.md": "Lorem"
5511 }
5512 }),
5513 )
5514 .await;
5515 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5516 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5517 let agent = NativeAgent::new(
5518 project.clone(),
5519 thread_store.clone(),
5520 Templates::new(),
5521 None,
5522 fs.clone(),
5523 &mut cx.to_async(),
5524 )
5525 .await
5526 .unwrap();
5527 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5528
5529 let acp_thread = cx
5530 .update(|cx| {
5531 connection
5532 .clone()
5533 .new_session(project.clone(), Path::new(""), cx)
5534 })
5535 .await
5536 .unwrap();
5537 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5538 let thread = agent.read_with(cx, |agent, _| {
5539 agent.sessions.get(&session_id).unwrap().thread.clone()
5540 });
5541 let model = Arc::new(FakeLanguageModel::default());
5542
5543 thread.update(cx, |thread, cx| {
5544 thread.set_model(model.clone(), cx);
5545 });
5546 cx.run_until_parked();
5547
5548 // Start the parent turn
5549 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5550 cx.run_until_parked();
5551 model.send_last_completion_stream_text_chunk("spawning subagent");
5552 let subagent_tool_input = SpawnAgentToolInput {
5553 label: "label".to_string(),
5554 message: "subagent task prompt".to_string(),
5555 session_id: None,
5556 };
5557 let subagent_tool_use = LanguageModelToolUse {
5558 id: "subagent_1".into(),
5559 name: SpawnAgentTool::NAME.into(),
5560 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5561 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5562 is_input_complete: true,
5563 thought_signature: None,
5564 };
5565 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5566 subagent_tool_use,
5567 ));
5568 model.end_last_completion_stream();
5569
5570 cx.run_until_parked();
5571
5572 // Verify subagent is running
5573 thread.read_with(cx, |thread, cx| {
5574 assert!(
5575 !thread.running_subagent_ids(cx).is_empty(),
5576 "subagent should be running"
5577 );
5578 });
5579
5580 // The subagent's model returns a non-retryable error
5581 model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
5582 tokens: None,
5583 });
5584
5585 cx.run_until_parked();
5586
5587 // The subagent should no longer be running
5588 thread.read_with(cx, |thread, cx| {
5589 assert!(
5590 thread.running_subagent_ids(cx).is_empty(),
5591 "subagent should not be running after error"
5592 );
5593 });
5594
5595 // The parent model should get a new completion request to respond to the tool error
5596 model.send_last_completion_stream_text_chunk("Response after error");
5597 model.end_last_completion_stream();
5598
5599 send.await.unwrap();
5600
5601 // Verify the parent thread shows the error in the tool call
5602 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5603 assert!(
5604 markdown.contains("Status: Failed"),
5605 "tool call should have Failed status after model error, got:\n{markdown}"
5606 );
5607}
5608
5609#[gpui::test]
5610async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5611 init_test(cx);
5612
5613 let fs = FakeFs::new(cx.executor());
5614 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5615 .await;
5616 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5617
5618 cx.update(|cx| {
5619 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5620 settings.tool_permissions.tools.insert(
5621 EditFileTool::NAME.into(),
5622 agent_settings::ToolRules {
5623 default: Some(settings::ToolPermissionMode::Allow),
5624 always_allow: vec![],
5625 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5626 always_confirm: vec![],
5627 invalid_patterns: vec![],
5628 },
5629 );
5630 agent_settings::AgentSettings::override_global(settings, cx);
5631 });
5632
5633 let context_server_registry =
5634 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5635 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5636 let templates = crate::Templates::new();
5637 let thread = cx.new(|cx| {
5638 crate::Thread::new(
5639 project.clone(),
5640 cx.new(|_cx| prompt_store::ProjectContext::default()),
5641 context_server_registry,
5642 templates.clone(),
5643 None,
5644 cx,
5645 )
5646 });
5647
5648 #[allow(clippy::arc_with_non_send_sync)]
5649 let tool = Arc::new(crate::EditFileTool::new(
5650 project.clone(),
5651 thread.downgrade(),
5652 language_registry,
5653 templates,
5654 ));
5655 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5656
5657 let task = cx.update(|cx| {
5658 tool.run(
5659 ToolInput::resolved(crate::EditFileToolInput {
5660 display_description: "Edit sensitive file".to_string(),
5661 path: "root/sensitive_config.txt".into(),
5662 mode: crate::EditFileMode::Edit,
5663 }),
5664 event_stream,
5665 cx,
5666 )
5667 });
5668
5669 let result = task.await;
5670 assert!(result.is_err(), "expected edit to be blocked");
5671 assert!(
5672 result.unwrap_err().to_string().contains("blocked"),
5673 "error should mention the edit was blocked"
5674 );
5675}
5676
5677#[gpui::test]
5678async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5679 init_test(cx);
5680
5681 let fs = FakeFs::new(cx.executor());
5682 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5683 .await;
5684 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5685
5686 cx.update(|cx| {
5687 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5688 settings.tool_permissions.tools.insert(
5689 DeletePathTool::NAME.into(),
5690 agent_settings::ToolRules {
5691 default: Some(settings::ToolPermissionMode::Allow),
5692 always_allow: vec![],
5693 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5694 always_confirm: vec![],
5695 invalid_patterns: vec![],
5696 },
5697 );
5698 agent_settings::AgentSettings::override_global(settings, cx);
5699 });
5700
5701 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5702
5703 #[allow(clippy::arc_with_non_send_sync)]
5704 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5705 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5706
5707 let task = cx.update(|cx| {
5708 tool.run(
5709 ToolInput::resolved(crate::DeletePathToolInput {
5710 path: "root/important_data.txt".to_string(),
5711 }),
5712 event_stream,
5713 cx,
5714 )
5715 });
5716
5717 let result = task.await;
5718 assert!(result.is_err(), "expected deletion to be blocked");
5719 assert!(
5720 result.unwrap_err().contains("blocked"),
5721 "error should mention the deletion was blocked"
5722 );
5723}
5724
5725#[gpui::test]
5726async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5727 init_test(cx);
5728
5729 let fs = FakeFs::new(cx.executor());
5730 fs.insert_tree(
5731 "/root",
5732 json!({
5733 "safe.txt": "content",
5734 "protected": {}
5735 }),
5736 )
5737 .await;
5738 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5739
5740 cx.update(|cx| {
5741 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5742 settings.tool_permissions.tools.insert(
5743 MovePathTool::NAME.into(),
5744 agent_settings::ToolRules {
5745 default: Some(settings::ToolPermissionMode::Allow),
5746 always_allow: vec![],
5747 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5748 always_confirm: vec![],
5749 invalid_patterns: vec![],
5750 },
5751 );
5752 agent_settings::AgentSettings::override_global(settings, cx);
5753 });
5754
5755 #[allow(clippy::arc_with_non_send_sync)]
5756 let tool = Arc::new(crate::MovePathTool::new(project));
5757 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5758
5759 let task = cx.update(|cx| {
5760 tool.run(
5761 ToolInput::resolved(crate::MovePathToolInput {
5762 source_path: "root/safe.txt".to_string(),
5763 destination_path: "root/protected/safe.txt".to_string(),
5764 }),
5765 event_stream,
5766 cx,
5767 )
5768 });
5769
5770 let result = task.await;
5771 assert!(
5772 result.is_err(),
5773 "expected move to be blocked due to destination path"
5774 );
5775 assert!(
5776 result.unwrap_err().contains("blocked"),
5777 "error should mention the move was blocked"
5778 );
5779}
5780
5781#[gpui::test]
5782async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5783 init_test(cx);
5784
5785 let fs = FakeFs::new(cx.executor());
5786 fs.insert_tree(
5787 "/root",
5788 json!({
5789 "secret.txt": "secret content",
5790 "public": {}
5791 }),
5792 )
5793 .await;
5794 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5795
5796 cx.update(|cx| {
5797 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5798 settings.tool_permissions.tools.insert(
5799 MovePathTool::NAME.into(),
5800 agent_settings::ToolRules {
5801 default: Some(settings::ToolPermissionMode::Allow),
5802 always_allow: vec![],
5803 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5804 always_confirm: vec![],
5805 invalid_patterns: vec![],
5806 },
5807 );
5808 agent_settings::AgentSettings::override_global(settings, cx);
5809 });
5810
5811 #[allow(clippy::arc_with_non_send_sync)]
5812 let tool = Arc::new(crate::MovePathTool::new(project));
5813 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5814
5815 let task = cx.update(|cx| {
5816 tool.run(
5817 ToolInput::resolved(crate::MovePathToolInput {
5818 source_path: "root/secret.txt".to_string(),
5819 destination_path: "root/public/not_secret.txt".to_string(),
5820 }),
5821 event_stream,
5822 cx,
5823 )
5824 });
5825
5826 let result = task.await;
5827 assert!(
5828 result.is_err(),
5829 "expected move to be blocked due to source path"
5830 );
5831 assert!(
5832 result.unwrap_err().contains("blocked"),
5833 "error should mention the move was blocked"
5834 );
5835}
5836
5837#[gpui::test]
5838async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5839 init_test(cx);
5840
5841 let fs = FakeFs::new(cx.executor());
5842 fs.insert_tree(
5843 "/root",
5844 json!({
5845 "confidential.txt": "confidential data",
5846 "dest": {}
5847 }),
5848 )
5849 .await;
5850 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5851
5852 cx.update(|cx| {
5853 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5854 settings.tool_permissions.tools.insert(
5855 CopyPathTool::NAME.into(),
5856 agent_settings::ToolRules {
5857 default: Some(settings::ToolPermissionMode::Allow),
5858 always_allow: vec![],
5859 always_deny: vec![
5860 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5861 ],
5862 always_confirm: vec![],
5863 invalid_patterns: vec![],
5864 },
5865 );
5866 agent_settings::AgentSettings::override_global(settings, cx);
5867 });
5868
5869 #[allow(clippy::arc_with_non_send_sync)]
5870 let tool = Arc::new(crate::CopyPathTool::new(project));
5871 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5872
5873 let task = cx.update(|cx| {
5874 tool.run(
5875 ToolInput::resolved(crate::CopyPathToolInput {
5876 source_path: "root/confidential.txt".to_string(),
5877 destination_path: "root/dest/copy.txt".to_string(),
5878 }),
5879 event_stream,
5880 cx,
5881 )
5882 });
5883
5884 let result = task.await;
5885 assert!(result.is_err(), "expected copy to be blocked");
5886 assert!(
5887 result.unwrap_err().contains("blocked"),
5888 "error should mention the copy was blocked"
5889 );
5890}
5891
5892#[gpui::test]
5893async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5894 init_test(cx);
5895
5896 let fs = FakeFs::new(cx.executor());
5897 fs.insert_tree(
5898 "/root",
5899 json!({
5900 "normal.txt": "normal content",
5901 "readonly": {
5902 "config.txt": "readonly content"
5903 }
5904 }),
5905 )
5906 .await;
5907 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5908
5909 cx.update(|cx| {
5910 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5911 settings.tool_permissions.tools.insert(
5912 SaveFileTool::NAME.into(),
5913 agent_settings::ToolRules {
5914 default: Some(settings::ToolPermissionMode::Allow),
5915 always_allow: vec![],
5916 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5917 always_confirm: vec![],
5918 invalid_patterns: vec![],
5919 },
5920 );
5921 agent_settings::AgentSettings::override_global(settings, cx);
5922 });
5923
5924 #[allow(clippy::arc_with_non_send_sync)]
5925 let tool = Arc::new(crate::SaveFileTool::new(project));
5926 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5927
5928 let task = cx.update(|cx| {
5929 tool.run(
5930 ToolInput::resolved(crate::SaveFileToolInput {
5931 paths: vec![
5932 std::path::PathBuf::from("root/normal.txt"),
5933 std::path::PathBuf::from("root/readonly/config.txt"),
5934 ],
5935 }),
5936 event_stream,
5937 cx,
5938 )
5939 });
5940
5941 let result = task.await;
5942 assert!(
5943 result.is_err(),
5944 "expected save to be blocked due to denied path"
5945 );
5946 assert!(
5947 result.unwrap_err().contains("blocked"),
5948 "error should mention the save was blocked"
5949 );
5950}
5951
5952#[gpui::test]
5953async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5954 init_test(cx);
5955
5956 let fs = FakeFs::new(cx.executor());
5957 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5958 .await;
5959 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5960
5961 cx.update(|cx| {
5962 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5963 settings.tool_permissions.tools.insert(
5964 SaveFileTool::NAME.into(),
5965 agent_settings::ToolRules {
5966 default: Some(settings::ToolPermissionMode::Allow),
5967 always_allow: vec![],
5968 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5969 always_confirm: vec![],
5970 invalid_patterns: vec![],
5971 },
5972 );
5973 agent_settings::AgentSettings::override_global(settings, cx);
5974 });
5975
5976 #[allow(clippy::arc_with_non_send_sync)]
5977 let tool = Arc::new(crate::SaveFileTool::new(project));
5978 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5979
5980 let task = cx.update(|cx| {
5981 tool.run(
5982 ToolInput::resolved(crate::SaveFileToolInput {
5983 paths: vec![std::path::PathBuf::from("root/config.secret")],
5984 }),
5985 event_stream,
5986 cx,
5987 )
5988 });
5989
5990 let result = task.await;
5991 assert!(result.is_err(), "expected save to be blocked");
5992 assert!(
5993 result.unwrap_err().contains("blocked"),
5994 "error should mention the save was blocked"
5995 );
5996}
5997
5998#[gpui::test]
5999async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
6000 init_test(cx);
6001
6002 cx.update(|cx| {
6003 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6004 settings.tool_permissions.tools.insert(
6005 WebSearchTool::NAME.into(),
6006 agent_settings::ToolRules {
6007 default: Some(settings::ToolPermissionMode::Allow),
6008 always_allow: vec![],
6009 always_deny: vec![
6010 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
6011 ],
6012 always_confirm: vec![],
6013 invalid_patterns: vec![],
6014 },
6015 );
6016 agent_settings::AgentSettings::override_global(settings, cx);
6017 });
6018
6019 #[allow(clippy::arc_with_non_send_sync)]
6020 let tool = Arc::new(crate::WebSearchTool);
6021 let (event_stream, _rx) = crate::ToolCallEventStream::test();
6022
6023 let input: crate::WebSearchToolInput =
6024 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
6025
6026 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6027
6028 let result = task.await;
6029 assert!(result.is_err(), "expected search to be blocked");
6030 match result.unwrap_err() {
6031 crate::WebSearchToolOutput::Error { error } => {
6032 assert!(
6033 error.contains("blocked"),
6034 "error should mention the search was blocked"
6035 );
6036 }
6037 other => panic!("expected Error variant, got: {other:?}"),
6038 }
6039}
6040
6041#[gpui::test]
6042async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
6043 init_test(cx);
6044
6045 let fs = FakeFs::new(cx.executor());
6046 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
6047 .await;
6048 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
6049
6050 cx.update(|cx| {
6051 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6052 settings.tool_permissions.tools.insert(
6053 EditFileTool::NAME.into(),
6054 agent_settings::ToolRules {
6055 default: Some(settings::ToolPermissionMode::Confirm),
6056 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
6057 always_deny: vec![],
6058 always_confirm: vec![],
6059 invalid_patterns: vec![],
6060 },
6061 );
6062 agent_settings::AgentSettings::override_global(settings, cx);
6063 });
6064
6065 let context_server_registry =
6066 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
6067 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
6068 let templates = crate::Templates::new();
6069 let thread = cx.new(|cx| {
6070 crate::Thread::new(
6071 project.clone(),
6072 cx.new(|_cx| prompt_store::ProjectContext::default()),
6073 context_server_registry,
6074 templates.clone(),
6075 None,
6076 cx,
6077 )
6078 });
6079
6080 #[allow(clippy::arc_with_non_send_sync)]
6081 let tool = Arc::new(crate::EditFileTool::new(
6082 project,
6083 thread.downgrade(),
6084 language_registry,
6085 templates,
6086 ));
6087 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6088
6089 let _task = cx.update(|cx| {
6090 tool.run(
6091 ToolInput::resolved(crate::EditFileToolInput {
6092 display_description: "Edit README".to_string(),
6093 path: "root/README.md".into(),
6094 mode: crate::EditFileMode::Edit,
6095 }),
6096 event_stream,
6097 cx,
6098 )
6099 });
6100
6101 cx.run_until_parked();
6102
6103 let event = rx.try_next();
6104 assert!(
6105 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
6106 "expected no authorization request for allowed .md file"
6107 );
6108}
6109
6110#[gpui::test]
6111async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
6112 init_test(cx);
6113
6114 let fs = FakeFs::new(cx.executor());
6115 fs.insert_tree(
6116 "/root",
6117 json!({
6118 ".zed": {
6119 "settings.json": "{}"
6120 },
6121 "README.md": "# Hello"
6122 }),
6123 )
6124 .await;
6125 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
6126
6127 cx.update(|cx| {
6128 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6129 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
6130 agent_settings::AgentSettings::override_global(settings, cx);
6131 });
6132
6133 let context_server_registry =
6134 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
6135 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
6136 let templates = crate::Templates::new();
6137 let thread = cx.new(|cx| {
6138 crate::Thread::new(
6139 project.clone(),
6140 cx.new(|_cx| prompt_store::ProjectContext::default()),
6141 context_server_registry,
6142 templates.clone(),
6143 None,
6144 cx,
6145 )
6146 });
6147
6148 #[allow(clippy::arc_with_non_send_sync)]
6149 let tool = Arc::new(crate::EditFileTool::new(
6150 project,
6151 thread.downgrade(),
6152 language_registry,
6153 templates,
6154 ));
6155
6156 // Editing a file inside .zed/ should still prompt even with global default: allow,
6157 // because local settings paths are sensitive and require confirmation regardless.
6158 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6159 let _task = cx.update(|cx| {
6160 tool.run(
6161 ToolInput::resolved(crate::EditFileToolInput {
6162 display_description: "Edit local settings".to_string(),
6163 path: "root/.zed/settings.json".into(),
6164 mode: crate::EditFileMode::Edit,
6165 }),
6166 event_stream,
6167 cx,
6168 )
6169 });
6170
6171 let _update = rx.expect_update_fields().await;
6172 let _auth = rx.expect_authorization().await;
6173}
6174
6175#[gpui::test]
6176async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
6177 init_test(cx);
6178
6179 cx.update(|cx| {
6180 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6181 settings.tool_permissions.tools.insert(
6182 FetchTool::NAME.into(),
6183 agent_settings::ToolRules {
6184 default: Some(settings::ToolPermissionMode::Allow),
6185 always_allow: vec![],
6186 always_deny: vec![
6187 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
6188 ],
6189 always_confirm: vec![],
6190 invalid_patterns: vec![],
6191 },
6192 );
6193 agent_settings::AgentSettings::override_global(settings, cx);
6194 });
6195
6196 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6197
6198 #[allow(clippy::arc_with_non_send_sync)]
6199 let tool = Arc::new(crate::FetchTool::new(http_client));
6200 let (event_stream, _rx) = crate::ToolCallEventStream::test();
6201
6202 let input: crate::FetchToolInput =
6203 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
6204
6205 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6206
6207 let result = task.await;
6208 assert!(result.is_err(), "expected fetch to be blocked");
6209 assert!(
6210 result.unwrap_err().contains("blocked"),
6211 "error should mention the fetch was blocked"
6212 );
6213}
6214
6215#[gpui::test]
6216async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
6217 init_test(cx);
6218
6219 cx.update(|cx| {
6220 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6221 settings.tool_permissions.tools.insert(
6222 FetchTool::NAME.into(),
6223 agent_settings::ToolRules {
6224 default: Some(settings::ToolPermissionMode::Confirm),
6225 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
6226 always_deny: vec![],
6227 always_confirm: vec![],
6228 invalid_patterns: vec![],
6229 },
6230 );
6231 agent_settings::AgentSettings::override_global(settings, cx);
6232 });
6233
6234 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6235
6236 #[allow(clippy::arc_with_non_send_sync)]
6237 let tool = Arc::new(crate::FetchTool::new(http_client));
6238 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6239
6240 let input: crate::FetchToolInput =
6241 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
6242
6243 let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6244
6245 cx.run_until_parked();
6246
6247 let event = rx.try_next();
6248 assert!(
6249 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
6250 "expected no authorization request for allowed docs.rs URL"
6251 );
6252}
6253
6254#[gpui::test]
6255async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
6256 init_test(cx);
6257 always_allow_tools(cx);
6258
6259 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6260 let fake_model = model.as_fake();
6261
6262 // Add a tool so we can simulate tool calls
6263 thread.update(cx, |thread, _cx| {
6264 thread.add_tool(EchoTool);
6265 });
6266
6267 // Start a turn by sending a message
6268 let mut events = thread
6269 .update(cx, |thread, cx| {
6270 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
6271 })
6272 .unwrap();
6273 cx.run_until_parked();
6274
6275 // Simulate the model making a tool call
6276 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6277 LanguageModelToolUse {
6278 id: "tool_1".into(),
6279 name: "echo".into(),
6280 raw_input: r#"{"text": "hello"}"#.into(),
6281 input: json!({"text": "hello"}),
6282 is_input_complete: true,
6283 thought_signature: None,
6284 },
6285 ));
6286 fake_model
6287 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
6288
6289 // Signal that a message is queued before ending the stream
6290 thread.update(cx, |thread, _cx| {
6291 thread.set_has_queued_message(true);
6292 });
6293
6294 // Now end the stream - tool will run, and the boundary check should see the queue
6295 fake_model.end_last_completion_stream();
6296
6297 // Collect all events until the turn stops
6298 let all_events = collect_events_until_stop(&mut events, cx).await;
6299
6300 // Verify we received the tool call event
6301 let tool_call_ids: Vec<_> = all_events
6302 .iter()
6303 .filter_map(|e| match e {
6304 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
6305 _ => None,
6306 })
6307 .collect();
6308 assert_eq!(
6309 tool_call_ids,
6310 vec!["tool_1"],
6311 "Should have received a tool call event for our echo tool"
6312 );
6313
6314 // The turn should have stopped with EndTurn
6315 let stop_reasons = stop_events(all_events);
6316 assert_eq!(
6317 stop_reasons,
6318 vec![acp::StopReason::EndTurn],
6319 "Turn should have ended after tool completion due to queued message"
6320 );
6321
6322 // Verify the queued message flag is still set
6323 thread.update(cx, |thread, _cx| {
6324 assert!(
6325 thread.has_queued_message(),
6326 "Should still have queued message flag set"
6327 );
6328 });
6329
6330 // Thread should be idle now
6331 thread.update(cx, |thread, _cx| {
6332 assert!(
6333 thread.is_turn_complete(),
6334 "Thread should not be running after turn ends"
6335 );
6336 });
6337}