1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
33};
34use pretty_assertions::assert_eq;
35use project::{
36 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
37};
38use prompt_store::ProjectContext;
39use reqwest_client::ReqwestClient;
40use schemars::JsonSchema;
41use serde::{Deserialize, Serialize};
42use serde_json::json;
43use settings::{Settings, SettingsStore};
44use std::{
45 path::Path,
46 pin::Pin,
47 rc::Rc,
48 sync::{
49 Arc,
50 atomic::{AtomicBool, Ordering},
51 },
52 time::Duration,
53};
54use util::path;
55
56mod edit_file_thread_test;
57mod test_tools;
58use test_tools::*;
59
60fn init_test(cx: &mut TestAppContext) {
61 cx.update(|cx| {
62 let settings_store = SettingsStore::test(cx);
63 cx.set_global(settings_store);
64 });
65}
66
67struct FakeTerminalHandle {
68 killed: Arc<AtomicBool>,
69 stopped_by_user: Arc<AtomicBool>,
70 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
71 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
72 output: acp::TerminalOutputResponse,
73 id: acp::TerminalId,
74}
75
76impl FakeTerminalHandle {
77 fn new_never_exits(cx: &mut App) -> Self {
78 let killed = Arc::new(AtomicBool::new(false));
79 let stopped_by_user = Arc::new(AtomicBool::new(false));
80
81 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
82
83 let wait_for_exit = cx
84 .spawn(async move |_cx| {
85 // Wait for the exit signal (sent when kill() is called)
86 let _ = exit_receiver.await;
87 acp::TerminalExitStatus::new()
88 })
89 .shared();
90
91 Self {
92 killed,
93 stopped_by_user,
94 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
95 wait_for_exit,
96 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
97 id: acp::TerminalId::new("fake_terminal".to_string()),
98 }
99 }
100
101 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
102 let killed = Arc::new(AtomicBool::new(false));
103 let stopped_by_user = Arc::new(AtomicBool::new(false));
104 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
105
106 let wait_for_exit = cx
107 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
108 .shared();
109
110 Self {
111 killed,
112 stopped_by_user,
113 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
114 wait_for_exit,
115 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
116 id: acp::TerminalId::new("fake_terminal".to_string()),
117 }
118 }
119
120 fn was_killed(&self) -> bool {
121 self.killed.load(Ordering::SeqCst)
122 }
123
124 fn set_stopped_by_user(&self, stopped: bool) {
125 self.stopped_by_user.store(stopped, Ordering::SeqCst);
126 }
127
128 fn signal_exit(&self) {
129 if let Some(sender) = self.exit_sender.borrow_mut().take() {
130 let _ = sender.send(());
131 }
132 }
133}
134
135impl crate::TerminalHandle for FakeTerminalHandle {
136 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
137 Ok(self.id.clone())
138 }
139
140 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
141 Ok(self.output.clone())
142 }
143
144 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
145 Ok(self.wait_for_exit.clone())
146 }
147
148 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
149 self.killed.store(true, Ordering::SeqCst);
150 self.signal_exit();
151 Ok(())
152 }
153
154 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
155 Ok(self.stopped_by_user.load(Ordering::SeqCst))
156 }
157}
158
159struct FakeSubagentHandle {
160 session_id: acp::SessionId,
161 wait_for_summary_task: Shared<Task<String>>,
162}
163
164impl SubagentHandle for FakeSubagentHandle {
165 fn id(&self) -> acp::SessionId {
166 self.session_id.clone()
167 }
168
169 fn wait_for_output(&self, cx: &AsyncApp) -> Task<Result<String>> {
170 let task = self.wait_for_summary_task.clone();
171 cx.background_spawn(async move { Ok(task.await) })
172 }
173}
174
175#[derive(Default)]
176struct FakeThreadEnvironment {
177 terminal_handle: Option<Rc<FakeTerminalHandle>>,
178 subagent_handle: Option<Rc<FakeSubagentHandle>>,
179}
180
181impl FakeThreadEnvironment {
182 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
183 Self {
184 terminal_handle: Some(terminal_handle.into()),
185 ..self
186 }
187 }
188}
189
190impl crate::ThreadEnvironment for FakeThreadEnvironment {
191 fn create_terminal(
192 &self,
193 _command: String,
194 _cwd: Option<std::path::PathBuf>,
195 _output_byte_limit: Option<u64>,
196 _cx: &mut AsyncApp,
197 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
198 let handle = self
199 .terminal_handle
200 .clone()
201 .expect("Terminal handle not available on FakeThreadEnvironment");
202 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
203 }
204
205 fn create_subagent(
206 &self,
207 _parent_thread: Entity<Thread>,
208 _label: String,
209 _initial_prompt: String,
210 _timeout_ms: Option<Duration>,
211 _cx: &mut App,
212 ) -> Result<Rc<dyn SubagentHandle>> {
213 Ok(self
214 .subagent_handle
215 .clone()
216 .expect("Subagent handle not available on FakeThreadEnvironment")
217 as Rc<dyn SubagentHandle>)
218 }
219}
220
221/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
222struct MultiTerminalEnvironment {
223 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
224}
225
226impl MultiTerminalEnvironment {
227 fn new() -> Self {
228 Self {
229 handles: std::cell::RefCell::new(Vec::new()),
230 }
231 }
232
233 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
234 self.handles.borrow().clone()
235 }
236}
237
238impl crate::ThreadEnvironment for MultiTerminalEnvironment {
239 fn create_terminal(
240 &self,
241 _command: String,
242 _cwd: Option<std::path::PathBuf>,
243 _output_byte_limit: Option<u64>,
244 cx: &mut AsyncApp,
245 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
246 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
247 self.handles.borrow_mut().push(handle.clone());
248 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
249 }
250
251 fn create_subagent(
252 &self,
253 _parent_thread: Entity<Thread>,
254 _label: String,
255 _initial_prompt: String,
256 _timeout: Option<Duration>,
257 _cx: &mut App,
258 ) -> Result<Rc<dyn SubagentHandle>> {
259 unimplemented!()
260 }
261}
262
263fn always_allow_tools(cx: &mut TestAppContext) {
264 cx.update(|cx| {
265 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
266 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
267 agent_settings::AgentSettings::override_global(settings, cx);
268 });
269}
270
271#[gpui::test]
272async fn test_echo(cx: &mut TestAppContext) {
273 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
274 let fake_model = model.as_fake();
275
276 let events = thread
277 .update(cx, |thread, cx| {
278 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
279 })
280 .unwrap();
281 cx.run_until_parked();
282 fake_model.send_last_completion_stream_text_chunk("Hello");
283 fake_model
284 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
285 fake_model.end_last_completion_stream();
286
287 let events = events.collect().await;
288 thread.update(cx, |thread, _cx| {
289 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
290 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
291 });
292 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
293}
294
295#[gpui::test]
296async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
297 init_test(cx);
298 always_allow_tools(cx);
299
300 let fs = FakeFs::new(cx.executor());
301 let project = Project::test(fs, [], cx).await;
302
303 let environment = Rc::new(cx.update(|cx| {
304 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
305 }));
306 let handle = environment.terminal_handle.clone().unwrap();
307
308 #[allow(clippy::arc_with_non_send_sync)]
309 let tool = Arc::new(crate::TerminalTool::new(project, environment));
310 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
311
312 let task = cx.update(|cx| {
313 tool.run(
314 crate::TerminalToolInput {
315 command: "sleep 1000".to_string(),
316 cd: ".".to_string(),
317 timeout_ms: Some(5),
318 },
319 event_stream,
320 cx,
321 )
322 });
323
324 let update = rx.expect_update_fields().await;
325 assert!(
326 update.content.iter().any(|blocks| {
327 blocks
328 .iter()
329 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
330 }),
331 "expected tool call update to include terminal content"
332 );
333
334 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
335
336 let deadline = std::time::Instant::now() + Duration::from_millis(500);
337 loop {
338 if let Some(result) = task_future.as_mut().now_or_never() {
339 let result = result.expect("terminal tool task should complete");
340
341 assert!(
342 handle.was_killed(),
343 "expected terminal handle to be killed on timeout"
344 );
345 assert!(
346 result.contains("partial output"),
347 "expected result to include terminal output, got: {result}"
348 );
349 return;
350 }
351
352 if std::time::Instant::now() >= deadline {
353 panic!("timed out waiting for terminal tool task to complete");
354 }
355
356 cx.run_until_parked();
357 cx.background_executor.timer(Duration::from_millis(1)).await;
358 }
359}
360
361#[gpui::test]
362#[ignore]
363async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
364 init_test(cx);
365 always_allow_tools(cx);
366
367 let fs = FakeFs::new(cx.executor());
368 let project = Project::test(fs, [], cx).await;
369
370 let environment = Rc::new(cx.update(|cx| {
371 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
372 }));
373 let handle = environment.terminal_handle.clone().unwrap();
374
375 #[allow(clippy::arc_with_non_send_sync)]
376 let tool = Arc::new(crate::TerminalTool::new(project, environment));
377 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
378
379 let _task = cx.update(|cx| {
380 tool.run(
381 crate::TerminalToolInput {
382 command: "sleep 1000".to_string(),
383 cd: ".".to_string(),
384 timeout_ms: None,
385 },
386 event_stream,
387 cx,
388 )
389 });
390
391 let update = rx.expect_update_fields().await;
392 assert!(
393 update.content.iter().any(|blocks| {
394 blocks
395 .iter()
396 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
397 }),
398 "expected tool call update to include terminal content"
399 );
400
401 cx.background_executor
402 .timer(Duration::from_millis(25))
403 .await;
404
405 assert!(
406 !handle.was_killed(),
407 "did not expect terminal handle to be killed without a timeout"
408 );
409}
410
411#[gpui::test]
412async fn test_thinking(cx: &mut TestAppContext) {
413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
414 let fake_model = model.as_fake();
415
416 let events = thread
417 .update(cx, |thread, cx| {
418 thread.send(
419 UserMessageId::new(),
420 [indoc! {"
421 Testing:
422
423 Generate a thinking step where you just think the word 'Think',
424 and have your final answer be 'Hello'
425 "}],
426 cx,
427 )
428 })
429 .unwrap();
430 cx.run_until_parked();
431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
432 text: "Think".to_string(),
433 signature: None,
434 });
435 fake_model.send_last_completion_stream_text_chunk("Hello");
436 fake_model
437 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
438 fake_model.end_last_completion_stream();
439
440 let events = events.collect().await;
441 thread.update(cx, |thread, _cx| {
442 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
443 assert_eq!(
444 thread.last_message().unwrap().to_markdown(),
445 indoc! {"
446 <think>Think</think>
447 Hello
448 "}
449 )
450 });
451 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
452}
453
454#[gpui::test]
455async fn test_system_prompt(cx: &mut TestAppContext) {
456 let ThreadTest {
457 model,
458 thread,
459 project_context,
460 ..
461 } = setup(cx, TestModel::Fake).await;
462 let fake_model = model.as_fake();
463
464 project_context.update(cx, |project_context, _cx| {
465 project_context.shell = "test-shell".into()
466 });
467 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
468 thread
469 .update(cx, |thread, cx| {
470 thread.send(UserMessageId::new(), ["abc"], cx)
471 })
472 .unwrap();
473 cx.run_until_parked();
474 let mut pending_completions = fake_model.pending_completions();
475 assert_eq!(
476 pending_completions.len(),
477 1,
478 "unexpected pending completions: {:?}",
479 pending_completions
480 );
481
482 let pending_completion = pending_completions.pop().unwrap();
483 assert_eq!(pending_completion.messages[0].role, Role::System);
484
485 let system_message = &pending_completion.messages[0];
486 let system_prompt = system_message.content[0].to_str().unwrap();
487 assert!(
488 system_prompt.contains("test-shell"),
489 "unexpected system message: {:?}",
490 system_message
491 );
492 assert!(
493 system_prompt.contains("## Fixing Diagnostics"),
494 "unexpected system message: {:?}",
495 system_message
496 );
497}
498
499#[gpui::test]
500async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
501 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
502 let fake_model = model.as_fake();
503
504 thread
505 .update(cx, |thread, cx| {
506 thread.send(UserMessageId::new(), ["abc"], cx)
507 })
508 .unwrap();
509 cx.run_until_parked();
510 let mut pending_completions = fake_model.pending_completions();
511 assert_eq!(
512 pending_completions.len(),
513 1,
514 "unexpected pending completions: {:?}",
515 pending_completions
516 );
517
518 let pending_completion = pending_completions.pop().unwrap();
519 assert_eq!(pending_completion.messages[0].role, Role::System);
520
521 let system_message = &pending_completion.messages[0];
522 let system_prompt = system_message.content[0].to_str().unwrap();
523 assert!(
524 !system_prompt.contains("## Tool Use"),
525 "unexpected system message: {:?}",
526 system_message
527 );
528 assert!(
529 !system_prompt.contains("## Fixing Diagnostics"),
530 "unexpected system message: {:?}",
531 system_message
532 );
533}
534
535#[gpui::test]
536async fn test_prompt_caching(cx: &mut TestAppContext) {
537 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
538 let fake_model = model.as_fake();
539
540 // Send initial user message and verify it's cached
541 thread
542 .update(cx, |thread, cx| {
543 thread.send(UserMessageId::new(), ["Message 1"], cx)
544 })
545 .unwrap();
546 cx.run_until_parked();
547
548 let completion = fake_model.pending_completions().pop().unwrap();
549 assert_eq!(
550 completion.messages[1..],
551 vec![LanguageModelRequestMessage {
552 role: Role::User,
553 content: vec!["Message 1".into()],
554 cache: true,
555 reasoning_details: None,
556 }]
557 );
558 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
559 "Response to Message 1".into(),
560 ));
561 fake_model.end_last_completion_stream();
562 cx.run_until_parked();
563
564 // Send another user message and verify only the latest is cached
565 thread
566 .update(cx, |thread, cx| {
567 thread.send(UserMessageId::new(), ["Message 2"], cx)
568 })
569 .unwrap();
570 cx.run_until_parked();
571
572 let completion = fake_model.pending_completions().pop().unwrap();
573 assert_eq!(
574 completion.messages[1..],
575 vec![
576 LanguageModelRequestMessage {
577 role: Role::User,
578 content: vec!["Message 1".into()],
579 cache: false,
580 reasoning_details: None,
581 },
582 LanguageModelRequestMessage {
583 role: Role::Assistant,
584 content: vec!["Response to Message 1".into()],
585 cache: false,
586 reasoning_details: None,
587 },
588 LanguageModelRequestMessage {
589 role: Role::User,
590 content: vec!["Message 2".into()],
591 cache: true,
592 reasoning_details: None,
593 }
594 ]
595 );
596 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
597 "Response to Message 2".into(),
598 ));
599 fake_model.end_last_completion_stream();
600 cx.run_until_parked();
601
602 // Simulate a tool call and verify that the latest tool result is cached
603 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
604 thread
605 .update(cx, |thread, cx| {
606 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
607 })
608 .unwrap();
609 cx.run_until_parked();
610
611 let tool_use = LanguageModelToolUse {
612 id: "tool_1".into(),
613 name: EchoTool::NAME.into(),
614 raw_input: json!({"text": "test"}).to_string(),
615 input: json!({"text": "test"}),
616 is_input_complete: true,
617 thought_signature: None,
618 };
619 fake_model
620 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
621 fake_model.end_last_completion_stream();
622 cx.run_until_parked();
623
624 let completion = fake_model.pending_completions().pop().unwrap();
625 let tool_result = LanguageModelToolResult {
626 tool_use_id: "tool_1".into(),
627 tool_name: EchoTool::NAME.into(),
628 is_error: false,
629 content: "test".into(),
630 output: Some("test".into()),
631 };
632 assert_eq!(
633 completion.messages[1..],
634 vec![
635 LanguageModelRequestMessage {
636 role: Role::User,
637 content: vec!["Message 1".into()],
638 cache: false,
639 reasoning_details: None,
640 },
641 LanguageModelRequestMessage {
642 role: Role::Assistant,
643 content: vec!["Response to Message 1".into()],
644 cache: false,
645 reasoning_details: None,
646 },
647 LanguageModelRequestMessage {
648 role: Role::User,
649 content: vec!["Message 2".into()],
650 cache: false,
651 reasoning_details: None,
652 },
653 LanguageModelRequestMessage {
654 role: Role::Assistant,
655 content: vec!["Response to Message 2".into()],
656 cache: false,
657 reasoning_details: None,
658 },
659 LanguageModelRequestMessage {
660 role: Role::User,
661 content: vec!["Use the echo tool".into()],
662 cache: false,
663 reasoning_details: None,
664 },
665 LanguageModelRequestMessage {
666 role: Role::Assistant,
667 content: vec![MessageContent::ToolUse(tool_use)],
668 cache: false,
669 reasoning_details: None,
670 },
671 LanguageModelRequestMessage {
672 role: Role::User,
673 content: vec![MessageContent::ToolResult(tool_result)],
674 cache: true,
675 reasoning_details: None,
676 }
677 ]
678 );
679}
680
681#[gpui::test]
682#[cfg_attr(not(feature = "e2e"), ignore)]
683async fn test_basic_tool_calls(cx: &mut TestAppContext) {
684 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
685
686 // Test a tool call that's likely to complete *before* streaming stops.
687 let events = thread
688 .update(cx, |thread, cx| {
689 thread.add_tool(EchoTool, None);
690 thread.send(
691 UserMessageId::new(),
692 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
693 cx,
694 )
695 })
696 .unwrap()
697 .collect()
698 .await;
699 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
700
701 // Test a tool calls that's likely to complete *after* streaming stops.
702 let events = thread
703 .update(cx, |thread, cx| {
704 thread.remove_tool(&EchoTool::NAME);
705 thread.add_tool(DelayTool, None);
706 thread.send(
707 UserMessageId::new(),
708 [
709 "Now call the delay tool with 200ms.",
710 "When the timer goes off, then you echo the output of the tool.",
711 ],
712 cx,
713 )
714 })
715 .unwrap()
716 .collect()
717 .await;
718 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
719 thread.update(cx, |thread, _cx| {
720 assert!(
721 thread
722 .last_message()
723 .unwrap()
724 .as_agent_message()
725 .unwrap()
726 .content
727 .iter()
728 .any(|content| {
729 if let AgentMessageContent::Text(text) = content {
730 text.contains("Ding")
731 } else {
732 false
733 }
734 }),
735 "{}",
736 thread.to_markdown()
737 );
738 });
739}
740
741#[gpui::test]
742#[cfg_attr(not(feature = "e2e"), ignore)]
743async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
744 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
745
746 // Test a tool call that's likely to complete *before* streaming stops.
747 let mut events = thread
748 .update(cx, |thread, cx| {
749 thread.add_tool(WordListTool, None);
750 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
751 })
752 .unwrap();
753
754 let mut saw_partial_tool_use = false;
755 while let Some(event) = events.next().await {
756 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
757 thread.update(cx, |thread, _cx| {
758 // Look for a tool use in the thread's last message
759 let message = thread.last_message().unwrap();
760 let agent_message = message.as_agent_message().unwrap();
761 let last_content = agent_message.content.last().unwrap();
762 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
763 assert_eq!(last_tool_use.name.as_ref(), "word_list");
764 if tool_call.status == acp::ToolCallStatus::Pending {
765 if !last_tool_use.is_input_complete
766 && last_tool_use.input.get("g").is_none()
767 {
768 saw_partial_tool_use = true;
769 }
770 } else {
771 last_tool_use
772 .input
773 .get("a")
774 .expect("'a' has streamed because input is now complete");
775 last_tool_use
776 .input
777 .get("g")
778 .expect("'g' has streamed because input is now complete");
779 }
780 } else {
781 panic!("last content should be a tool use");
782 }
783 });
784 }
785 }
786
787 assert!(
788 saw_partial_tool_use,
789 "should see at least one partially streamed tool use in the history"
790 );
791}
792
793#[gpui::test]
794async fn test_tool_authorization(cx: &mut TestAppContext) {
795 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
796 let fake_model = model.as_fake();
797
798 let mut events = thread
799 .update(cx, |thread, cx| {
800 thread.add_tool(ToolRequiringPermission, None);
801 thread.send(UserMessageId::new(), ["abc"], cx)
802 })
803 .unwrap();
804 cx.run_until_parked();
805 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
806 LanguageModelToolUse {
807 id: "tool_id_1".into(),
808 name: ToolRequiringPermission::NAME.into(),
809 raw_input: "{}".into(),
810 input: json!({}),
811 is_input_complete: true,
812 thought_signature: None,
813 },
814 ));
815 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
816 LanguageModelToolUse {
817 id: "tool_id_2".into(),
818 name: ToolRequiringPermission::NAME.into(),
819 raw_input: "{}".into(),
820 input: json!({}),
821 is_input_complete: true,
822 thought_signature: None,
823 },
824 ));
825 fake_model.end_last_completion_stream();
826 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
827 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
828
829 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
830 tool_call_auth_1
831 .response
832 .send(acp::PermissionOptionId::new("allow"))
833 .unwrap();
834 cx.run_until_parked();
835
836 // Reject the second - send "deny" option_id directly since Deny is now a button
837 tool_call_auth_2
838 .response
839 .send(acp::PermissionOptionId::new("deny"))
840 .unwrap();
841 cx.run_until_parked();
842
843 let completion = fake_model.pending_completions().pop().unwrap();
844 let message = completion.messages.last().unwrap();
845 assert_eq!(
846 message.content,
847 vec![
848 language_model::MessageContent::ToolResult(LanguageModelToolResult {
849 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
850 tool_name: ToolRequiringPermission::NAME.into(),
851 is_error: false,
852 content: "Allowed".into(),
853 output: Some("Allowed".into())
854 }),
855 language_model::MessageContent::ToolResult(LanguageModelToolResult {
856 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
857 tool_name: ToolRequiringPermission::NAME.into(),
858 is_error: true,
859 content: "Permission to run tool denied by user".into(),
860 output: Some("Permission to run tool denied by user".into())
861 })
862 ]
863 );
864
865 // Simulate yet another tool call.
866 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
867 LanguageModelToolUse {
868 id: "tool_id_3".into(),
869 name: ToolRequiringPermission::NAME.into(),
870 raw_input: "{}".into(),
871 input: json!({}),
872 is_input_complete: true,
873 thought_signature: None,
874 },
875 ));
876 fake_model.end_last_completion_stream();
877
878 // Respond by always allowing tools - send transformed option_id
879 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
880 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
881 tool_call_auth_3
882 .response
883 .send(acp::PermissionOptionId::new(
884 "always_allow:tool_requiring_permission",
885 ))
886 .unwrap();
887 cx.run_until_parked();
888 let completion = fake_model.pending_completions().pop().unwrap();
889 let message = completion.messages.last().unwrap();
890 assert_eq!(
891 message.content,
892 vec![language_model::MessageContent::ToolResult(
893 LanguageModelToolResult {
894 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
895 tool_name: ToolRequiringPermission::NAME.into(),
896 is_error: false,
897 content: "Allowed".into(),
898 output: Some("Allowed".into())
899 }
900 )]
901 );
902
903 // Simulate a final tool call, ensuring we don't trigger authorization.
904 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
905 LanguageModelToolUse {
906 id: "tool_id_4".into(),
907 name: ToolRequiringPermission::NAME.into(),
908 raw_input: "{}".into(),
909 input: json!({}),
910 is_input_complete: true,
911 thought_signature: None,
912 },
913 ));
914 fake_model.end_last_completion_stream();
915 cx.run_until_parked();
916 let completion = fake_model.pending_completions().pop().unwrap();
917 let message = completion.messages.last().unwrap();
918 assert_eq!(
919 message.content,
920 vec![language_model::MessageContent::ToolResult(
921 LanguageModelToolResult {
922 tool_use_id: "tool_id_4".into(),
923 tool_name: ToolRequiringPermission::NAME.into(),
924 is_error: false,
925 content: "Allowed".into(),
926 output: Some("Allowed".into())
927 }
928 )]
929 );
930}
931
932#[gpui::test]
933async fn test_tool_hallucination(cx: &mut TestAppContext) {
934 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
935 let fake_model = model.as_fake();
936
937 let mut events = thread
938 .update(cx, |thread, cx| {
939 thread.send(UserMessageId::new(), ["abc"], cx)
940 })
941 .unwrap();
942 cx.run_until_parked();
943 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
944 LanguageModelToolUse {
945 id: "tool_id_1".into(),
946 name: "nonexistent_tool".into(),
947 raw_input: "{}".into(),
948 input: json!({}),
949 is_input_complete: true,
950 thought_signature: None,
951 },
952 ));
953 fake_model.end_last_completion_stream();
954
955 let tool_call = expect_tool_call(&mut events).await;
956 assert_eq!(tool_call.title, "nonexistent_tool");
957 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
958 let update = expect_tool_call_update_fields(&mut events).await;
959 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
960}
961
962async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
963 let event = events
964 .next()
965 .await
966 .expect("no tool call authorization event received")
967 .unwrap();
968 match event {
969 ThreadEvent::ToolCall(tool_call) => tool_call,
970 event => {
971 panic!("Unexpected event {event:?}");
972 }
973 }
974}
975
976async fn expect_tool_call_update_fields(
977 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
978) -> acp::ToolCallUpdate {
979 let event = events
980 .next()
981 .await
982 .expect("no tool call authorization event received")
983 .unwrap();
984 match event {
985 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
986 event => {
987 panic!("Unexpected event {event:?}");
988 }
989 }
990}
991
992async fn next_tool_call_authorization(
993 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
994) -> ToolCallAuthorization {
995 loop {
996 let event = events
997 .next()
998 .await
999 .expect("no tool call authorization event received")
1000 .unwrap();
1001 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1002 let permission_kinds = tool_call_authorization
1003 .options
1004 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1005 .map(|option| option.kind);
1006 let allow_once = tool_call_authorization
1007 .options
1008 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1009 .map(|option| option.kind);
1010
1011 assert_eq!(
1012 permission_kinds,
1013 Some(acp::PermissionOptionKind::AllowAlways)
1014 );
1015 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1016 return tool_call_authorization;
1017 }
1018 }
1019}
1020
1021#[test]
1022fn test_permission_options_terminal_with_pattern() {
1023 let permission_options = ToolPermissionContext::new(
1024 TerminalTool::NAME,
1025 vec!["cargo build --release".to_string()],
1026 )
1027 .build_permission_options();
1028
1029 let PermissionOptions::Dropdown(choices) = permission_options else {
1030 panic!("Expected dropdown permission options");
1031 };
1032
1033 assert_eq!(choices.len(), 3);
1034 let labels: Vec<&str> = choices
1035 .iter()
1036 .map(|choice| choice.allow.name.as_ref())
1037 .collect();
1038 assert!(labels.contains(&"Always for terminal"));
1039 assert!(labels.contains(&"Always for `cargo build` commands"));
1040 assert!(labels.contains(&"Only this time"));
1041}
1042
1043#[test]
1044fn test_permission_options_terminal_command_with_flag_second_token() {
1045 let permission_options =
1046 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1047 .build_permission_options();
1048
1049 let PermissionOptions::Dropdown(choices) = permission_options else {
1050 panic!("Expected dropdown permission options");
1051 };
1052
1053 assert_eq!(choices.len(), 3);
1054 let labels: Vec<&str> = choices
1055 .iter()
1056 .map(|choice| choice.allow.name.as_ref())
1057 .collect();
1058 assert!(labels.contains(&"Always for terminal"));
1059 assert!(labels.contains(&"Always for `ls` commands"));
1060 assert!(labels.contains(&"Only this time"));
1061}
1062
1063#[test]
1064fn test_permission_options_terminal_single_word_command() {
1065 let permission_options =
1066 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1067 .build_permission_options();
1068
1069 let PermissionOptions::Dropdown(choices) = permission_options else {
1070 panic!("Expected dropdown permission options");
1071 };
1072
1073 assert_eq!(choices.len(), 3);
1074 let labels: Vec<&str> = choices
1075 .iter()
1076 .map(|choice| choice.allow.name.as_ref())
1077 .collect();
1078 assert!(labels.contains(&"Always for terminal"));
1079 assert!(labels.contains(&"Always for `whoami` commands"));
1080 assert!(labels.contains(&"Only this time"));
1081}
1082
1083#[test]
1084fn test_permission_options_edit_file_with_path_pattern() {
1085 let permission_options =
1086 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1087 .build_permission_options();
1088
1089 let PermissionOptions::Dropdown(choices) = permission_options else {
1090 panic!("Expected dropdown permission options");
1091 };
1092
1093 let labels: Vec<&str> = choices
1094 .iter()
1095 .map(|choice| choice.allow.name.as_ref())
1096 .collect();
1097 assert!(labels.contains(&"Always for edit file"));
1098 assert!(labels.contains(&"Always for `src/`"));
1099}
1100
1101#[test]
1102fn test_permission_options_fetch_with_domain_pattern() {
1103 let permission_options =
1104 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1105 .build_permission_options();
1106
1107 let PermissionOptions::Dropdown(choices) = permission_options else {
1108 panic!("Expected dropdown permission options");
1109 };
1110
1111 let labels: Vec<&str> = choices
1112 .iter()
1113 .map(|choice| choice.allow.name.as_ref())
1114 .collect();
1115 assert!(labels.contains(&"Always for fetch"));
1116 assert!(labels.contains(&"Always for `docs.rs`"));
1117}
1118
1119#[test]
1120fn test_permission_options_without_pattern() {
1121 let permission_options = ToolPermissionContext::new(
1122 TerminalTool::NAME,
1123 vec!["./deploy.sh --production".to_string()],
1124 )
1125 .build_permission_options();
1126
1127 let PermissionOptions::Dropdown(choices) = permission_options else {
1128 panic!("Expected dropdown permission options");
1129 };
1130
1131 assert_eq!(choices.len(), 2);
1132 let labels: Vec<&str> = choices
1133 .iter()
1134 .map(|choice| choice.allow.name.as_ref())
1135 .collect();
1136 assert!(labels.contains(&"Always for terminal"));
1137 assert!(labels.contains(&"Only this time"));
1138 assert!(!labels.iter().any(|label| label.contains("commands")));
1139}
1140
1141#[test]
1142fn test_permission_options_symlink_target_are_flat_once_only() {
1143 let permission_options =
1144 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1145 .build_permission_options();
1146
1147 let PermissionOptions::Flat(options) = permission_options else {
1148 panic!("Expected flat permission options for symlink target authorization");
1149 };
1150
1151 assert_eq!(options.len(), 2);
1152 assert!(options.iter().any(|option| {
1153 option.option_id.0.as_ref() == "allow"
1154 && option.kind == acp::PermissionOptionKind::AllowOnce
1155 }));
1156 assert!(options.iter().any(|option| {
1157 option.option_id.0.as_ref() == "deny"
1158 && option.kind == acp::PermissionOptionKind::RejectOnce
1159 }));
1160}
1161
1162#[test]
1163fn test_permission_option_ids_for_terminal() {
1164 let permission_options = ToolPermissionContext::new(
1165 TerminalTool::NAME,
1166 vec!["cargo build --release".to_string()],
1167 )
1168 .build_permission_options();
1169
1170 let PermissionOptions::Dropdown(choices) = permission_options else {
1171 panic!("Expected dropdown permission options");
1172 };
1173
1174 let allow_ids: Vec<String> = choices
1175 .iter()
1176 .map(|choice| choice.allow.option_id.0.to_string())
1177 .collect();
1178 let deny_ids: Vec<String> = choices
1179 .iter()
1180 .map(|choice| choice.deny.option_id.0.to_string())
1181 .collect();
1182
1183 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1184 assert!(allow_ids.contains(&"allow".to_string()));
1185 assert!(
1186 allow_ids
1187 .iter()
1188 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1189 "Missing allow pattern option"
1190 );
1191
1192 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1193 assert!(deny_ids.contains(&"deny".to_string()));
1194 assert!(
1195 deny_ids
1196 .iter()
1197 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1198 "Missing deny pattern option"
1199 );
1200}
1201
1202#[gpui::test]
1203#[cfg_attr(not(feature = "e2e"), ignore)]
1204async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1205 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1206
1207 // Test concurrent tool calls with different delay times
1208 let events = thread
1209 .update(cx, |thread, cx| {
1210 thread.add_tool(DelayTool, None);
1211 thread.send(
1212 UserMessageId::new(),
1213 [
1214 "Call the delay tool twice in the same message.",
1215 "Once with 100ms. Once with 300ms.",
1216 "When both timers are complete, describe the outputs.",
1217 ],
1218 cx,
1219 )
1220 })
1221 .unwrap()
1222 .collect()
1223 .await;
1224
1225 let stop_reasons = stop_events(events);
1226 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1227
1228 thread.update(cx, |thread, _cx| {
1229 let last_message = thread.last_message().unwrap();
1230 let agent_message = last_message.as_agent_message().unwrap();
1231 let text = agent_message
1232 .content
1233 .iter()
1234 .filter_map(|content| {
1235 if let AgentMessageContent::Text(text) = content {
1236 Some(text.as_str())
1237 } else {
1238 None
1239 }
1240 })
1241 .collect::<String>();
1242
1243 assert!(text.contains("Ding"));
1244 });
1245}
1246
1247#[gpui::test]
1248async fn test_profiles(cx: &mut TestAppContext) {
1249 let ThreadTest {
1250 model, thread, fs, ..
1251 } = setup(cx, TestModel::Fake).await;
1252 let fake_model = model.as_fake();
1253
1254 thread.update(cx, |thread, _cx| {
1255 thread.add_tool(DelayTool, None);
1256 thread.add_tool(EchoTool, None);
1257 thread.add_tool(InfiniteTool, None);
1258 });
1259
1260 // Override profiles and wait for settings to be loaded.
1261 fs.insert_file(
1262 paths::settings_file(),
1263 json!({
1264 "agent": {
1265 "profiles": {
1266 "test-1": {
1267 "name": "Test Profile 1",
1268 "tools": {
1269 EchoTool::NAME: true,
1270 DelayTool::NAME: true,
1271 }
1272 },
1273 "test-2": {
1274 "name": "Test Profile 2",
1275 "tools": {
1276 InfiniteTool::NAME: true,
1277 }
1278 }
1279 }
1280 }
1281 })
1282 .to_string()
1283 .into_bytes(),
1284 )
1285 .await;
1286 cx.run_until_parked();
1287
1288 // Test that test-1 profile (default) has echo and delay tools
1289 thread
1290 .update(cx, |thread, cx| {
1291 thread.set_profile(AgentProfileId("test-1".into()), cx);
1292 thread.send(UserMessageId::new(), ["test"], cx)
1293 })
1294 .unwrap();
1295 cx.run_until_parked();
1296
1297 let mut pending_completions = fake_model.pending_completions();
1298 assert_eq!(pending_completions.len(), 1);
1299 let completion = pending_completions.pop().unwrap();
1300 let tool_names: Vec<String> = completion
1301 .tools
1302 .iter()
1303 .map(|tool| tool.name.clone())
1304 .collect();
1305 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1306 fake_model.end_last_completion_stream();
1307
1308 // Switch to test-2 profile, and verify that it has only the infinite tool.
1309 thread
1310 .update(cx, |thread, cx| {
1311 thread.set_profile(AgentProfileId("test-2".into()), cx);
1312 thread.send(UserMessageId::new(), ["test2"], cx)
1313 })
1314 .unwrap();
1315 cx.run_until_parked();
1316 let mut pending_completions = fake_model.pending_completions();
1317 assert_eq!(pending_completions.len(), 1);
1318 let completion = pending_completions.pop().unwrap();
1319 let tool_names: Vec<String> = completion
1320 .tools
1321 .iter()
1322 .map(|tool| tool.name.clone())
1323 .collect();
1324 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1325}
1326
1327#[gpui::test]
1328async fn test_mcp_tools(cx: &mut TestAppContext) {
1329 let ThreadTest {
1330 model,
1331 thread,
1332 context_server_store,
1333 fs,
1334 ..
1335 } = setup(cx, TestModel::Fake).await;
1336 let fake_model = model.as_fake();
1337
1338 // Override profiles and wait for settings to be loaded.
1339 fs.insert_file(
1340 paths::settings_file(),
1341 json!({
1342 "agent": {
1343 "tool_permissions": { "default": "allow" },
1344 "profiles": {
1345 "test": {
1346 "name": "Test Profile",
1347 "enable_all_context_servers": true,
1348 "tools": {
1349 EchoTool::NAME: true,
1350 }
1351 },
1352 }
1353 }
1354 })
1355 .to_string()
1356 .into_bytes(),
1357 )
1358 .await;
1359 cx.run_until_parked();
1360 thread.update(cx, |thread, cx| {
1361 thread.set_profile(AgentProfileId("test".into()), cx)
1362 });
1363
1364 let mut mcp_tool_calls = setup_context_server(
1365 "test_server",
1366 vec![context_server::types::Tool {
1367 name: "echo".into(),
1368 description: None,
1369 input_schema: serde_json::to_value(EchoTool::input_schema(
1370 LanguageModelToolSchemaFormat::JsonSchema,
1371 ))
1372 .unwrap(),
1373 output_schema: None,
1374 annotations: None,
1375 }],
1376 &context_server_store,
1377 cx,
1378 );
1379
1380 let events = thread.update(cx, |thread, cx| {
1381 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1382 });
1383 cx.run_until_parked();
1384
1385 // Simulate the model calling the MCP tool.
1386 let completion = fake_model.pending_completions().pop().unwrap();
1387 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1388 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1389 LanguageModelToolUse {
1390 id: "tool_1".into(),
1391 name: "echo".into(),
1392 raw_input: json!({"text": "test"}).to_string(),
1393 input: json!({"text": "test"}),
1394 is_input_complete: true,
1395 thought_signature: None,
1396 },
1397 ));
1398 fake_model.end_last_completion_stream();
1399 cx.run_until_parked();
1400
1401 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1402 assert_eq!(tool_call_params.name, "echo");
1403 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1404 tool_call_response
1405 .send(context_server::types::CallToolResponse {
1406 content: vec![context_server::types::ToolResponseContent::Text {
1407 text: "test".into(),
1408 }],
1409 is_error: None,
1410 meta: None,
1411 structured_content: None,
1412 })
1413 .unwrap();
1414 cx.run_until_parked();
1415
1416 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1417 fake_model.send_last_completion_stream_text_chunk("Done!");
1418 fake_model.end_last_completion_stream();
1419 events.collect::<Vec<_>>().await;
1420
1421 // Send again after adding the echo tool, ensuring the name collision is resolved.
1422 let events = thread.update(cx, |thread, cx| {
1423 thread.add_tool(EchoTool, None);
1424 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1425 });
1426 cx.run_until_parked();
1427 let completion = fake_model.pending_completions().pop().unwrap();
1428 assert_eq!(
1429 tool_names_for_completion(&completion),
1430 vec!["echo", "test_server_echo"]
1431 );
1432 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1433 LanguageModelToolUse {
1434 id: "tool_2".into(),
1435 name: "test_server_echo".into(),
1436 raw_input: json!({"text": "mcp"}).to_string(),
1437 input: json!({"text": "mcp"}),
1438 is_input_complete: true,
1439 thought_signature: None,
1440 },
1441 ));
1442 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1443 LanguageModelToolUse {
1444 id: "tool_3".into(),
1445 name: "echo".into(),
1446 raw_input: json!({"text": "native"}).to_string(),
1447 input: json!({"text": "native"}),
1448 is_input_complete: true,
1449 thought_signature: None,
1450 },
1451 ));
1452 fake_model.end_last_completion_stream();
1453 cx.run_until_parked();
1454
1455 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1456 assert_eq!(tool_call_params.name, "echo");
1457 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1458 tool_call_response
1459 .send(context_server::types::CallToolResponse {
1460 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1461 is_error: None,
1462 meta: None,
1463 structured_content: None,
1464 })
1465 .unwrap();
1466 cx.run_until_parked();
1467
1468 // Ensure the tool results were inserted with the correct names.
1469 let completion = fake_model.pending_completions().pop().unwrap();
1470 assert_eq!(
1471 completion.messages.last().unwrap().content,
1472 vec![
1473 MessageContent::ToolResult(LanguageModelToolResult {
1474 tool_use_id: "tool_3".into(),
1475 tool_name: "echo".into(),
1476 is_error: false,
1477 content: "native".into(),
1478 output: Some("native".into()),
1479 },),
1480 MessageContent::ToolResult(LanguageModelToolResult {
1481 tool_use_id: "tool_2".into(),
1482 tool_name: "test_server_echo".into(),
1483 is_error: false,
1484 content: "mcp".into(),
1485 output: Some("mcp".into()),
1486 },),
1487 ]
1488 );
1489 fake_model.end_last_completion_stream();
1490 events.collect::<Vec<_>>().await;
1491}
1492
1493#[gpui::test]
1494async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1495 let ThreadTest {
1496 model,
1497 thread,
1498 context_server_store,
1499 fs,
1500 ..
1501 } = setup(cx, TestModel::Fake).await;
1502 let fake_model = model.as_fake();
1503
1504 // Setup settings to allow MCP tools
1505 fs.insert_file(
1506 paths::settings_file(),
1507 json!({
1508 "agent": {
1509 "always_allow_tool_actions": true,
1510 "profiles": {
1511 "test": {
1512 "name": "Test Profile",
1513 "enable_all_context_servers": true,
1514 "tools": {}
1515 },
1516 }
1517 }
1518 })
1519 .to_string()
1520 .into_bytes(),
1521 )
1522 .await;
1523 cx.run_until_parked();
1524 thread.update(cx, |thread, cx| {
1525 thread.set_profile(AgentProfileId("test".into()), cx)
1526 });
1527
1528 // Setup a context server with a tool
1529 let mut mcp_tool_calls = setup_context_server(
1530 "github_server",
1531 vec![context_server::types::Tool {
1532 name: "issue_read".into(),
1533 description: Some("Read a GitHub issue".into()),
1534 input_schema: json!({
1535 "type": "object",
1536 "properties": {
1537 "issue_url": { "type": "string" }
1538 }
1539 }),
1540 output_schema: None,
1541 annotations: None,
1542 }],
1543 &context_server_store,
1544 cx,
1545 );
1546
1547 // Send a message and have the model call the MCP tool
1548 let events = thread.update(cx, |thread, cx| {
1549 thread
1550 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1551 .unwrap()
1552 });
1553 cx.run_until_parked();
1554
1555 // Verify the MCP tool is available to the model
1556 let completion = fake_model.pending_completions().pop().unwrap();
1557 assert_eq!(
1558 tool_names_for_completion(&completion),
1559 vec!["issue_read"],
1560 "MCP tool should be available"
1561 );
1562
1563 // Simulate the model calling the MCP tool
1564 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1565 LanguageModelToolUse {
1566 id: "tool_1".into(),
1567 name: "issue_read".into(),
1568 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1569 .to_string(),
1570 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1571 is_input_complete: true,
1572 thought_signature: None,
1573 },
1574 ));
1575 fake_model.end_last_completion_stream();
1576 cx.run_until_parked();
1577
1578 // The MCP server receives the tool call and responds with content
1579 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1580 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1581 assert_eq!(tool_call_params.name, "issue_read");
1582 tool_call_response
1583 .send(context_server::types::CallToolResponse {
1584 content: vec![context_server::types::ToolResponseContent::Text {
1585 text: expected_tool_output.into(),
1586 }],
1587 is_error: None,
1588 meta: None,
1589 structured_content: None,
1590 })
1591 .unwrap();
1592 cx.run_until_parked();
1593
1594 // After tool completes, the model continues with a new completion request
1595 // that includes the tool results. We need to respond to this.
1596 let _completion = fake_model.pending_completions().pop().unwrap();
1597 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1598 fake_model
1599 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1600 fake_model.end_last_completion_stream();
1601 events.collect::<Vec<_>>().await;
1602
1603 // Verify the tool result is stored in the thread by checking the markdown output.
1604 // The tool result is in the first assistant message (not the last one, which is
1605 // the model's response after the tool completed).
1606 thread.update(cx, |thread, _cx| {
1607 let markdown = thread.to_markdown();
1608 assert!(
1609 markdown.contains("**Tool Result**: issue_read"),
1610 "Thread should contain tool result header"
1611 );
1612 assert!(
1613 markdown.contains(expected_tool_output),
1614 "Thread should contain tool output: {}",
1615 expected_tool_output
1616 );
1617 });
1618
1619 // Simulate app restart: disconnect the MCP server.
1620 // After restart, the MCP server won't be connected yet when the thread is replayed.
1621 context_server_store.update(cx, |store, cx| {
1622 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1623 });
1624 cx.run_until_parked();
1625
1626 // Replay the thread (this is what happens when loading a saved thread)
1627 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1628
1629 let mut found_tool_call = None;
1630 let mut found_tool_call_update_with_output = None;
1631
1632 while let Some(event) = replay_events.next().await {
1633 let event = event.unwrap();
1634 match &event {
1635 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1636 found_tool_call = Some(tc.clone());
1637 }
1638 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1639 if update.tool_call_id.to_string() == "tool_1" =>
1640 {
1641 if update.fields.raw_output.is_some() {
1642 found_tool_call_update_with_output = Some(update.clone());
1643 }
1644 }
1645 _ => {}
1646 }
1647 }
1648
1649 // The tool call should be found
1650 assert!(
1651 found_tool_call.is_some(),
1652 "Tool call should be emitted during replay"
1653 );
1654
1655 assert!(
1656 found_tool_call_update_with_output.is_some(),
1657 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1658 );
1659
1660 let update = found_tool_call_update_with_output.unwrap();
1661 assert_eq!(
1662 update.fields.raw_output,
1663 Some(expected_tool_output.into()),
1664 "raw_output should contain the saved tool result"
1665 );
1666
1667 // Also verify the status is correct (completed, not failed)
1668 assert_eq!(
1669 update.fields.status,
1670 Some(acp::ToolCallStatus::Completed),
1671 "Tool call status should reflect the original completion status"
1672 );
1673}
1674
1675#[gpui::test]
1676async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1677 let ThreadTest {
1678 model,
1679 thread,
1680 context_server_store,
1681 fs,
1682 ..
1683 } = setup(cx, TestModel::Fake).await;
1684 let fake_model = model.as_fake();
1685
1686 // Set up a profile with all tools enabled
1687 fs.insert_file(
1688 paths::settings_file(),
1689 json!({
1690 "agent": {
1691 "profiles": {
1692 "test": {
1693 "name": "Test Profile",
1694 "enable_all_context_servers": true,
1695 "tools": {
1696 EchoTool::NAME: true,
1697 DelayTool::NAME: true,
1698 WordListTool::NAME: true,
1699 ToolRequiringPermission::NAME: true,
1700 InfiniteTool::NAME: true,
1701 }
1702 },
1703 }
1704 }
1705 })
1706 .to_string()
1707 .into_bytes(),
1708 )
1709 .await;
1710 cx.run_until_parked();
1711
1712 thread.update(cx, |thread, cx| {
1713 thread.set_profile(AgentProfileId("test".into()), cx);
1714 thread.add_tool(EchoTool, None);
1715 thread.add_tool(DelayTool, None);
1716 thread.add_tool(WordListTool, None);
1717 thread.add_tool(ToolRequiringPermission, None);
1718 thread.add_tool(InfiniteTool, None);
1719 });
1720
1721 // Set up multiple context servers with some overlapping tool names
1722 let _server1_calls = setup_context_server(
1723 "xxx",
1724 vec![
1725 context_server::types::Tool {
1726 name: "echo".into(), // Conflicts with native EchoTool
1727 description: None,
1728 input_schema: serde_json::to_value(EchoTool::input_schema(
1729 LanguageModelToolSchemaFormat::JsonSchema,
1730 ))
1731 .unwrap(),
1732 output_schema: None,
1733 annotations: None,
1734 },
1735 context_server::types::Tool {
1736 name: "unique_tool_1".into(),
1737 description: None,
1738 input_schema: json!({"type": "object", "properties": {}}),
1739 output_schema: None,
1740 annotations: None,
1741 },
1742 ],
1743 &context_server_store,
1744 cx,
1745 );
1746
1747 let _server2_calls = setup_context_server(
1748 "yyy",
1749 vec![
1750 context_server::types::Tool {
1751 name: "echo".into(), // Also conflicts with native EchoTool
1752 description: None,
1753 input_schema: serde_json::to_value(EchoTool::input_schema(
1754 LanguageModelToolSchemaFormat::JsonSchema,
1755 ))
1756 .unwrap(),
1757 output_schema: None,
1758 annotations: None,
1759 },
1760 context_server::types::Tool {
1761 name: "unique_tool_2".into(),
1762 description: None,
1763 input_schema: json!({"type": "object", "properties": {}}),
1764 output_schema: None,
1765 annotations: None,
1766 },
1767 context_server::types::Tool {
1768 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1769 description: None,
1770 input_schema: json!({"type": "object", "properties": {}}),
1771 output_schema: None,
1772 annotations: None,
1773 },
1774 context_server::types::Tool {
1775 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1776 description: None,
1777 input_schema: json!({"type": "object", "properties": {}}),
1778 output_schema: None,
1779 annotations: None,
1780 },
1781 ],
1782 &context_server_store,
1783 cx,
1784 );
1785 let _server3_calls = setup_context_server(
1786 "zzz",
1787 vec![
1788 context_server::types::Tool {
1789 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1790 description: None,
1791 input_schema: json!({"type": "object", "properties": {}}),
1792 output_schema: None,
1793 annotations: None,
1794 },
1795 context_server::types::Tool {
1796 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1797 description: None,
1798 input_schema: json!({"type": "object", "properties": {}}),
1799 output_schema: None,
1800 annotations: None,
1801 },
1802 context_server::types::Tool {
1803 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1804 description: None,
1805 input_schema: json!({"type": "object", "properties": {}}),
1806 output_schema: None,
1807 annotations: None,
1808 },
1809 ],
1810 &context_server_store,
1811 cx,
1812 );
1813
1814 // Server with spaces in name - tests snake_case conversion for API compatibility
1815 let _server4_calls = setup_context_server(
1816 "Azure DevOps",
1817 vec![context_server::types::Tool {
1818 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1819 description: None,
1820 input_schema: serde_json::to_value(EchoTool::input_schema(
1821 LanguageModelToolSchemaFormat::JsonSchema,
1822 ))
1823 .unwrap(),
1824 output_schema: None,
1825 annotations: None,
1826 }],
1827 &context_server_store,
1828 cx,
1829 );
1830
1831 thread
1832 .update(cx, |thread, cx| {
1833 thread.send(UserMessageId::new(), ["Go"], cx)
1834 })
1835 .unwrap();
1836 cx.run_until_parked();
1837 let completion = fake_model.pending_completions().pop().unwrap();
1838 assert_eq!(
1839 tool_names_for_completion(&completion),
1840 vec![
1841 "azure_dev_ops_echo",
1842 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1843 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1844 "delay",
1845 "echo",
1846 "infinite",
1847 "tool_requiring_permission",
1848 "unique_tool_1",
1849 "unique_tool_2",
1850 "word_list",
1851 "xxx_echo",
1852 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1853 "yyy_echo",
1854 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1855 ]
1856 );
1857}
1858
1859#[gpui::test]
1860#[cfg_attr(not(feature = "e2e"), ignore)]
1861async fn test_cancellation(cx: &mut TestAppContext) {
1862 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1863
1864 let mut events = thread
1865 .update(cx, |thread, cx| {
1866 thread.add_tool(InfiniteTool, None);
1867 thread.add_tool(EchoTool, None);
1868 thread.send(
1869 UserMessageId::new(),
1870 ["Call the echo tool, then call the infinite tool, then explain their output"],
1871 cx,
1872 )
1873 })
1874 .unwrap();
1875
1876 // Wait until both tools are called.
1877 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1878 let mut echo_id = None;
1879 let mut echo_completed = false;
1880 while let Some(event) = events.next().await {
1881 match event.unwrap() {
1882 ThreadEvent::ToolCall(tool_call) => {
1883 assert_eq!(tool_call.title, expected_tools.remove(0));
1884 if tool_call.title == "Echo" {
1885 echo_id = Some(tool_call.tool_call_id);
1886 }
1887 }
1888 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1889 acp::ToolCallUpdate {
1890 tool_call_id,
1891 fields:
1892 acp::ToolCallUpdateFields {
1893 status: Some(acp::ToolCallStatus::Completed),
1894 ..
1895 },
1896 ..
1897 },
1898 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1899 echo_completed = true;
1900 }
1901 _ => {}
1902 }
1903
1904 if expected_tools.is_empty() && echo_completed {
1905 break;
1906 }
1907 }
1908
1909 // Cancel the current send and ensure that the event stream is closed, even
1910 // if one of the tools is still running.
1911 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1912 let events = events.collect::<Vec<_>>().await;
1913 let last_event = events.last();
1914 assert!(
1915 matches!(
1916 last_event,
1917 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1918 ),
1919 "unexpected event {last_event:?}"
1920 );
1921
1922 // Ensure we can still send a new message after cancellation.
1923 let events = thread
1924 .update(cx, |thread, cx| {
1925 thread.send(
1926 UserMessageId::new(),
1927 ["Testing: reply with 'Hello' then stop."],
1928 cx,
1929 )
1930 })
1931 .unwrap()
1932 .collect::<Vec<_>>()
1933 .await;
1934 thread.update(cx, |thread, _cx| {
1935 let message = thread.last_message().unwrap();
1936 let agent_message = message.as_agent_message().unwrap();
1937 assert_eq!(
1938 agent_message.content,
1939 vec![AgentMessageContent::Text("Hello".to_string())]
1940 );
1941 });
1942 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1943}
1944
1945#[gpui::test]
1946async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1947 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1948 always_allow_tools(cx);
1949 let fake_model = model.as_fake();
1950
1951 let environment = Rc::new(cx.update(|cx| {
1952 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1953 }));
1954 let handle = environment.terminal_handle.clone().unwrap();
1955
1956 let mut events = thread
1957 .update(cx, |thread, cx| {
1958 thread.add_tool(
1959 crate::TerminalTool::new(thread.project().clone(), environment),
1960 None,
1961 );
1962 thread.send(UserMessageId::new(), ["run a command"], cx)
1963 })
1964 .unwrap();
1965
1966 cx.run_until_parked();
1967
1968 // Simulate the model calling the terminal tool
1969 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1970 LanguageModelToolUse {
1971 id: "terminal_tool_1".into(),
1972 name: TerminalTool::NAME.into(),
1973 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1974 input: json!({"command": "sleep 1000", "cd": "."}),
1975 is_input_complete: true,
1976 thought_signature: None,
1977 },
1978 ));
1979 fake_model.end_last_completion_stream();
1980
1981 // Wait for the terminal tool to start running
1982 wait_for_terminal_tool_started(&mut events, cx).await;
1983
1984 // Cancel the thread while the terminal is running
1985 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1986
1987 // Collect remaining events, driving the executor to let cancellation complete
1988 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1989
1990 // Verify the terminal was killed
1991 assert!(
1992 handle.was_killed(),
1993 "expected terminal handle to be killed on cancellation"
1994 );
1995
1996 // Verify we got a cancellation stop event
1997 assert_eq!(
1998 stop_events(remaining_events),
1999 vec![acp::StopReason::Cancelled],
2000 );
2001
2002 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2003 thread.update(cx, |thread, _cx| {
2004 let message = thread.last_message().unwrap();
2005 let agent_message = message.as_agent_message().unwrap();
2006
2007 let tool_use = agent_message
2008 .content
2009 .iter()
2010 .find_map(|content| match content {
2011 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2012 _ => None,
2013 })
2014 .expect("expected tool use in agent message");
2015
2016 let tool_result = agent_message
2017 .tool_results
2018 .get(&tool_use.id)
2019 .expect("expected tool result");
2020
2021 let result_text = match &tool_result.content {
2022 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2023 _ => panic!("expected text content in tool result"),
2024 };
2025
2026 // "partial output" comes from FakeTerminalHandle's output field
2027 assert!(
2028 result_text.contains("partial output"),
2029 "expected tool result to contain terminal output, got: {result_text}"
2030 );
2031 // Match the actual format from process_content in terminal_tool.rs
2032 assert!(
2033 result_text.contains("The user stopped this command"),
2034 "expected tool result to indicate user stopped, got: {result_text}"
2035 );
2036 });
2037
2038 // Verify we can send a new message after cancellation
2039 verify_thread_recovery(&thread, &fake_model, cx).await;
2040}
2041
2042#[gpui::test]
2043async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2044 // This test verifies that tools which properly handle cancellation via
2045 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2046 // to cancellation and report that they were cancelled.
2047 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2048 always_allow_tools(cx);
2049 let fake_model = model.as_fake();
2050
2051 let (tool, was_cancelled) = CancellationAwareTool::new();
2052
2053 let mut events = thread
2054 .update(cx, |thread, cx| {
2055 thread.add_tool(tool, None);
2056 thread.send(
2057 UserMessageId::new(),
2058 ["call the cancellation aware tool"],
2059 cx,
2060 )
2061 })
2062 .unwrap();
2063
2064 cx.run_until_parked();
2065
2066 // Simulate the model calling the cancellation-aware tool
2067 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2068 LanguageModelToolUse {
2069 id: "cancellation_aware_1".into(),
2070 name: "cancellation_aware".into(),
2071 raw_input: r#"{}"#.into(),
2072 input: json!({}),
2073 is_input_complete: true,
2074 thought_signature: None,
2075 },
2076 ));
2077 fake_model.end_last_completion_stream();
2078
2079 cx.run_until_parked();
2080
2081 // Wait for the tool call to be reported
2082 let mut tool_started = false;
2083 let deadline = cx.executor().num_cpus() * 100;
2084 for _ in 0..deadline {
2085 cx.run_until_parked();
2086
2087 while let Some(Some(event)) = events.next().now_or_never() {
2088 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2089 if tool_call.title == "Cancellation Aware Tool" {
2090 tool_started = true;
2091 break;
2092 }
2093 }
2094 }
2095
2096 if tool_started {
2097 break;
2098 }
2099
2100 cx.background_executor
2101 .timer(Duration::from_millis(10))
2102 .await;
2103 }
2104 assert!(tool_started, "expected cancellation aware tool to start");
2105
2106 // Cancel the thread and wait for it to complete
2107 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2108
2109 // The cancel task should complete promptly because the tool handles cancellation
2110 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2111 futures::select! {
2112 _ = cancel_task.fuse() => {}
2113 _ = timeout.fuse() => {
2114 panic!("cancel task timed out - tool did not respond to cancellation");
2115 }
2116 }
2117
2118 // Verify the tool detected cancellation via its flag
2119 assert!(
2120 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2121 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2122 );
2123
2124 // Collect remaining events
2125 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2126
2127 // Verify we got a cancellation stop event
2128 assert_eq!(
2129 stop_events(remaining_events),
2130 vec![acp::StopReason::Cancelled],
2131 );
2132
2133 // Verify we can send a new message after cancellation
2134 verify_thread_recovery(&thread, &fake_model, cx).await;
2135}
2136
2137/// Helper to verify thread can recover after cancellation by sending a simple message.
2138async fn verify_thread_recovery(
2139 thread: &Entity<Thread>,
2140 fake_model: &FakeLanguageModel,
2141 cx: &mut TestAppContext,
2142) {
2143 let events = thread
2144 .update(cx, |thread, cx| {
2145 thread.send(
2146 UserMessageId::new(),
2147 ["Testing: reply with 'Hello' then stop."],
2148 cx,
2149 )
2150 })
2151 .unwrap();
2152 cx.run_until_parked();
2153 fake_model.send_last_completion_stream_text_chunk("Hello");
2154 fake_model
2155 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2156 fake_model.end_last_completion_stream();
2157
2158 let events = events.collect::<Vec<_>>().await;
2159 thread.update(cx, |thread, _cx| {
2160 let message = thread.last_message().unwrap();
2161 let agent_message = message.as_agent_message().unwrap();
2162 assert_eq!(
2163 agent_message.content,
2164 vec![AgentMessageContent::Text("Hello".to_string())]
2165 );
2166 });
2167 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2168}
2169
2170/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2171async fn wait_for_terminal_tool_started(
2172 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2173 cx: &mut TestAppContext,
2174) {
2175 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2176 for _ in 0..deadline {
2177 cx.run_until_parked();
2178
2179 while let Some(Some(event)) = events.next().now_or_never() {
2180 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2181 update,
2182 ))) = &event
2183 {
2184 if update.fields.content.as_ref().is_some_and(|content| {
2185 content
2186 .iter()
2187 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2188 }) {
2189 return;
2190 }
2191 }
2192 }
2193
2194 cx.background_executor
2195 .timer(Duration::from_millis(10))
2196 .await;
2197 }
2198 panic!("terminal tool did not start within the expected time");
2199}
2200
2201/// Collects events until a Stop event is received, driving the executor to completion.
2202async fn collect_events_until_stop(
2203 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2204 cx: &mut TestAppContext,
2205) -> Vec<Result<ThreadEvent>> {
2206 let mut collected = Vec::new();
2207 let deadline = cx.executor().num_cpus() * 200;
2208
2209 for _ in 0..deadline {
2210 cx.executor().advance_clock(Duration::from_millis(10));
2211 cx.run_until_parked();
2212
2213 while let Some(Some(event)) = events.next().now_or_never() {
2214 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2215 collected.push(event);
2216 if is_stop {
2217 return collected;
2218 }
2219 }
2220 }
2221 panic!(
2222 "did not receive Stop event within the expected time; collected {} events",
2223 collected.len()
2224 );
2225}
2226
2227#[gpui::test]
2228async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2229 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2230 always_allow_tools(cx);
2231 let fake_model = model.as_fake();
2232
2233 let environment = Rc::new(cx.update(|cx| {
2234 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2235 }));
2236 let handle = environment.terminal_handle.clone().unwrap();
2237
2238 let message_id = UserMessageId::new();
2239 let mut events = thread
2240 .update(cx, |thread, cx| {
2241 thread.add_tool(
2242 crate::TerminalTool::new(thread.project().clone(), environment),
2243 None,
2244 );
2245 thread.send(message_id.clone(), ["run a command"], cx)
2246 })
2247 .unwrap();
2248
2249 cx.run_until_parked();
2250
2251 // Simulate the model calling the terminal tool
2252 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2253 LanguageModelToolUse {
2254 id: "terminal_tool_1".into(),
2255 name: TerminalTool::NAME.into(),
2256 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2257 input: json!({"command": "sleep 1000", "cd": "."}),
2258 is_input_complete: true,
2259 thought_signature: None,
2260 },
2261 ));
2262 fake_model.end_last_completion_stream();
2263
2264 // Wait for the terminal tool to start running
2265 wait_for_terminal_tool_started(&mut events, cx).await;
2266
2267 // Truncate the thread while the terminal is running
2268 thread
2269 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2270 .unwrap();
2271
2272 // Drive the executor to let cancellation complete
2273 let _ = collect_events_until_stop(&mut events, cx).await;
2274
2275 // Verify the terminal was killed
2276 assert!(
2277 handle.was_killed(),
2278 "expected terminal handle to be killed on truncate"
2279 );
2280
2281 // Verify the thread is empty after truncation
2282 thread.update(cx, |thread, _cx| {
2283 assert_eq!(
2284 thread.to_markdown(),
2285 "",
2286 "expected thread to be empty after truncating the only message"
2287 );
2288 });
2289
2290 // Verify we can send a new message after truncation
2291 verify_thread_recovery(&thread, &fake_model, cx).await;
2292}
2293
2294#[gpui::test]
2295async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2296 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2297 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2298 always_allow_tools(cx);
2299 let fake_model = model.as_fake();
2300
2301 let environment = Rc::new(MultiTerminalEnvironment::new());
2302
2303 let mut events = thread
2304 .update(cx, |thread, cx| {
2305 thread.add_tool(
2306 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2307 None,
2308 );
2309 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2310 })
2311 .unwrap();
2312
2313 cx.run_until_parked();
2314
2315 // Simulate the model calling two terminal tools
2316 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2317 LanguageModelToolUse {
2318 id: "terminal_tool_1".into(),
2319 name: TerminalTool::NAME.into(),
2320 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2321 input: json!({"command": "sleep 1000", "cd": "."}),
2322 is_input_complete: true,
2323 thought_signature: None,
2324 },
2325 ));
2326 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2327 LanguageModelToolUse {
2328 id: "terminal_tool_2".into(),
2329 name: TerminalTool::NAME.into(),
2330 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2331 input: json!({"command": "sleep 2000", "cd": "."}),
2332 is_input_complete: true,
2333 thought_signature: None,
2334 },
2335 ));
2336 fake_model.end_last_completion_stream();
2337
2338 // Wait for both terminal tools to start by counting terminal content updates
2339 let mut terminals_started = 0;
2340 let deadline = cx.executor().num_cpus() * 100;
2341 for _ in 0..deadline {
2342 cx.run_until_parked();
2343
2344 while let Some(Some(event)) = events.next().now_or_never() {
2345 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2346 update,
2347 ))) = &event
2348 {
2349 if update.fields.content.as_ref().is_some_and(|content| {
2350 content
2351 .iter()
2352 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2353 }) {
2354 terminals_started += 1;
2355 if terminals_started >= 2 {
2356 break;
2357 }
2358 }
2359 }
2360 }
2361 if terminals_started >= 2 {
2362 break;
2363 }
2364
2365 cx.background_executor
2366 .timer(Duration::from_millis(10))
2367 .await;
2368 }
2369 assert!(
2370 terminals_started >= 2,
2371 "expected 2 terminal tools to start, got {terminals_started}"
2372 );
2373
2374 // Cancel the thread while both terminals are running
2375 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2376
2377 // Collect remaining events
2378 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2379
2380 // Verify both terminal handles were killed
2381 let handles = environment.handles();
2382 assert_eq!(
2383 handles.len(),
2384 2,
2385 "expected 2 terminal handles to be created"
2386 );
2387 assert!(
2388 handles[0].was_killed(),
2389 "expected first terminal handle to be killed on cancellation"
2390 );
2391 assert!(
2392 handles[1].was_killed(),
2393 "expected second terminal handle to be killed on cancellation"
2394 );
2395
2396 // Verify we got a cancellation stop event
2397 assert_eq!(
2398 stop_events(remaining_events),
2399 vec![acp::StopReason::Cancelled],
2400 );
2401}
2402
2403#[gpui::test]
2404async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2405 // Tests that clicking the stop button on the terminal card (as opposed to the main
2406 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2407 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2408 always_allow_tools(cx);
2409 let fake_model = model.as_fake();
2410
2411 let environment = Rc::new(cx.update(|cx| {
2412 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2413 }));
2414 let handle = environment.terminal_handle.clone().unwrap();
2415
2416 let mut events = thread
2417 .update(cx, |thread, cx| {
2418 thread.add_tool(
2419 crate::TerminalTool::new(thread.project().clone(), environment),
2420 None,
2421 );
2422 thread.send(UserMessageId::new(), ["run a command"], cx)
2423 })
2424 .unwrap();
2425
2426 cx.run_until_parked();
2427
2428 // Simulate the model calling the terminal tool
2429 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2430 LanguageModelToolUse {
2431 id: "terminal_tool_1".into(),
2432 name: TerminalTool::NAME.into(),
2433 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2434 input: json!({"command": "sleep 1000", "cd": "."}),
2435 is_input_complete: true,
2436 thought_signature: None,
2437 },
2438 ));
2439 fake_model.end_last_completion_stream();
2440
2441 // Wait for the terminal tool to start running
2442 wait_for_terminal_tool_started(&mut events, cx).await;
2443
2444 // Simulate user clicking stop on the terminal card itself.
2445 // This sets the flag and signals exit (simulating what the real UI would do).
2446 handle.set_stopped_by_user(true);
2447 handle.killed.store(true, Ordering::SeqCst);
2448 handle.signal_exit();
2449
2450 // Wait for the tool to complete
2451 cx.run_until_parked();
2452
2453 // The thread continues after tool completion - simulate the model ending its turn
2454 fake_model
2455 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2456 fake_model.end_last_completion_stream();
2457
2458 // Collect remaining events
2459 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2460
2461 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2462 assert_eq!(
2463 stop_events(remaining_events),
2464 vec![acp::StopReason::EndTurn],
2465 );
2466
2467 // Verify the tool result indicates user stopped
2468 thread.update(cx, |thread, _cx| {
2469 let message = thread.last_message().unwrap();
2470 let agent_message = message.as_agent_message().unwrap();
2471
2472 let tool_use = agent_message
2473 .content
2474 .iter()
2475 .find_map(|content| match content {
2476 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2477 _ => None,
2478 })
2479 .expect("expected tool use in agent message");
2480
2481 let tool_result = agent_message
2482 .tool_results
2483 .get(&tool_use.id)
2484 .expect("expected tool result");
2485
2486 let result_text = match &tool_result.content {
2487 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2488 _ => panic!("expected text content in tool result"),
2489 };
2490
2491 assert!(
2492 result_text.contains("The user stopped this command"),
2493 "expected tool result to indicate user stopped, got: {result_text}"
2494 );
2495 });
2496}
2497
2498#[gpui::test]
2499async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2500 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2501 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2502 always_allow_tools(cx);
2503 let fake_model = model.as_fake();
2504
2505 let environment = Rc::new(cx.update(|cx| {
2506 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2507 }));
2508 let handle = environment.terminal_handle.clone().unwrap();
2509
2510 let mut events = thread
2511 .update(cx, |thread, cx| {
2512 thread.add_tool(
2513 crate::TerminalTool::new(thread.project().clone(), environment),
2514 None,
2515 );
2516 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2517 })
2518 .unwrap();
2519
2520 cx.run_until_parked();
2521
2522 // Simulate the model calling the terminal tool with a short timeout
2523 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2524 LanguageModelToolUse {
2525 id: "terminal_tool_1".into(),
2526 name: TerminalTool::NAME.into(),
2527 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2528 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2529 is_input_complete: true,
2530 thought_signature: None,
2531 },
2532 ));
2533 fake_model.end_last_completion_stream();
2534
2535 // Wait for the terminal tool to start running
2536 wait_for_terminal_tool_started(&mut events, cx).await;
2537
2538 // Advance clock past the timeout
2539 cx.executor().advance_clock(Duration::from_millis(200));
2540 cx.run_until_parked();
2541
2542 // The thread continues after tool completion - simulate the model ending its turn
2543 fake_model
2544 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2545 fake_model.end_last_completion_stream();
2546
2547 // Collect remaining events
2548 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2549
2550 // Verify the terminal was killed due to timeout
2551 assert!(
2552 handle.was_killed(),
2553 "expected terminal handle to be killed on timeout"
2554 );
2555
2556 // Verify we got an EndTurn (the tool completed, just with timeout)
2557 assert_eq!(
2558 stop_events(remaining_events),
2559 vec![acp::StopReason::EndTurn],
2560 );
2561
2562 // Verify the tool result indicates timeout, not user stopped
2563 thread.update(cx, |thread, _cx| {
2564 let message = thread.last_message().unwrap();
2565 let agent_message = message.as_agent_message().unwrap();
2566
2567 let tool_use = agent_message
2568 .content
2569 .iter()
2570 .find_map(|content| match content {
2571 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2572 _ => None,
2573 })
2574 .expect("expected tool use in agent message");
2575
2576 let tool_result = agent_message
2577 .tool_results
2578 .get(&tool_use.id)
2579 .expect("expected tool result");
2580
2581 let result_text = match &tool_result.content {
2582 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2583 _ => panic!("expected text content in tool result"),
2584 };
2585
2586 assert!(
2587 result_text.contains("timed out"),
2588 "expected tool result to indicate timeout, got: {result_text}"
2589 );
2590 assert!(
2591 !result_text.contains("The user stopped"),
2592 "tool result should not mention user stopped when it timed out, got: {result_text}"
2593 );
2594 });
2595}
2596
2597#[gpui::test]
2598async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2599 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2600 let fake_model = model.as_fake();
2601
2602 let events_1 = thread
2603 .update(cx, |thread, cx| {
2604 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2605 })
2606 .unwrap();
2607 cx.run_until_parked();
2608 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2609 cx.run_until_parked();
2610
2611 let events_2 = thread
2612 .update(cx, |thread, cx| {
2613 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2614 })
2615 .unwrap();
2616 cx.run_until_parked();
2617 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2618 fake_model
2619 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2620 fake_model.end_last_completion_stream();
2621
2622 let events_1 = events_1.collect::<Vec<_>>().await;
2623 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2624 let events_2 = events_2.collect::<Vec<_>>().await;
2625 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2626}
2627
2628#[gpui::test]
2629async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2630 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2631 let fake_model = model.as_fake();
2632
2633 let events_1 = thread
2634 .update(cx, |thread, cx| {
2635 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2636 })
2637 .unwrap();
2638 cx.run_until_parked();
2639 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2640 fake_model
2641 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2642 fake_model.end_last_completion_stream();
2643 let events_1 = events_1.collect::<Vec<_>>().await;
2644
2645 let events_2 = thread
2646 .update(cx, |thread, cx| {
2647 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2648 })
2649 .unwrap();
2650 cx.run_until_parked();
2651 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2652 fake_model
2653 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2654 fake_model.end_last_completion_stream();
2655 let events_2 = events_2.collect::<Vec<_>>().await;
2656
2657 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2658 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2659}
2660
2661#[gpui::test]
2662async fn test_refusal(cx: &mut TestAppContext) {
2663 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2664 let fake_model = model.as_fake();
2665
2666 let events = thread
2667 .update(cx, |thread, cx| {
2668 thread.send(UserMessageId::new(), ["Hello"], cx)
2669 })
2670 .unwrap();
2671 cx.run_until_parked();
2672 thread.read_with(cx, |thread, _| {
2673 assert_eq!(
2674 thread.to_markdown(),
2675 indoc! {"
2676 ## User
2677
2678 Hello
2679 "}
2680 );
2681 });
2682
2683 fake_model.send_last_completion_stream_text_chunk("Hey!");
2684 cx.run_until_parked();
2685 thread.read_with(cx, |thread, _| {
2686 assert_eq!(
2687 thread.to_markdown(),
2688 indoc! {"
2689 ## User
2690
2691 Hello
2692
2693 ## Assistant
2694
2695 Hey!
2696 "}
2697 );
2698 });
2699
2700 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2701 fake_model
2702 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2703 let events = events.collect::<Vec<_>>().await;
2704 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2705 thread.read_with(cx, |thread, _| {
2706 assert_eq!(thread.to_markdown(), "");
2707 });
2708}
2709
2710#[gpui::test]
2711async fn test_truncate_first_message(cx: &mut TestAppContext) {
2712 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2713 let fake_model = model.as_fake();
2714
2715 let message_id = UserMessageId::new();
2716 thread
2717 .update(cx, |thread, cx| {
2718 thread.send(message_id.clone(), ["Hello"], cx)
2719 })
2720 .unwrap();
2721 cx.run_until_parked();
2722 thread.read_with(cx, |thread, _| {
2723 assert_eq!(
2724 thread.to_markdown(),
2725 indoc! {"
2726 ## User
2727
2728 Hello
2729 "}
2730 );
2731 assert_eq!(thread.latest_token_usage(), None);
2732 });
2733
2734 fake_model.send_last_completion_stream_text_chunk("Hey!");
2735 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2736 language_model::TokenUsage {
2737 input_tokens: 32_000,
2738 output_tokens: 16_000,
2739 cache_creation_input_tokens: 0,
2740 cache_read_input_tokens: 0,
2741 },
2742 ));
2743 cx.run_until_parked();
2744 thread.read_with(cx, |thread, _| {
2745 assert_eq!(
2746 thread.to_markdown(),
2747 indoc! {"
2748 ## User
2749
2750 Hello
2751
2752 ## Assistant
2753
2754 Hey!
2755 "}
2756 );
2757 assert_eq!(
2758 thread.latest_token_usage(),
2759 Some(acp_thread::TokenUsage {
2760 used_tokens: 32_000 + 16_000,
2761 max_tokens: 1_000_000,
2762 max_output_tokens: None,
2763 input_tokens: 32_000,
2764 output_tokens: 16_000,
2765 })
2766 );
2767 });
2768
2769 thread
2770 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2771 .unwrap();
2772 cx.run_until_parked();
2773 thread.read_with(cx, |thread, _| {
2774 assert_eq!(thread.to_markdown(), "");
2775 assert_eq!(thread.latest_token_usage(), None);
2776 });
2777
2778 // Ensure we can still send a new message after truncation.
2779 thread
2780 .update(cx, |thread, cx| {
2781 thread.send(UserMessageId::new(), ["Hi"], cx)
2782 })
2783 .unwrap();
2784 thread.update(cx, |thread, _cx| {
2785 assert_eq!(
2786 thread.to_markdown(),
2787 indoc! {"
2788 ## User
2789
2790 Hi
2791 "}
2792 );
2793 });
2794 cx.run_until_parked();
2795 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2796 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2797 language_model::TokenUsage {
2798 input_tokens: 40_000,
2799 output_tokens: 20_000,
2800 cache_creation_input_tokens: 0,
2801 cache_read_input_tokens: 0,
2802 },
2803 ));
2804 cx.run_until_parked();
2805 thread.read_with(cx, |thread, _| {
2806 assert_eq!(
2807 thread.to_markdown(),
2808 indoc! {"
2809 ## User
2810
2811 Hi
2812
2813 ## Assistant
2814
2815 Ahoy!
2816 "}
2817 );
2818
2819 assert_eq!(
2820 thread.latest_token_usage(),
2821 Some(acp_thread::TokenUsage {
2822 used_tokens: 40_000 + 20_000,
2823 max_tokens: 1_000_000,
2824 max_output_tokens: None,
2825 input_tokens: 40_000,
2826 output_tokens: 20_000,
2827 })
2828 );
2829 });
2830}
2831
2832#[gpui::test]
2833async fn test_truncate_second_message(cx: &mut TestAppContext) {
2834 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2835 let fake_model = model.as_fake();
2836
2837 thread
2838 .update(cx, |thread, cx| {
2839 thread.send(UserMessageId::new(), ["Message 1"], cx)
2840 })
2841 .unwrap();
2842 cx.run_until_parked();
2843 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2844 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2845 language_model::TokenUsage {
2846 input_tokens: 32_000,
2847 output_tokens: 16_000,
2848 cache_creation_input_tokens: 0,
2849 cache_read_input_tokens: 0,
2850 },
2851 ));
2852 fake_model.end_last_completion_stream();
2853 cx.run_until_parked();
2854
2855 let assert_first_message_state = |cx: &mut TestAppContext| {
2856 thread.clone().read_with(cx, |thread, _| {
2857 assert_eq!(
2858 thread.to_markdown(),
2859 indoc! {"
2860 ## User
2861
2862 Message 1
2863
2864 ## Assistant
2865
2866 Message 1 response
2867 "}
2868 );
2869
2870 assert_eq!(
2871 thread.latest_token_usage(),
2872 Some(acp_thread::TokenUsage {
2873 used_tokens: 32_000 + 16_000,
2874 max_tokens: 1_000_000,
2875 max_output_tokens: None,
2876 input_tokens: 32_000,
2877 output_tokens: 16_000,
2878 })
2879 );
2880 });
2881 };
2882
2883 assert_first_message_state(cx);
2884
2885 let second_message_id = UserMessageId::new();
2886 thread
2887 .update(cx, |thread, cx| {
2888 thread.send(second_message_id.clone(), ["Message 2"], cx)
2889 })
2890 .unwrap();
2891 cx.run_until_parked();
2892
2893 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2894 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2895 language_model::TokenUsage {
2896 input_tokens: 40_000,
2897 output_tokens: 20_000,
2898 cache_creation_input_tokens: 0,
2899 cache_read_input_tokens: 0,
2900 },
2901 ));
2902 fake_model.end_last_completion_stream();
2903 cx.run_until_parked();
2904
2905 thread.read_with(cx, |thread, _| {
2906 assert_eq!(
2907 thread.to_markdown(),
2908 indoc! {"
2909 ## User
2910
2911 Message 1
2912
2913 ## Assistant
2914
2915 Message 1 response
2916
2917 ## User
2918
2919 Message 2
2920
2921 ## Assistant
2922
2923 Message 2 response
2924 "}
2925 );
2926
2927 assert_eq!(
2928 thread.latest_token_usage(),
2929 Some(acp_thread::TokenUsage {
2930 used_tokens: 40_000 + 20_000,
2931 max_tokens: 1_000_000,
2932 max_output_tokens: None,
2933 input_tokens: 40_000,
2934 output_tokens: 20_000,
2935 })
2936 );
2937 });
2938
2939 thread
2940 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2941 .unwrap();
2942 cx.run_until_parked();
2943
2944 assert_first_message_state(cx);
2945}
2946
2947#[gpui::test]
2948async fn test_title_generation(cx: &mut TestAppContext) {
2949 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2950 let fake_model = model.as_fake();
2951
2952 let summary_model = Arc::new(FakeLanguageModel::default());
2953 thread.update(cx, |thread, cx| {
2954 thread.set_summarization_model(Some(summary_model.clone()), cx)
2955 });
2956
2957 let send = thread
2958 .update(cx, |thread, cx| {
2959 thread.send(UserMessageId::new(), ["Hello"], cx)
2960 })
2961 .unwrap();
2962 cx.run_until_parked();
2963
2964 fake_model.send_last_completion_stream_text_chunk("Hey!");
2965 fake_model.end_last_completion_stream();
2966 cx.run_until_parked();
2967 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2968
2969 // Ensure the summary model has been invoked to generate a title.
2970 summary_model.send_last_completion_stream_text_chunk("Hello ");
2971 summary_model.send_last_completion_stream_text_chunk("world\nG");
2972 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2973 summary_model.end_last_completion_stream();
2974 send.collect::<Vec<_>>().await;
2975 cx.run_until_parked();
2976 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2977
2978 // Send another message, ensuring no title is generated this time.
2979 let send = thread
2980 .update(cx, |thread, cx| {
2981 thread.send(UserMessageId::new(), ["Hello again"], cx)
2982 })
2983 .unwrap();
2984 cx.run_until_parked();
2985 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2986 fake_model.end_last_completion_stream();
2987 cx.run_until_parked();
2988 assert_eq!(summary_model.pending_completions(), Vec::new());
2989 send.collect::<Vec<_>>().await;
2990 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2991}
2992
2993#[gpui::test]
2994async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2995 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2996 let fake_model = model.as_fake();
2997
2998 let _events = thread
2999 .update(cx, |thread, cx| {
3000 thread.add_tool(ToolRequiringPermission, None);
3001 thread.add_tool(EchoTool, None);
3002 thread.send(UserMessageId::new(), ["Hey!"], cx)
3003 })
3004 .unwrap();
3005 cx.run_until_parked();
3006
3007 let permission_tool_use = LanguageModelToolUse {
3008 id: "tool_id_1".into(),
3009 name: ToolRequiringPermission::NAME.into(),
3010 raw_input: "{}".into(),
3011 input: json!({}),
3012 is_input_complete: true,
3013 thought_signature: None,
3014 };
3015 let echo_tool_use = LanguageModelToolUse {
3016 id: "tool_id_2".into(),
3017 name: EchoTool::NAME.into(),
3018 raw_input: json!({"text": "test"}).to_string(),
3019 input: json!({"text": "test"}),
3020 is_input_complete: true,
3021 thought_signature: None,
3022 };
3023 fake_model.send_last_completion_stream_text_chunk("Hi!");
3024 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3025 permission_tool_use,
3026 ));
3027 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3028 echo_tool_use.clone(),
3029 ));
3030 fake_model.end_last_completion_stream();
3031 cx.run_until_parked();
3032
3033 // Ensure pending tools are skipped when building a request.
3034 let request = thread
3035 .read_with(cx, |thread, cx| {
3036 thread.build_completion_request(CompletionIntent::EditFile, cx)
3037 })
3038 .unwrap();
3039 assert_eq!(
3040 request.messages[1..],
3041 vec![
3042 LanguageModelRequestMessage {
3043 role: Role::User,
3044 content: vec!["Hey!".into()],
3045 cache: true,
3046 reasoning_details: None,
3047 },
3048 LanguageModelRequestMessage {
3049 role: Role::Assistant,
3050 content: vec![
3051 MessageContent::Text("Hi!".into()),
3052 MessageContent::ToolUse(echo_tool_use.clone())
3053 ],
3054 cache: false,
3055 reasoning_details: None,
3056 },
3057 LanguageModelRequestMessage {
3058 role: Role::User,
3059 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3060 tool_use_id: echo_tool_use.id.clone(),
3061 tool_name: echo_tool_use.name,
3062 is_error: false,
3063 content: "test".into(),
3064 output: Some("test".into())
3065 })],
3066 cache: false,
3067 reasoning_details: None,
3068 },
3069 ],
3070 );
3071}
3072
3073#[gpui::test]
3074async fn test_agent_connection(cx: &mut TestAppContext) {
3075 cx.update(settings::init);
3076 let templates = Templates::new();
3077
3078 // Initialize language model system with test provider
3079 cx.update(|cx| {
3080 gpui_tokio::init(cx);
3081
3082 let http_client = FakeHttpClient::with_404_response();
3083 let clock = Arc::new(clock::FakeSystemClock::new());
3084 let client = Client::new(clock, http_client, cx);
3085 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3086 language_model::init(client.clone(), cx);
3087 language_models::init(user_store, client.clone(), cx);
3088 LanguageModelRegistry::test(cx);
3089 });
3090 cx.executor().forbid_parking();
3091
3092 // Create a project for new_thread
3093 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3094 fake_fs.insert_tree(path!("/test"), json!({})).await;
3095 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3096 let cwd = Path::new("/test");
3097 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3098
3099 // Create agent and connection
3100 let agent = NativeAgent::new(
3101 project.clone(),
3102 thread_store,
3103 templates.clone(),
3104 None,
3105 fake_fs.clone(),
3106 &mut cx.to_async(),
3107 )
3108 .await
3109 .unwrap();
3110 let connection = NativeAgentConnection(agent.clone());
3111
3112 // Create a thread using new_thread
3113 let connection_rc = Rc::new(connection.clone());
3114 let acp_thread = cx
3115 .update(|cx| connection_rc.new_session(project, cwd, cx))
3116 .await
3117 .expect("new_thread should succeed");
3118
3119 // Get the session_id from the AcpThread
3120 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3121
3122 // Test model_selector returns Some
3123 let selector_opt = connection.model_selector(&session_id);
3124 assert!(
3125 selector_opt.is_some(),
3126 "agent should always support ModelSelector"
3127 );
3128 let selector = selector_opt.unwrap();
3129
3130 // Test list_models
3131 let listed_models = cx
3132 .update(|cx| selector.list_models(cx))
3133 .await
3134 .expect("list_models should succeed");
3135 let AgentModelList::Grouped(listed_models) = listed_models else {
3136 panic!("Unexpected model list type");
3137 };
3138 assert!(!listed_models.is_empty(), "should have at least one model");
3139 assert_eq!(
3140 listed_models[&AgentModelGroupName("Fake".into())][0]
3141 .id
3142 .0
3143 .as_ref(),
3144 "fake/fake"
3145 );
3146
3147 // Test selected_model returns the default
3148 let model = cx
3149 .update(|cx| selector.selected_model(cx))
3150 .await
3151 .expect("selected_model should succeed");
3152 let model = cx
3153 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3154 .unwrap();
3155 let model = model.as_fake();
3156 assert_eq!(model.id().0, "fake", "should return default model");
3157
3158 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3159 cx.run_until_parked();
3160 model.send_last_completion_stream_text_chunk("def");
3161 cx.run_until_parked();
3162 acp_thread.read_with(cx, |thread, cx| {
3163 assert_eq!(
3164 thread.to_markdown(cx),
3165 indoc! {"
3166 ## User
3167
3168 abc
3169
3170 ## Assistant
3171
3172 def
3173
3174 "}
3175 )
3176 });
3177
3178 // Test cancel
3179 cx.update(|cx| connection.cancel(&session_id, cx));
3180 request.await.expect("prompt should fail gracefully");
3181
3182 // Explicitly close the session and drop the ACP thread.
3183 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3184 .await
3185 .unwrap();
3186 drop(acp_thread);
3187 let result = cx
3188 .update(|cx| {
3189 connection.prompt(
3190 Some(acp_thread::UserMessageId::new()),
3191 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3192 cx,
3193 )
3194 })
3195 .await;
3196 assert_eq!(
3197 result.as_ref().unwrap_err().to_string(),
3198 "Session not found",
3199 "unexpected result: {:?}",
3200 result
3201 );
3202}
3203
3204#[gpui::test]
3205async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3206 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3207 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
3208 let fake_model = model.as_fake();
3209
3210 let mut events = thread
3211 .update(cx, |thread, cx| {
3212 thread.send(UserMessageId::new(), ["Echo something"], cx)
3213 })
3214 .unwrap();
3215 cx.run_until_parked();
3216
3217 // Simulate streaming partial input.
3218 let input = json!({});
3219 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3220 LanguageModelToolUse {
3221 id: "1".into(),
3222 name: EchoTool::NAME.into(),
3223 raw_input: input.to_string(),
3224 input,
3225 is_input_complete: false,
3226 thought_signature: None,
3227 },
3228 ));
3229
3230 // Input streaming completed
3231 let input = json!({ "text": "Hello!" });
3232 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3233 LanguageModelToolUse {
3234 id: "1".into(),
3235 name: "echo".into(),
3236 raw_input: input.to_string(),
3237 input,
3238 is_input_complete: true,
3239 thought_signature: None,
3240 },
3241 ));
3242 fake_model.end_last_completion_stream();
3243 cx.run_until_parked();
3244
3245 let tool_call = expect_tool_call(&mut events).await;
3246 assert_eq!(
3247 tool_call,
3248 acp::ToolCall::new("1", "Echo")
3249 .raw_input(json!({}))
3250 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3251 );
3252 let update = expect_tool_call_update_fields(&mut events).await;
3253 assert_eq!(
3254 update,
3255 acp::ToolCallUpdate::new(
3256 "1",
3257 acp::ToolCallUpdateFields::new()
3258 .title("Echo")
3259 .kind(acp::ToolKind::Other)
3260 .raw_input(json!({ "text": "Hello!"}))
3261 )
3262 );
3263 let update = expect_tool_call_update_fields(&mut events).await;
3264 assert_eq!(
3265 update,
3266 acp::ToolCallUpdate::new(
3267 "1",
3268 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3269 )
3270 );
3271 let update = expect_tool_call_update_fields(&mut events).await;
3272 assert_eq!(
3273 update,
3274 acp::ToolCallUpdate::new(
3275 "1",
3276 acp::ToolCallUpdateFields::new()
3277 .status(acp::ToolCallStatus::Completed)
3278 .raw_output("Hello!")
3279 )
3280 );
3281}
3282
3283#[gpui::test]
3284async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3285 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3286 let fake_model = model.as_fake();
3287
3288 let mut events = thread
3289 .update(cx, |thread, cx| {
3290 thread.send(UserMessageId::new(), ["Hello!"], cx)
3291 })
3292 .unwrap();
3293 cx.run_until_parked();
3294
3295 fake_model.send_last_completion_stream_text_chunk("Hey!");
3296 fake_model.end_last_completion_stream();
3297
3298 let mut retry_events = Vec::new();
3299 while let Some(Ok(event)) = events.next().await {
3300 match event {
3301 ThreadEvent::Retry(retry_status) => {
3302 retry_events.push(retry_status);
3303 }
3304 ThreadEvent::Stop(..) => break,
3305 _ => {}
3306 }
3307 }
3308
3309 assert_eq!(retry_events.len(), 0);
3310 thread.read_with(cx, |thread, _cx| {
3311 assert_eq!(
3312 thread.to_markdown(),
3313 indoc! {"
3314 ## User
3315
3316 Hello!
3317
3318 ## Assistant
3319
3320 Hey!
3321 "}
3322 )
3323 });
3324}
3325
3326#[gpui::test]
3327async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3328 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3329 let fake_model = model.as_fake();
3330
3331 let mut events = thread
3332 .update(cx, |thread, cx| {
3333 thread.send(UserMessageId::new(), ["Hello!"], cx)
3334 })
3335 .unwrap();
3336 cx.run_until_parked();
3337
3338 fake_model.send_last_completion_stream_text_chunk("Hey,");
3339 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3340 provider: LanguageModelProviderName::new("Anthropic"),
3341 retry_after: Some(Duration::from_secs(3)),
3342 });
3343 fake_model.end_last_completion_stream();
3344
3345 cx.executor().advance_clock(Duration::from_secs(3));
3346 cx.run_until_parked();
3347
3348 fake_model.send_last_completion_stream_text_chunk("there!");
3349 fake_model.end_last_completion_stream();
3350 cx.run_until_parked();
3351
3352 let mut retry_events = Vec::new();
3353 while let Some(Ok(event)) = events.next().await {
3354 match event {
3355 ThreadEvent::Retry(retry_status) => {
3356 retry_events.push(retry_status);
3357 }
3358 ThreadEvent::Stop(..) => break,
3359 _ => {}
3360 }
3361 }
3362
3363 assert_eq!(retry_events.len(), 1);
3364 assert!(matches!(
3365 retry_events[0],
3366 acp_thread::RetryStatus { attempt: 1, .. }
3367 ));
3368 thread.read_with(cx, |thread, _cx| {
3369 assert_eq!(
3370 thread.to_markdown(),
3371 indoc! {"
3372 ## User
3373
3374 Hello!
3375
3376 ## Assistant
3377
3378 Hey,
3379
3380 [resume]
3381
3382 ## Assistant
3383
3384 there!
3385 "}
3386 )
3387 });
3388}
3389
3390#[gpui::test]
3391async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3392 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3393 let fake_model = model.as_fake();
3394
3395 let events = thread
3396 .update(cx, |thread, cx| {
3397 thread.add_tool(EchoTool, None);
3398 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3399 })
3400 .unwrap();
3401 cx.run_until_parked();
3402
3403 let tool_use_1 = LanguageModelToolUse {
3404 id: "tool_1".into(),
3405 name: EchoTool::NAME.into(),
3406 raw_input: json!({"text": "test"}).to_string(),
3407 input: json!({"text": "test"}),
3408 is_input_complete: true,
3409 thought_signature: None,
3410 };
3411 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3412 tool_use_1.clone(),
3413 ));
3414 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3415 provider: LanguageModelProviderName::new("Anthropic"),
3416 retry_after: Some(Duration::from_secs(3)),
3417 });
3418 fake_model.end_last_completion_stream();
3419
3420 cx.executor().advance_clock(Duration::from_secs(3));
3421 let completion = fake_model.pending_completions().pop().unwrap();
3422 assert_eq!(
3423 completion.messages[1..],
3424 vec![
3425 LanguageModelRequestMessage {
3426 role: Role::User,
3427 content: vec!["Call the echo tool!".into()],
3428 cache: false,
3429 reasoning_details: None,
3430 },
3431 LanguageModelRequestMessage {
3432 role: Role::Assistant,
3433 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3434 cache: false,
3435 reasoning_details: None,
3436 },
3437 LanguageModelRequestMessage {
3438 role: Role::User,
3439 content: vec![language_model::MessageContent::ToolResult(
3440 LanguageModelToolResult {
3441 tool_use_id: tool_use_1.id.clone(),
3442 tool_name: tool_use_1.name.clone(),
3443 is_error: false,
3444 content: "test".into(),
3445 output: Some("test".into())
3446 }
3447 )],
3448 cache: true,
3449 reasoning_details: None,
3450 },
3451 ]
3452 );
3453
3454 fake_model.send_last_completion_stream_text_chunk("Done");
3455 fake_model.end_last_completion_stream();
3456 cx.run_until_parked();
3457 events.collect::<Vec<_>>().await;
3458 thread.read_with(cx, |thread, _cx| {
3459 assert_eq!(
3460 thread.last_message(),
3461 Some(Message::Agent(AgentMessage {
3462 content: vec![AgentMessageContent::Text("Done".into())],
3463 tool_results: IndexMap::default(),
3464 reasoning_details: None,
3465 }))
3466 );
3467 })
3468}
3469
3470#[gpui::test]
3471async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3472 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3473 let fake_model = model.as_fake();
3474
3475 let mut events = thread
3476 .update(cx, |thread, cx| {
3477 thread.send(UserMessageId::new(), ["Hello!"], cx)
3478 })
3479 .unwrap();
3480 cx.run_until_parked();
3481
3482 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3483 fake_model.send_last_completion_stream_error(
3484 LanguageModelCompletionError::ServerOverloaded {
3485 provider: LanguageModelProviderName::new("Anthropic"),
3486 retry_after: Some(Duration::from_secs(3)),
3487 },
3488 );
3489 fake_model.end_last_completion_stream();
3490 cx.executor().advance_clock(Duration::from_secs(3));
3491 cx.run_until_parked();
3492 }
3493
3494 let mut errors = Vec::new();
3495 let mut retry_events = Vec::new();
3496 while let Some(event) = events.next().await {
3497 match event {
3498 Ok(ThreadEvent::Retry(retry_status)) => {
3499 retry_events.push(retry_status);
3500 }
3501 Ok(ThreadEvent::Stop(..)) => break,
3502 Err(error) => errors.push(error),
3503 _ => {}
3504 }
3505 }
3506
3507 assert_eq!(
3508 retry_events.len(),
3509 crate::thread::MAX_RETRY_ATTEMPTS as usize
3510 );
3511 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3512 assert_eq!(retry_events[i].attempt, i + 1);
3513 }
3514 assert_eq!(errors.len(), 1);
3515 let error = errors[0]
3516 .downcast_ref::<LanguageModelCompletionError>()
3517 .unwrap();
3518 assert!(matches!(
3519 error,
3520 LanguageModelCompletionError::ServerOverloaded { .. }
3521 ));
3522}
3523
3524/// Filters out the stop events for asserting against in tests
3525fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3526 result_events
3527 .into_iter()
3528 .filter_map(|event| match event.unwrap() {
3529 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3530 _ => None,
3531 })
3532 .collect()
3533}
3534
3535struct ThreadTest {
3536 model: Arc<dyn LanguageModel>,
3537 thread: Entity<Thread>,
3538 project_context: Entity<ProjectContext>,
3539 context_server_store: Entity<ContextServerStore>,
3540 fs: Arc<FakeFs>,
3541}
3542
3543enum TestModel {
3544 Sonnet4,
3545 Fake,
3546}
3547
3548impl TestModel {
3549 fn id(&self) -> LanguageModelId {
3550 match self {
3551 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3552 TestModel::Fake => unreachable!(),
3553 }
3554 }
3555}
3556
3557async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3558 cx.executor().allow_parking();
3559
3560 let fs = FakeFs::new(cx.background_executor.clone());
3561 fs.create_dir(paths::settings_file().parent().unwrap())
3562 .await
3563 .unwrap();
3564 fs.insert_file(
3565 paths::settings_file(),
3566 json!({
3567 "agent": {
3568 "default_profile": "test-profile",
3569 "profiles": {
3570 "test-profile": {
3571 "name": "Test Profile",
3572 "tools": {
3573 EchoTool::NAME: true,
3574 DelayTool::NAME: true,
3575 WordListTool::NAME: true,
3576 ToolRequiringPermission::NAME: true,
3577 InfiniteTool::NAME: true,
3578 CancellationAwareTool::NAME: true,
3579 (TerminalTool::NAME): true,
3580 }
3581 }
3582 }
3583 }
3584 })
3585 .to_string()
3586 .into_bytes(),
3587 )
3588 .await;
3589
3590 cx.update(|cx| {
3591 settings::init(cx);
3592
3593 match model {
3594 TestModel::Fake => {}
3595 TestModel::Sonnet4 => {
3596 gpui_tokio::init(cx);
3597 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3598 cx.set_http_client(Arc::new(http_client));
3599 let client = Client::production(cx);
3600 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3601 language_model::init(client.clone(), cx);
3602 language_models::init(user_store, client.clone(), cx);
3603 }
3604 };
3605
3606 watch_settings(fs.clone(), cx);
3607 });
3608
3609 let templates = Templates::new();
3610
3611 fs.insert_tree(path!("/test"), json!({})).await;
3612 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3613
3614 let model = cx
3615 .update(|cx| {
3616 if let TestModel::Fake = model {
3617 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3618 } else {
3619 let model_id = model.id();
3620 let models = LanguageModelRegistry::read_global(cx);
3621 let model = models
3622 .available_models(cx)
3623 .find(|model| model.id() == model_id)
3624 .unwrap();
3625
3626 let provider = models.provider(&model.provider_id()).unwrap();
3627 let authenticated = provider.authenticate(cx);
3628
3629 cx.spawn(async move |_cx| {
3630 authenticated.await.unwrap();
3631 model
3632 })
3633 }
3634 })
3635 .await;
3636
3637 let project_context = cx.new(|_cx| ProjectContext::default());
3638 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3639 let context_server_registry =
3640 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3641 let thread = cx.new(|cx| {
3642 Thread::new(
3643 project,
3644 project_context.clone(),
3645 context_server_registry,
3646 templates,
3647 Some(model.clone()),
3648 cx,
3649 )
3650 });
3651 ThreadTest {
3652 model,
3653 thread,
3654 project_context,
3655 context_server_store,
3656 fs,
3657 }
3658}
3659
3660#[cfg(test)]
3661#[ctor::ctor]
3662fn init_logger() {
3663 if std::env::var("RUST_LOG").is_ok() {
3664 env_logger::init();
3665 }
3666}
3667
3668fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3669 let fs = fs.clone();
3670 cx.spawn({
3671 async move |cx| {
3672 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3673 cx.background_executor(),
3674 fs,
3675 paths::settings_file().clone(),
3676 );
3677 let _watcher_task = watcher_task;
3678
3679 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3680 cx.update(|cx| {
3681 SettingsStore::update_global(cx, |settings, cx| {
3682 settings.set_user_settings(&new_settings_content, cx)
3683 })
3684 })
3685 .ok();
3686 }
3687 }
3688 })
3689 .detach();
3690}
3691
3692fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3693 completion
3694 .tools
3695 .iter()
3696 .map(|tool| tool.name.clone())
3697 .collect()
3698}
3699
3700fn setup_context_server(
3701 name: &'static str,
3702 tools: Vec<context_server::types::Tool>,
3703 context_server_store: &Entity<ContextServerStore>,
3704 cx: &mut TestAppContext,
3705) -> mpsc::UnboundedReceiver<(
3706 context_server::types::CallToolParams,
3707 oneshot::Sender<context_server::types::CallToolResponse>,
3708)> {
3709 cx.update(|cx| {
3710 let mut settings = ProjectSettings::get_global(cx).clone();
3711 settings.context_servers.insert(
3712 name.into(),
3713 project::project_settings::ContextServerSettings::Stdio {
3714 enabled: true,
3715 remote: false,
3716 command: ContextServerCommand {
3717 path: "somebinary".into(),
3718 args: Vec::new(),
3719 env: None,
3720 timeout: None,
3721 },
3722 },
3723 );
3724 ProjectSettings::override_global(settings, cx);
3725 });
3726
3727 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3728 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3729 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3730 context_server::types::InitializeResponse {
3731 protocol_version: context_server::types::ProtocolVersion(
3732 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3733 ),
3734 server_info: context_server::types::Implementation {
3735 name: name.into(),
3736 version: "1.0.0".to_string(),
3737 },
3738 capabilities: context_server::types::ServerCapabilities {
3739 tools: Some(context_server::types::ToolsCapabilities {
3740 list_changed: Some(true),
3741 }),
3742 ..Default::default()
3743 },
3744 meta: None,
3745 }
3746 })
3747 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3748 let tools = tools.clone();
3749 async move {
3750 context_server::types::ListToolsResponse {
3751 tools,
3752 next_cursor: None,
3753 meta: None,
3754 }
3755 }
3756 })
3757 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3758 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3759 async move {
3760 let (response_tx, response_rx) = oneshot::channel();
3761 mcp_tool_calls_tx
3762 .unbounded_send((params, response_tx))
3763 .unwrap();
3764 response_rx.await.unwrap()
3765 }
3766 });
3767 context_server_store.update(cx, |store, cx| {
3768 store.start_server(
3769 Arc::new(ContextServer::new(
3770 ContextServerId(name.into()),
3771 Arc::new(fake_transport),
3772 )),
3773 cx,
3774 );
3775 });
3776 cx.run_until_parked();
3777 mcp_tool_calls_rx
3778}
3779
3780#[gpui::test]
3781async fn test_tokens_before_message(cx: &mut TestAppContext) {
3782 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3783 let fake_model = model.as_fake();
3784
3785 // First message
3786 let message_1_id = UserMessageId::new();
3787 thread
3788 .update(cx, |thread, cx| {
3789 thread.send(message_1_id.clone(), ["First message"], cx)
3790 })
3791 .unwrap();
3792 cx.run_until_parked();
3793
3794 // Before any response, tokens_before_message should return None for first message
3795 thread.read_with(cx, |thread, _| {
3796 assert_eq!(
3797 thread.tokens_before_message(&message_1_id),
3798 None,
3799 "First message should have no tokens before it"
3800 );
3801 });
3802
3803 // Complete first message with usage
3804 fake_model.send_last_completion_stream_text_chunk("Response 1");
3805 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3806 language_model::TokenUsage {
3807 input_tokens: 100,
3808 output_tokens: 50,
3809 cache_creation_input_tokens: 0,
3810 cache_read_input_tokens: 0,
3811 },
3812 ));
3813 fake_model.end_last_completion_stream();
3814 cx.run_until_parked();
3815
3816 // First message still has no tokens before it
3817 thread.read_with(cx, |thread, _| {
3818 assert_eq!(
3819 thread.tokens_before_message(&message_1_id),
3820 None,
3821 "First message should still have no tokens before it after response"
3822 );
3823 });
3824
3825 // Second message
3826 let message_2_id = UserMessageId::new();
3827 thread
3828 .update(cx, |thread, cx| {
3829 thread.send(message_2_id.clone(), ["Second message"], cx)
3830 })
3831 .unwrap();
3832 cx.run_until_parked();
3833
3834 // Second message should have first message's input tokens before it
3835 thread.read_with(cx, |thread, _| {
3836 assert_eq!(
3837 thread.tokens_before_message(&message_2_id),
3838 Some(100),
3839 "Second message should have 100 tokens before it (from first request)"
3840 );
3841 });
3842
3843 // Complete second message
3844 fake_model.send_last_completion_stream_text_chunk("Response 2");
3845 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3846 language_model::TokenUsage {
3847 input_tokens: 250, // Total for this request (includes previous context)
3848 output_tokens: 75,
3849 cache_creation_input_tokens: 0,
3850 cache_read_input_tokens: 0,
3851 },
3852 ));
3853 fake_model.end_last_completion_stream();
3854 cx.run_until_parked();
3855
3856 // Third message
3857 let message_3_id = UserMessageId::new();
3858 thread
3859 .update(cx, |thread, cx| {
3860 thread.send(message_3_id.clone(), ["Third message"], cx)
3861 })
3862 .unwrap();
3863 cx.run_until_parked();
3864
3865 // Third message should have second message's input tokens (250) before it
3866 thread.read_with(cx, |thread, _| {
3867 assert_eq!(
3868 thread.tokens_before_message(&message_3_id),
3869 Some(250),
3870 "Third message should have 250 tokens before it (from second request)"
3871 );
3872 // Second message should still have 100
3873 assert_eq!(
3874 thread.tokens_before_message(&message_2_id),
3875 Some(100),
3876 "Second message should still have 100 tokens before it"
3877 );
3878 // First message still has none
3879 assert_eq!(
3880 thread.tokens_before_message(&message_1_id),
3881 None,
3882 "First message should still have no tokens before it"
3883 );
3884 });
3885}
3886
3887#[gpui::test]
3888async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3889 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3890 let fake_model = model.as_fake();
3891
3892 // Set up three messages with responses
3893 let message_1_id = UserMessageId::new();
3894 thread
3895 .update(cx, |thread, cx| {
3896 thread.send(message_1_id.clone(), ["Message 1"], cx)
3897 })
3898 .unwrap();
3899 cx.run_until_parked();
3900 fake_model.send_last_completion_stream_text_chunk("Response 1");
3901 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3902 language_model::TokenUsage {
3903 input_tokens: 100,
3904 output_tokens: 50,
3905 cache_creation_input_tokens: 0,
3906 cache_read_input_tokens: 0,
3907 },
3908 ));
3909 fake_model.end_last_completion_stream();
3910 cx.run_until_parked();
3911
3912 let message_2_id = UserMessageId::new();
3913 thread
3914 .update(cx, |thread, cx| {
3915 thread.send(message_2_id.clone(), ["Message 2"], cx)
3916 })
3917 .unwrap();
3918 cx.run_until_parked();
3919 fake_model.send_last_completion_stream_text_chunk("Response 2");
3920 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3921 language_model::TokenUsage {
3922 input_tokens: 250,
3923 output_tokens: 75,
3924 cache_creation_input_tokens: 0,
3925 cache_read_input_tokens: 0,
3926 },
3927 ));
3928 fake_model.end_last_completion_stream();
3929 cx.run_until_parked();
3930
3931 // Verify initial state
3932 thread.read_with(cx, |thread, _| {
3933 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3934 });
3935
3936 // Truncate at message 2 (removes message 2 and everything after)
3937 thread
3938 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3939 .unwrap();
3940 cx.run_until_parked();
3941
3942 // After truncation, message_2_id no longer exists, so lookup should return None
3943 thread.read_with(cx, |thread, _| {
3944 assert_eq!(
3945 thread.tokens_before_message(&message_2_id),
3946 None,
3947 "After truncation, message 2 no longer exists"
3948 );
3949 // Message 1 still exists but has no tokens before it
3950 assert_eq!(
3951 thread.tokens_before_message(&message_1_id),
3952 None,
3953 "First message still has no tokens before it"
3954 );
3955 });
3956}
3957
3958#[gpui::test]
3959async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3960 init_test(cx);
3961
3962 let fs = FakeFs::new(cx.executor());
3963 fs.insert_tree("/root", json!({})).await;
3964 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3965
3966 // Test 1: Deny rule blocks command
3967 {
3968 let environment = Rc::new(cx.update(|cx| {
3969 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3970 }));
3971
3972 cx.update(|cx| {
3973 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3974 settings.tool_permissions.tools.insert(
3975 TerminalTool::NAME.into(),
3976 agent_settings::ToolRules {
3977 default: Some(settings::ToolPermissionMode::Confirm),
3978 always_allow: vec![],
3979 always_deny: vec![
3980 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3981 ],
3982 always_confirm: vec![],
3983 invalid_patterns: vec![],
3984 },
3985 );
3986 agent_settings::AgentSettings::override_global(settings, cx);
3987 });
3988
3989 #[allow(clippy::arc_with_non_send_sync)]
3990 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3991 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3992
3993 let task = cx.update(|cx| {
3994 tool.run(
3995 crate::TerminalToolInput {
3996 command: "rm -rf /".to_string(),
3997 cd: ".".to_string(),
3998 timeout_ms: None,
3999 },
4000 event_stream,
4001 cx,
4002 )
4003 });
4004
4005 let result = task.await;
4006 assert!(
4007 result.is_err(),
4008 "expected command to be blocked by deny rule"
4009 );
4010 let err_msg = result.unwrap_err().to_string().to_lowercase();
4011 assert!(
4012 err_msg.contains("blocked"),
4013 "error should mention the command was blocked"
4014 );
4015 }
4016
4017 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4018 {
4019 let environment = Rc::new(cx.update(|cx| {
4020 FakeThreadEnvironment::default()
4021 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4022 }));
4023
4024 cx.update(|cx| {
4025 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4026 settings.tool_permissions.tools.insert(
4027 TerminalTool::NAME.into(),
4028 agent_settings::ToolRules {
4029 default: Some(settings::ToolPermissionMode::Deny),
4030 always_allow: vec![
4031 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4032 ],
4033 always_deny: vec![],
4034 always_confirm: vec![],
4035 invalid_patterns: vec![],
4036 },
4037 );
4038 agent_settings::AgentSettings::override_global(settings, cx);
4039 });
4040
4041 #[allow(clippy::arc_with_non_send_sync)]
4042 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4043 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4044
4045 let task = cx.update(|cx| {
4046 tool.run(
4047 crate::TerminalToolInput {
4048 command: "echo hello".to_string(),
4049 cd: ".".to_string(),
4050 timeout_ms: None,
4051 },
4052 event_stream,
4053 cx,
4054 )
4055 });
4056
4057 let update = rx.expect_update_fields().await;
4058 assert!(
4059 update.content.iter().any(|blocks| {
4060 blocks
4061 .iter()
4062 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4063 }),
4064 "expected terminal content (allow rule should skip confirmation and override default deny)"
4065 );
4066
4067 let result = task.await;
4068 assert!(
4069 result.is_ok(),
4070 "expected command to succeed without confirmation"
4071 );
4072 }
4073
4074 // Test 3: global default: allow does NOT override always_confirm patterns
4075 {
4076 let environment = Rc::new(cx.update(|cx| {
4077 FakeThreadEnvironment::default()
4078 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4079 }));
4080
4081 cx.update(|cx| {
4082 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4083 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4084 settings.tool_permissions.tools.insert(
4085 TerminalTool::NAME.into(),
4086 agent_settings::ToolRules {
4087 default: Some(settings::ToolPermissionMode::Allow),
4088 always_allow: vec![],
4089 always_deny: vec![],
4090 always_confirm: vec![
4091 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4092 ],
4093 invalid_patterns: vec![],
4094 },
4095 );
4096 agent_settings::AgentSettings::override_global(settings, cx);
4097 });
4098
4099 #[allow(clippy::arc_with_non_send_sync)]
4100 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4101 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4102
4103 let _task = cx.update(|cx| {
4104 tool.run(
4105 crate::TerminalToolInput {
4106 command: "sudo rm file".to_string(),
4107 cd: ".".to_string(),
4108 timeout_ms: None,
4109 },
4110 event_stream,
4111 cx,
4112 )
4113 });
4114
4115 // With global default: allow, confirm patterns are still respected
4116 // The expect_authorization() call will panic if no authorization is requested,
4117 // which validates that the confirm pattern still triggers confirmation
4118 let _auth = rx.expect_authorization().await;
4119
4120 drop(_task);
4121 }
4122
4123 // Test 4: tool-specific default: deny is respected even with global default: allow
4124 {
4125 let environment = Rc::new(cx.update(|cx| {
4126 FakeThreadEnvironment::default()
4127 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4128 }));
4129
4130 cx.update(|cx| {
4131 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4132 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4133 settings.tool_permissions.tools.insert(
4134 TerminalTool::NAME.into(),
4135 agent_settings::ToolRules {
4136 default: Some(settings::ToolPermissionMode::Deny),
4137 always_allow: vec![],
4138 always_deny: vec![],
4139 always_confirm: vec![],
4140 invalid_patterns: vec![],
4141 },
4142 );
4143 agent_settings::AgentSettings::override_global(settings, cx);
4144 });
4145
4146 #[allow(clippy::arc_with_non_send_sync)]
4147 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4148 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4149
4150 let task = cx.update(|cx| {
4151 tool.run(
4152 crate::TerminalToolInput {
4153 command: "echo hello".to_string(),
4154 cd: ".".to_string(),
4155 timeout_ms: None,
4156 },
4157 event_stream,
4158 cx,
4159 )
4160 });
4161
4162 // tool-specific default: deny is respected even with global default: allow
4163 let result = task.await;
4164 assert!(
4165 result.is_err(),
4166 "expected command to be blocked by tool-specific deny default"
4167 );
4168 let err_msg = result.unwrap_err().to_string().to_lowercase();
4169 assert!(
4170 err_msg.contains("disabled"),
4171 "error should mention the tool is disabled, got: {err_msg}"
4172 );
4173 }
4174}
4175
4176#[gpui::test]
4177async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4178 init_test(cx);
4179 cx.update(|cx| {
4180 LanguageModelRegistry::test(cx);
4181 });
4182 cx.update(|cx| {
4183 cx.update_flags(true, vec!["subagents".to_string()]);
4184 });
4185
4186 let fs = FakeFs::new(cx.executor());
4187 fs.insert_tree(
4188 "/",
4189 json!({
4190 "a": {
4191 "b.md": "Lorem"
4192 }
4193 }),
4194 )
4195 .await;
4196 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4197 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4198 let agent = NativeAgent::new(
4199 project.clone(),
4200 thread_store.clone(),
4201 Templates::new(),
4202 None,
4203 fs.clone(),
4204 &mut cx.to_async(),
4205 )
4206 .await
4207 .unwrap();
4208 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4209
4210 let acp_thread = cx
4211 .update(|cx| {
4212 connection
4213 .clone()
4214 .new_session(project.clone(), Path::new(""), cx)
4215 })
4216 .await
4217 .unwrap();
4218 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4219 let thread = agent.read_with(cx, |agent, _| {
4220 agent.sessions.get(&session_id).unwrap().thread.clone()
4221 });
4222 let model = Arc::new(FakeLanguageModel::default());
4223
4224 // Ensure empty threads are not saved, even if they get mutated.
4225 thread.update(cx, |thread, cx| {
4226 thread.set_model(model.clone(), cx);
4227 });
4228 cx.run_until_parked();
4229
4230 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4231 cx.run_until_parked();
4232 model.send_last_completion_stream_text_chunk("spawning subagent");
4233 let subagent_tool_input = SubagentToolInput {
4234 label: "label".to_string(),
4235 prompt: "subagent task prompt".to_string(),
4236 timeout: None,
4237 };
4238 let subagent_tool_use = LanguageModelToolUse {
4239 id: "subagent_1".into(),
4240 name: SubagentTool::NAME.into(),
4241 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4242 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4243 is_input_complete: true,
4244 thought_signature: None,
4245 };
4246 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4247 subagent_tool_use,
4248 ));
4249 model.end_last_completion_stream();
4250
4251 cx.run_until_parked();
4252
4253 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4254 thread
4255 .running_subagent_ids(cx)
4256 .get(0)
4257 .expect("subagent thread should be running")
4258 .clone()
4259 });
4260
4261 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4262 agent
4263 .sessions
4264 .get(&subagent_session_id)
4265 .expect("subagent session should exist")
4266 .acp_thread
4267 .clone()
4268 });
4269
4270 model.send_last_completion_stream_text_chunk("subagent task response");
4271 model.end_last_completion_stream();
4272
4273 cx.run_until_parked();
4274
4275 assert_eq!(
4276 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4277 indoc! {"
4278 ## User
4279
4280 subagent task prompt
4281
4282 ## Assistant
4283
4284 subagent task response
4285
4286 "}
4287 );
4288
4289 model.send_last_completion_stream_text_chunk("Response");
4290 model.end_last_completion_stream();
4291
4292 send.await.unwrap();
4293
4294 assert_eq!(
4295 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4296 format!(
4297 indoc! {r#"
4298 ## User
4299
4300 Prompt
4301
4302 ## Assistant
4303
4304 spawning subagent
4305
4306 **Tool Call: label**
4307 Status: Completed
4308
4309 ```json
4310 {{
4311 "session_id": "{}",
4312 "output": "subagent task response\n"
4313 }}
4314 ```
4315
4316 ## Assistant
4317
4318 Response
4319
4320 "#},
4321 subagent_session_id
4322 )
4323 );
4324}
4325
4326#[gpui::test]
4327async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4328 init_test(cx);
4329 cx.update(|cx| {
4330 LanguageModelRegistry::test(cx);
4331 });
4332 cx.update(|cx| {
4333 cx.update_flags(true, vec!["subagents".to_string()]);
4334 });
4335
4336 let fs = FakeFs::new(cx.executor());
4337 fs.insert_tree(
4338 "/",
4339 json!({
4340 "a": {
4341 "b.md": "Lorem"
4342 }
4343 }),
4344 )
4345 .await;
4346 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4347 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4348 let agent = NativeAgent::new(
4349 project.clone(),
4350 thread_store.clone(),
4351 Templates::new(),
4352 None,
4353 fs.clone(),
4354 &mut cx.to_async(),
4355 )
4356 .await
4357 .unwrap();
4358 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4359
4360 let acp_thread = cx
4361 .update(|cx| {
4362 connection
4363 .clone()
4364 .new_session(project.clone(), Path::new(""), cx)
4365 })
4366 .await
4367 .unwrap();
4368 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4369 let thread = agent.read_with(cx, |agent, _| {
4370 agent.sessions.get(&session_id).unwrap().thread.clone()
4371 });
4372 let model = Arc::new(FakeLanguageModel::default());
4373
4374 // Ensure empty threads are not saved, even if they get mutated.
4375 thread.update(cx, |thread, cx| {
4376 thread.set_model(model.clone(), cx);
4377 });
4378 cx.run_until_parked();
4379
4380 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4381 cx.run_until_parked();
4382 model.send_last_completion_stream_text_chunk("spawning subagent");
4383 let subagent_tool_input = SubagentToolInput {
4384 label: "label".to_string(),
4385 prompt: "subagent task prompt".to_string(),
4386 timeout: None,
4387 };
4388 let subagent_tool_use = LanguageModelToolUse {
4389 id: "subagent_1".into(),
4390 name: SubagentTool::NAME.into(),
4391 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4392 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4393 is_input_complete: true,
4394 thought_signature: None,
4395 };
4396 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4397 subagent_tool_use,
4398 ));
4399 model.end_last_completion_stream();
4400
4401 cx.run_until_parked();
4402
4403 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4404 thread
4405 .running_subagent_ids(cx)
4406 .get(0)
4407 .expect("subagent thread should be running")
4408 .clone()
4409 });
4410 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4411 agent
4412 .sessions
4413 .get(&subagent_session_id)
4414 .expect("subagent session should exist")
4415 .acp_thread
4416 .clone()
4417 });
4418
4419 // model.send_last_completion_stream_text_chunk("subagent task response");
4420 // model.end_last_completion_stream();
4421
4422 // cx.run_until_parked();
4423
4424 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4425
4426 cx.run_until_parked();
4427
4428 send.await.unwrap();
4429
4430 acp_thread.read_with(cx, |thread, cx| {
4431 assert_eq!(thread.status(), ThreadStatus::Idle);
4432 assert_eq!(
4433 thread.to_markdown(cx),
4434 indoc! {"
4435 ## User
4436
4437 Prompt
4438
4439 ## Assistant
4440
4441 spawning subagent
4442
4443 **Tool Call: label**
4444 Status: Canceled
4445
4446 "}
4447 );
4448 });
4449 subagent_acp_thread.read_with(cx, |thread, cx| {
4450 assert_eq!(thread.status(), ThreadStatus::Idle);
4451 assert_eq!(
4452 thread.to_markdown(cx),
4453 indoc! {"
4454 ## User
4455
4456 subagent task prompt
4457
4458 "}
4459 );
4460 });
4461}
4462
4463#[gpui::test]
4464async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4465 init_test(cx);
4466
4467 cx.update(|cx| {
4468 cx.update_flags(true, vec!["subagents".to_string()]);
4469 });
4470
4471 let fs = FakeFs::new(cx.executor());
4472 fs.insert_tree(path!("/test"), json!({})).await;
4473 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4474 let project_context = cx.new(|_cx| ProjectContext::default());
4475 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4476 let context_server_registry =
4477 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4478 let model = Arc::new(FakeLanguageModel::default());
4479
4480 let environment = Rc::new(cx.update(|cx| {
4481 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4482 }));
4483
4484 let thread = cx.new(|cx| {
4485 let mut thread = Thread::new(
4486 project.clone(),
4487 project_context,
4488 context_server_registry,
4489 Templates::new(),
4490 Some(model),
4491 cx,
4492 );
4493 thread.add_default_tools(None, environment, cx);
4494 thread
4495 });
4496
4497 thread.read_with(cx, |thread, _| {
4498 assert!(
4499 thread.has_registered_tool(SubagentTool::NAME),
4500 "subagent tool should be present when feature flag is enabled"
4501 );
4502 });
4503}
4504
4505#[gpui::test]
4506async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4507 init_test(cx);
4508
4509 cx.update(|cx| {
4510 cx.update_flags(true, vec!["subagents".to_string()]);
4511 });
4512
4513 let fs = FakeFs::new(cx.executor());
4514 fs.insert_tree(path!("/test"), json!({})).await;
4515 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4516 let project_context = cx.new(|_cx| ProjectContext::default());
4517 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4518 let context_server_registry =
4519 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4520 let model = Arc::new(FakeLanguageModel::default());
4521
4522 let parent_thread = cx.new(|cx| {
4523 Thread::new(
4524 project.clone(),
4525 project_context,
4526 context_server_registry,
4527 Templates::new(),
4528 Some(model.clone()),
4529 cx,
4530 )
4531 });
4532
4533 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4534 subagent_thread.read_with(cx, |subagent_thread, cx| {
4535 assert!(subagent_thread.is_subagent());
4536 assert_eq!(subagent_thread.depth(), 1);
4537 assert_eq!(
4538 subagent_thread.model().map(|model| model.id()),
4539 Some(model.id())
4540 );
4541 assert_eq!(
4542 subagent_thread.parent_thread_id(),
4543 Some(parent_thread.read(cx).id().clone())
4544 );
4545 });
4546}
4547
4548#[gpui::test]
4549async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4550 init_test(cx);
4551
4552 cx.update(|cx| {
4553 cx.update_flags(true, vec!["subagents".to_string()]);
4554 });
4555
4556 let fs = FakeFs::new(cx.executor());
4557 fs.insert_tree(path!("/test"), json!({})).await;
4558 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4559 let project_context = cx.new(|_cx| ProjectContext::default());
4560 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4561 let context_server_registry =
4562 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4563 let model = Arc::new(FakeLanguageModel::default());
4564 let environment = Rc::new(cx.update(|cx| {
4565 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4566 }));
4567
4568 let deep_parent_thread = cx.new(|cx| {
4569 let mut thread = Thread::new(
4570 project.clone(),
4571 project_context,
4572 context_server_registry,
4573 Templates::new(),
4574 Some(model.clone()),
4575 cx,
4576 );
4577 thread.set_subagent_context(SubagentContext {
4578 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4579 depth: MAX_SUBAGENT_DEPTH - 1,
4580 });
4581 thread
4582 });
4583 let deep_subagent_thread = cx.new(|cx| {
4584 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4585 thread.add_default_tools(None, environment, cx);
4586 thread
4587 });
4588
4589 deep_subagent_thread.read_with(cx, |thread, _| {
4590 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4591 assert!(
4592 !thread.has_registered_tool(SubagentTool::NAME),
4593 "subagent tool should not be present at max depth"
4594 );
4595 });
4596}
4597
4598#[gpui::test]
4599async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4600 init_test(cx);
4601
4602 cx.update(|cx| {
4603 cx.update_flags(true, vec!["subagents".to_string()]);
4604 });
4605
4606 let fs = FakeFs::new(cx.executor());
4607 fs.insert_tree(path!("/test"), json!({})).await;
4608 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4609 let project_context = cx.new(|_cx| ProjectContext::default());
4610 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4611 let context_server_registry =
4612 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4613 let model = Arc::new(FakeLanguageModel::default());
4614
4615 let parent = cx.new(|cx| {
4616 Thread::new(
4617 project.clone(),
4618 project_context.clone(),
4619 context_server_registry.clone(),
4620 Templates::new(),
4621 Some(model.clone()),
4622 cx,
4623 )
4624 });
4625
4626 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4627
4628 parent.update(cx, |thread, _cx| {
4629 thread.register_running_subagent(subagent.downgrade());
4630 });
4631
4632 subagent
4633 .update(cx, |thread, cx| {
4634 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4635 })
4636 .unwrap();
4637 cx.run_until_parked();
4638
4639 subagent.read_with(cx, |thread, _| {
4640 assert!(!thread.is_turn_complete(), "subagent should be running");
4641 });
4642
4643 parent.update(cx, |thread, cx| {
4644 thread.cancel(cx).detach();
4645 });
4646
4647 subagent.read_with(cx, |thread, _| {
4648 assert!(
4649 thread.is_turn_complete(),
4650 "subagent should be cancelled when parent cancels"
4651 );
4652 });
4653}
4654
4655#[gpui::test]
4656async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
4657 cx: &mut TestAppContext,
4658) {
4659 init_test(cx);
4660
4661 always_allow_tools(cx);
4662
4663 cx.update(|cx| {
4664 cx.update_flags(true, vec!["subagents".to_string()]);
4665 });
4666
4667 let fs = FakeFs::new(cx.executor());
4668 fs.insert_tree(path!("/test"), json!({})).await;
4669 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4670 let project_context = cx.new(|_cx| ProjectContext::default());
4671 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4672 let context_server_registry =
4673 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4674 cx.update(LanguageModelRegistry::test);
4675 let model = Arc::new(FakeLanguageModel::default());
4676 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4677 let native_agent = NativeAgent::new(
4678 project.clone(),
4679 thread_store,
4680 Templates::new(),
4681 None,
4682 fs,
4683 &mut cx.to_async(),
4684 )
4685 .await
4686 .unwrap();
4687 let parent_thread = cx.new(|cx| {
4688 Thread::new(
4689 project.clone(),
4690 project_context,
4691 context_server_registry,
4692 Templates::new(),
4693 Some(model.clone()),
4694 cx,
4695 )
4696 });
4697
4698 let subagent_handle = cx
4699 .update(|cx| {
4700 NativeThreadEnvironment::create_subagent_thread(
4701 native_agent.downgrade(),
4702 parent_thread.clone(),
4703 "some title".to_string(),
4704 "task prompt".to_string(),
4705 Some(Duration::from_secs(1)),
4706 cx,
4707 )
4708 })
4709 .expect("Failed to create subagent");
4710
4711 let summary_task = subagent_handle.wait_for_output(&cx.to_async());
4712
4713 cx.run_until_parked();
4714
4715 {
4716 let messages = model.pending_completions().last().unwrap().messages.clone();
4717 // Ensure that model received a system prompt
4718 assert_eq!(messages[0].role, Role::System);
4719 // Ensure that model received a task prompt
4720 assert_eq!(
4721 messages[1].content,
4722 vec![MessageContent::Text("task prompt".to_string())]
4723 );
4724 }
4725
4726 // Don't complete the initial model stream — let the timeout expire instead.
4727 cx.executor().advance_clock(Duration::from_secs(2));
4728 cx.run_until_parked();
4729
4730 model.end_last_completion_stream();
4731
4732 let error = summary_task.await.unwrap_err();
4733 assert_eq!(
4734 error.to_string(),
4735 "The time to complete the task was exceeded."
4736 );
4737}
4738
4739#[gpui::test]
4740async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
4741 init_test(cx);
4742
4743 always_allow_tools(cx);
4744
4745 cx.update(|cx| {
4746 cx.update_flags(true, vec!["subagents".to_string()]);
4747 });
4748
4749 let fs = FakeFs::new(cx.executor());
4750 fs.insert_tree(path!("/test"), json!({})).await;
4751 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4752 let project_context = cx.new(|_cx| ProjectContext::default());
4753 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4754 let context_server_registry =
4755 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4756 cx.update(LanguageModelRegistry::test);
4757 let model = Arc::new(FakeLanguageModel::default());
4758 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4759 let native_agent = NativeAgent::new(
4760 project.clone(),
4761 thread_store,
4762 Templates::new(),
4763 None,
4764 fs,
4765 &mut cx.to_async(),
4766 )
4767 .await
4768 .unwrap();
4769 let parent_thread = cx.new(|cx| {
4770 let mut thread = Thread::new(
4771 project.clone(),
4772 project_context,
4773 context_server_registry,
4774 Templates::new(),
4775 Some(model.clone()),
4776 cx,
4777 );
4778 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
4779 thread.add_tool(GrepTool::new(project.clone()), None);
4780 thread
4781 });
4782
4783 let _subagent_handle = cx
4784 .update(|cx| {
4785 NativeThreadEnvironment::create_subagent_thread(
4786 native_agent.downgrade(),
4787 parent_thread.clone(),
4788 "some title".to_string(),
4789 "task prompt".to_string(),
4790 Some(Duration::from_millis(10)),
4791 cx,
4792 )
4793 })
4794 .expect("Failed to create subagent");
4795
4796 cx.run_until_parked();
4797
4798 let tools = model
4799 .pending_completions()
4800 .last()
4801 .unwrap()
4802 .tools
4803 .iter()
4804 .map(|tool| tool.name.clone())
4805 .collect::<Vec<_>>();
4806 assert_eq!(tools.len(), 2);
4807 assert!(tools.contains(&"grep".to_string()));
4808 assert!(tools.contains(&"list_directory".to_string()));
4809}
4810
4811#[gpui::test]
4812async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4813 init_test(cx);
4814
4815 let fs = FakeFs::new(cx.executor());
4816 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4817 .await;
4818 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4819
4820 cx.update(|cx| {
4821 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4822 settings.tool_permissions.tools.insert(
4823 EditFileTool::NAME.into(),
4824 agent_settings::ToolRules {
4825 default: Some(settings::ToolPermissionMode::Allow),
4826 always_allow: vec![],
4827 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
4828 always_confirm: vec![],
4829 invalid_patterns: vec![],
4830 },
4831 );
4832 agent_settings::AgentSettings::override_global(settings, cx);
4833 });
4834
4835 let context_server_registry =
4836 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
4837 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
4838 let templates = crate::Templates::new();
4839 let thread = cx.new(|cx| {
4840 crate::Thread::new(
4841 project.clone(),
4842 cx.new(|_cx| prompt_store::ProjectContext::default()),
4843 context_server_registry,
4844 templates.clone(),
4845 None,
4846 cx,
4847 )
4848 });
4849
4850 #[allow(clippy::arc_with_non_send_sync)]
4851 let tool = Arc::new(crate::EditFileTool::new(
4852 project.clone(),
4853 thread.downgrade(),
4854 language_registry,
4855 templates,
4856 ));
4857 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4858
4859 let task = cx.update(|cx| {
4860 tool.run(
4861 crate::EditFileToolInput {
4862 display_description: "Edit sensitive file".to_string(),
4863 path: "root/sensitive_config.txt".into(),
4864 mode: crate::EditFileMode::Edit,
4865 },
4866 event_stream,
4867 cx,
4868 )
4869 });
4870
4871 let result = task.await;
4872 assert!(result.is_err(), "expected edit to be blocked");
4873 assert!(
4874 result.unwrap_err().to_string().contains("blocked"),
4875 "error should mention the edit was blocked"
4876 );
4877}
4878
4879#[gpui::test]
4880async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
4881 init_test(cx);
4882
4883 let fs = FakeFs::new(cx.executor());
4884 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
4885 .await;
4886 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4887
4888 cx.update(|cx| {
4889 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4890 settings.tool_permissions.tools.insert(
4891 DeletePathTool::NAME.into(),
4892 agent_settings::ToolRules {
4893 default: Some(settings::ToolPermissionMode::Allow),
4894 always_allow: vec![],
4895 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
4896 always_confirm: vec![],
4897 invalid_patterns: vec![],
4898 },
4899 );
4900 agent_settings::AgentSettings::override_global(settings, cx);
4901 });
4902
4903 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
4904
4905 #[allow(clippy::arc_with_non_send_sync)]
4906 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
4907 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4908
4909 let task = cx.update(|cx| {
4910 tool.run(
4911 crate::DeletePathToolInput {
4912 path: "root/important_data.txt".to_string(),
4913 },
4914 event_stream,
4915 cx,
4916 )
4917 });
4918
4919 let result = task.await;
4920 assert!(result.is_err(), "expected deletion to be blocked");
4921 assert!(
4922 result.unwrap_err().to_string().contains("blocked"),
4923 "error should mention the deletion was blocked"
4924 );
4925}
4926
4927#[gpui::test]
4928async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
4929 init_test(cx);
4930
4931 let fs = FakeFs::new(cx.executor());
4932 fs.insert_tree(
4933 "/root",
4934 json!({
4935 "safe.txt": "content",
4936 "protected": {}
4937 }),
4938 )
4939 .await;
4940 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4941
4942 cx.update(|cx| {
4943 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4944 settings.tool_permissions.tools.insert(
4945 MovePathTool::NAME.into(),
4946 agent_settings::ToolRules {
4947 default: Some(settings::ToolPermissionMode::Allow),
4948 always_allow: vec![],
4949 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
4950 always_confirm: vec![],
4951 invalid_patterns: vec![],
4952 },
4953 );
4954 agent_settings::AgentSettings::override_global(settings, cx);
4955 });
4956
4957 #[allow(clippy::arc_with_non_send_sync)]
4958 let tool = Arc::new(crate::MovePathTool::new(project));
4959 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4960
4961 let task = cx.update(|cx| {
4962 tool.run(
4963 crate::MovePathToolInput {
4964 source_path: "root/safe.txt".to_string(),
4965 destination_path: "root/protected/safe.txt".to_string(),
4966 },
4967 event_stream,
4968 cx,
4969 )
4970 });
4971
4972 let result = task.await;
4973 assert!(
4974 result.is_err(),
4975 "expected move to be blocked due to destination path"
4976 );
4977 assert!(
4978 result.unwrap_err().to_string().contains("blocked"),
4979 "error should mention the move was blocked"
4980 );
4981}
4982
4983#[gpui::test]
4984async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
4985 init_test(cx);
4986
4987 let fs = FakeFs::new(cx.executor());
4988 fs.insert_tree(
4989 "/root",
4990 json!({
4991 "secret.txt": "secret content",
4992 "public": {}
4993 }),
4994 )
4995 .await;
4996 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4997
4998 cx.update(|cx| {
4999 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5000 settings.tool_permissions.tools.insert(
5001 MovePathTool::NAME.into(),
5002 agent_settings::ToolRules {
5003 default: Some(settings::ToolPermissionMode::Allow),
5004 always_allow: vec![],
5005 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5006 always_confirm: vec![],
5007 invalid_patterns: vec![],
5008 },
5009 );
5010 agent_settings::AgentSettings::override_global(settings, cx);
5011 });
5012
5013 #[allow(clippy::arc_with_non_send_sync)]
5014 let tool = Arc::new(crate::MovePathTool::new(project));
5015 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5016
5017 let task = cx.update(|cx| {
5018 tool.run(
5019 crate::MovePathToolInput {
5020 source_path: "root/secret.txt".to_string(),
5021 destination_path: "root/public/not_secret.txt".to_string(),
5022 },
5023 event_stream,
5024 cx,
5025 )
5026 });
5027
5028 let result = task.await;
5029 assert!(
5030 result.is_err(),
5031 "expected move to be blocked due to source path"
5032 );
5033 assert!(
5034 result.unwrap_err().to_string().contains("blocked"),
5035 "error should mention the move was blocked"
5036 );
5037}
5038
5039#[gpui::test]
5040async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5041 init_test(cx);
5042
5043 let fs = FakeFs::new(cx.executor());
5044 fs.insert_tree(
5045 "/root",
5046 json!({
5047 "confidential.txt": "confidential data",
5048 "dest": {}
5049 }),
5050 )
5051 .await;
5052 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5053
5054 cx.update(|cx| {
5055 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5056 settings.tool_permissions.tools.insert(
5057 CopyPathTool::NAME.into(),
5058 agent_settings::ToolRules {
5059 default: Some(settings::ToolPermissionMode::Allow),
5060 always_allow: vec![],
5061 always_deny: vec![
5062 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5063 ],
5064 always_confirm: vec![],
5065 invalid_patterns: vec![],
5066 },
5067 );
5068 agent_settings::AgentSettings::override_global(settings, cx);
5069 });
5070
5071 #[allow(clippy::arc_with_non_send_sync)]
5072 let tool = Arc::new(crate::CopyPathTool::new(project));
5073 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5074
5075 let task = cx.update(|cx| {
5076 tool.run(
5077 crate::CopyPathToolInput {
5078 source_path: "root/confidential.txt".to_string(),
5079 destination_path: "root/dest/copy.txt".to_string(),
5080 },
5081 event_stream,
5082 cx,
5083 )
5084 });
5085
5086 let result = task.await;
5087 assert!(result.is_err(), "expected copy to be blocked");
5088 assert!(
5089 result.unwrap_err().to_string().contains("blocked"),
5090 "error should mention the copy was blocked"
5091 );
5092}
5093
5094#[gpui::test]
5095async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5096 init_test(cx);
5097
5098 let fs = FakeFs::new(cx.executor());
5099 fs.insert_tree(
5100 "/root",
5101 json!({
5102 "normal.txt": "normal content",
5103 "readonly": {
5104 "config.txt": "readonly content"
5105 }
5106 }),
5107 )
5108 .await;
5109 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5110
5111 cx.update(|cx| {
5112 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5113 settings.tool_permissions.tools.insert(
5114 SaveFileTool::NAME.into(),
5115 agent_settings::ToolRules {
5116 default: Some(settings::ToolPermissionMode::Allow),
5117 always_allow: vec![],
5118 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5119 always_confirm: vec![],
5120 invalid_patterns: vec![],
5121 },
5122 );
5123 agent_settings::AgentSettings::override_global(settings, cx);
5124 });
5125
5126 #[allow(clippy::arc_with_non_send_sync)]
5127 let tool = Arc::new(crate::SaveFileTool::new(project));
5128 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5129
5130 let task = cx.update(|cx| {
5131 tool.run(
5132 crate::SaveFileToolInput {
5133 paths: vec![
5134 std::path::PathBuf::from("root/normal.txt"),
5135 std::path::PathBuf::from("root/readonly/config.txt"),
5136 ],
5137 },
5138 event_stream,
5139 cx,
5140 )
5141 });
5142
5143 let result = task.await;
5144 assert!(
5145 result.is_err(),
5146 "expected save to be blocked due to denied path"
5147 );
5148 assert!(
5149 result.unwrap_err().to_string().contains("blocked"),
5150 "error should mention the save was blocked"
5151 );
5152}
5153
5154#[gpui::test]
5155async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5156 init_test(cx);
5157
5158 let fs = FakeFs::new(cx.executor());
5159 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5160 .await;
5161 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5162
5163 cx.update(|cx| {
5164 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5165 settings.tool_permissions.tools.insert(
5166 SaveFileTool::NAME.into(),
5167 agent_settings::ToolRules {
5168 default: Some(settings::ToolPermissionMode::Allow),
5169 always_allow: vec![],
5170 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5171 always_confirm: vec![],
5172 invalid_patterns: vec![],
5173 },
5174 );
5175 agent_settings::AgentSettings::override_global(settings, cx);
5176 });
5177
5178 #[allow(clippy::arc_with_non_send_sync)]
5179 let tool = Arc::new(crate::SaveFileTool::new(project));
5180 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5181
5182 let task = cx.update(|cx| {
5183 tool.run(
5184 crate::SaveFileToolInput {
5185 paths: vec![std::path::PathBuf::from("root/config.secret")],
5186 },
5187 event_stream,
5188 cx,
5189 )
5190 });
5191
5192 let result = task.await;
5193 assert!(result.is_err(), "expected save to be blocked");
5194 assert!(
5195 result.unwrap_err().to_string().contains("blocked"),
5196 "error should mention the save was blocked"
5197 );
5198}
5199
5200#[gpui::test]
5201async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5202 init_test(cx);
5203
5204 cx.update(|cx| {
5205 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5206 settings.tool_permissions.tools.insert(
5207 WebSearchTool::NAME.into(),
5208 agent_settings::ToolRules {
5209 default: Some(settings::ToolPermissionMode::Allow),
5210 always_allow: vec![],
5211 always_deny: vec![
5212 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5213 ],
5214 always_confirm: vec![],
5215 invalid_patterns: vec![],
5216 },
5217 );
5218 agent_settings::AgentSettings::override_global(settings, cx);
5219 });
5220
5221 #[allow(clippy::arc_with_non_send_sync)]
5222 let tool = Arc::new(crate::WebSearchTool);
5223 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5224
5225 let input: crate::WebSearchToolInput =
5226 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5227
5228 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5229
5230 let result = task.await;
5231 assert!(result.is_err(), "expected search to be blocked");
5232 assert!(
5233 result.unwrap_err().to_string().contains("blocked"),
5234 "error should mention the search was blocked"
5235 );
5236}
5237
5238#[gpui::test]
5239async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5240 init_test(cx);
5241
5242 let fs = FakeFs::new(cx.executor());
5243 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5244 .await;
5245 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5246
5247 cx.update(|cx| {
5248 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5249 settings.tool_permissions.tools.insert(
5250 EditFileTool::NAME.into(),
5251 agent_settings::ToolRules {
5252 default: Some(settings::ToolPermissionMode::Confirm),
5253 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5254 always_deny: vec![],
5255 always_confirm: vec![],
5256 invalid_patterns: vec![],
5257 },
5258 );
5259 agent_settings::AgentSettings::override_global(settings, cx);
5260 });
5261
5262 let context_server_registry =
5263 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5264 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5265 let templates = crate::Templates::new();
5266 let thread = cx.new(|cx| {
5267 crate::Thread::new(
5268 project.clone(),
5269 cx.new(|_cx| prompt_store::ProjectContext::default()),
5270 context_server_registry,
5271 templates.clone(),
5272 None,
5273 cx,
5274 )
5275 });
5276
5277 #[allow(clippy::arc_with_non_send_sync)]
5278 let tool = Arc::new(crate::EditFileTool::new(
5279 project,
5280 thread.downgrade(),
5281 language_registry,
5282 templates,
5283 ));
5284 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5285
5286 let _task = cx.update(|cx| {
5287 tool.run(
5288 crate::EditFileToolInput {
5289 display_description: "Edit README".to_string(),
5290 path: "root/README.md".into(),
5291 mode: crate::EditFileMode::Edit,
5292 },
5293 event_stream,
5294 cx,
5295 )
5296 });
5297
5298 cx.run_until_parked();
5299
5300 let event = rx.try_next();
5301 assert!(
5302 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5303 "expected no authorization request for allowed .md file"
5304 );
5305}
5306
5307#[gpui::test]
5308async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5309 init_test(cx);
5310
5311 let fs = FakeFs::new(cx.executor());
5312 fs.insert_tree(
5313 "/root",
5314 json!({
5315 ".zed": {
5316 "settings.json": "{}"
5317 },
5318 "README.md": "# Hello"
5319 }),
5320 )
5321 .await;
5322 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5323
5324 cx.update(|cx| {
5325 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5326 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5327 agent_settings::AgentSettings::override_global(settings, cx);
5328 });
5329
5330 let context_server_registry =
5331 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5332 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5333 let templates = crate::Templates::new();
5334 let thread = cx.new(|cx| {
5335 crate::Thread::new(
5336 project.clone(),
5337 cx.new(|_cx| prompt_store::ProjectContext::default()),
5338 context_server_registry,
5339 templates.clone(),
5340 None,
5341 cx,
5342 )
5343 });
5344
5345 #[allow(clippy::arc_with_non_send_sync)]
5346 let tool = Arc::new(crate::EditFileTool::new(
5347 project,
5348 thread.downgrade(),
5349 language_registry,
5350 templates,
5351 ));
5352
5353 // Editing a file inside .zed/ should still prompt even with global default: allow,
5354 // because local settings paths are sensitive and require confirmation regardless.
5355 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5356 let _task = cx.update(|cx| {
5357 tool.run(
5358 crate::EditFileToolInput {
5359 display_description: "Edit local settings".to_string(),
5360 path: "root/.zed/settings.json".into(),
5361 mode: crate::EditFileMode::Edit,
5362 },
5363 event_stream,
5364 cx,
5365 )
5366 });
5367
5368 let _update = rx.expect_update_fields().await;
5369 let _auth = rx.expect_authorization().await;
5370}
5371
5372#[gpui::test]
5373async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5374 init_test(cx);
5375
5376 cx.update(|cx| {
5377 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5378 settings.tool_permissions.tools.insert(
5379 FetchTool::NAME.into(),
5380 agent_settings::ToolRules {
5381 default: Some(settings::ToolPermissionMode::Allow),
5382 always_allow: vec![],
5383 always_deny: vec![
5384 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5385 ],
5386 always_confirm: vec![],
5387 invalid_patterns: vec![],
5388 },
5389 );
5390 agent_settings::AgentSettings::override_global(settings, cx);
5391 });
5392
5393 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5394
5395 #[allow(clippy::arc_with_non_send_sync)]
5396 let tool = Arc::new(crate::FetchTool::new(http_client));
5397 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5398
5399 let input: crate::FetchToolInput =
5400 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5401
5402 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5403
5404 let result = task.await;
5405 assert!(result.is_err(), "expected fetch to be blocked");
5406 assert!(
5407 result.unwrap_err().to_string().contains("blocked"),
5408 "error should mention the fetch was blocked"
5409 );
5410}
5411
5412#[gpui::test]
5413async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5414 init_test(cx);
5415
5416 cx.update(|cx| {
5417 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5418 settings.tool_permissions.tools.insert(
5419 FetchTool::NAME.into(),
5420 agent_settings::ToolRules {
5421 default: Some(settings::ToolPermissionMode::Confirm),
5422 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5423 always_deny: vec![],
5424 always_confirm: vec![],
5425 invalid_patterns: vec![],
5426 },
5427 );
5428 agent_settings::AgentSettings::override_global(settings, cx);
5429 });
5430
5431 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5432
5433 #[allow(clippy::arc_with_non_send_sync)]
5434 let tool = Arc::new(crate::FetchTool::new(http_client));
5435 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5436
5437 let input: crate::FetchToolInput =
5438 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5439
5440 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5441
5442 cx.run_until_parked();
5443
5444 let event = rx.try_next();
5445 assert!(
5446 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5447 "expected no authorization request for allowed docs.rs URL"
5448 );
5449}
5450
5451#[gpui::test]
5452async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5453 init_test(cx);
5454 always_allow_tools(cx);
5455
5456 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5457 let fake_model = model.as_fake();
5458
5459 // Add a tool so we can simulate tool calls
5460 thread.update(cx, |thread, _cx| {
5461 thread.add_tool(EchoTool, None);
5462 });
5463
5464 // Start a turn by sending a message
5465 let mut events = thread
5466 .update(cx, |thread, cx| {
5467 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5468 })
5469 .unwrap();
5470 cx.run_until_parked();
5471
5472 // Simulate the model making a tool call
5473 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5474 LanguageModelToolUse {
5475 id: "tool_1".into(),
5476 name: "echo".into(),
5477 raw_input: r#"{"text": "hello"}"#.into(),
5478 input: json!({"text": "hello"}),
5479 is_input_complete: true,
5480 thought_signature: None,
5481 },
5482 ));
5483 fake_model
5484 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5485
5486 // Signal that a message is queued before ending the stream
5487 thread.update(cx, |thread, _cx| {
5488 thread.set_has_queued_message(true);
5489 });
5490
5491 // Now end the stream - tool will run, and the boundary check should see the queue
5492 fake_model.end_last_completion_stream();
5493
5494 // Collect all events until the turn stops
5495 let all_events = collect_events_until_stop(&mut events, cx).await;
5496
5497 // Verify we received the tool call event
5498 let tool_call_ids: Vec<_> = all_events
5499 .iter()
5500 .filter_map(|e| match e {
5501 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5502 _ => None,
5503 })
5504 .collect();
5505 assert_eq!(
5506 tool_call_ids,
5507 vec!["tool_1"],
5508 "Should have received a tool call event for our echo tool"
5509 );
5510
5511 // The turn should have stopped with EndTurn
5512 let stop_reasons = stop_events(all_events);
5513 assert_eq!(
5514 stop_reasons,
5515 vec![acp::StopReason::EndTurn],
5516 "Turn should have ended after tool completion due to queued message"
5517 );
5518
5519 // Verify the queued message flag is still set
5520 thread.update(cx, |thread, _cx| {
5521 assert!(
5522 thread.has_queued_message(),
5523 "Should still have queued message flag set"
5524 );
5525 });
5526
5527 // Thread should be idle now
5528 thread.update(cx, |thread, _cx| {
5529 assert!(
5530 thread.is_turn_complete(),
5531 "Thread should not be running after turn ends"
5532 );
5533 });
5534}