1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
33};
34use pretty_assertions::assert_eq;
35use project::{
36 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
37};
38use prompt_store::ProjectContext;
39use reqwest_client::ReqwestClient;
40use schemars::JsonSchema;
41use serde::{Deserialize, Serialize};
42use serde_json::json;
43use settings::{Settings, SettingsStore};
44use std::{
45 path::Path,
46 pin::Pin,
47 rc::Rc,
48 sync::{
49 Arc,
50 atomic::{AtomicBool, Ordering},
51 },
52 time::Duration,
53};
54use util::path;
55
56mod edit_file_thread_test;
57mod test_tools;
58use test_tools::*;
59
60fn init_test(cx: &mut TestAppContext) {
61 cx.update(|cx| {
62 let settings_store = SettingsStore::test(cx);
63 cx.set_global(settings_store);
64 });
65}
66
67struct FakeTerminalHandle {
68 killed: Arc<AtomicBool>,
69 stopped_by_user: Arc<AtomicBool>,
70 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
71 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
72 output: acp::TerminalOutputResponse,
73 id: acp::TerminalId,
74}
75
76impl FakeTerminalHandle {
77 fn new_never_exits(cx: &mut App) -> Self {
78 let killed = Arc::new(AtomicBool::new(false));
79 let stopped_by_user = Arc::new(AtomicBool::new(false));
80
81 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
82
83 let wait_for_exit = cx
84 .spawn(async move |_cx| {
85 // Wait for the exit signal (sent when kill() is called)
86 let _ = exit_receiver.await;
87 acp::TerminalExitStatus::new()
88 })
89 .shared();
90
91 Self {
92 killed,
93 stopped_by_user,
94 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
95 wait_for_exit,
96 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
97 id: acp::TerminalId::new("fake_terminal".to_string()),
98 }
99 }
100
101 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
102 let killed = Arc::new(AtomicBool::new(false));
103 let stopped_by_user = Arc::new(AtomicBool::new(false));
104 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
105
106 let wait_for_exit = cx
107 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
108 .shared();
109
110 Self {
111 killed,
112 stopped_by_user,
113 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
114 wait_for_exit,
115 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
116 id: acp::TerminalId::new("fake_terminal".to_string()),
117 }
118 }
119
120 fn was_killed(&self) -> bool {
121 self.killed.load(Ordering::SeqCst)
122 }
123
124 fn set_stopped_by_user(&self, stopped: bool) {
125 self.stopped_by_user.store(stopped, Ordering::SeqCst);
126 }
127
128 fn signal_exit(&self) {
129 if let Some(sender) = self.exit_sender.borrow_mut().take() {
130 let _ = sender.send(());
131 }
132 }
133}
134
135impl crate::TerminalHandle for FakeTerminalHandle {
136 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
137 Ok(self.id.clone())
138 }
139
140 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
141 Ok(self.output.clone())
142 }
143
144 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
145 Ok(self.wait_for_exit.clone())
146 }
147
148 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
149 self.killed.store(true, Ordering::SeqCst);
150 self.signal_exit();
151 Ok(())
152 }
153
154 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
155 Ok(self.stopped_by_user.load(Ordering::SeqCst))
156 }
157}
158
159struct FakeSubagentHandle {
160 session_id: acp::SessionId,
161 wait_for_summary_task: Shared<Task<String>>,
162}
163
164impl SubagentHandle for FakeSubagentHandle {
165 fn id(&self) -> acp::SessionId {
166 self.session_id.clone()
167 }
168
169 fn wait_for_summary(&self, _summary_prompt: String, cx: &AsyncApp) -> Task<Result<String>> {
170 let task = self.wait_for_summary_task.clone();
171 cx.background_spawn(async move { Ok(task.await) })
172 }
173}
174
175#[derive(Default)]
176struct FakeThreadEnvironment {
177 terminal_handle: Option<Rc<FakeTerminalHandle>>,
178 subagent_handle: Option<Rc<FakeSubagentHandle>>,
179}
180
181impl FakeThreadEnvironment {
182 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
183 Self {
184 terminal_handle: Some(terminal_handle.into()),
185 ..self
186 }
187 }
188}
189
190impl crate::ThreadEnvironment for FakeThreadEnvironment {
191 fn create_terminal(
192 &self,
193 _command: String,
194 _cwd: Option<std::path::PathBuf>,
195 _output_byte_limit: Option<u64>,
196 _cx: &mut AsyncApp,
197 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
198 let handle = self
199 .terminal_handle
200 .clone()
201 .expect("Terminal handle not available on FakeThreadEnvironment");
202 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
203 }
204
205 fn create_subagent(
206 &self,
207 _parent_thread: Entity<Thread>,
208 _label: String,
209 _initial_prompt: String,
210 _timeout_ms: Option<Duration>,
211 _allowed_tools: Option<Vec<String>>,
212 _cx: &mut App,
213 ) -> Result<Rc<dyn SubagentHandle>> {
214 Ok(self
215 .subagent_handle
216 .clone()
217 .expect("Subagent handle not available on FakeThreadEnvironment")
218 as Rc<dyn SubagentHandle>)
219 }
220}
221
222/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
223struct MultiTerminalEnvironment {
224 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
225}
226
227impl MultiTerminalEnvironment {
228 fn new() -> Self {
229 Self {
230 handles: std::cell::RefCell::new(Vec::new()),
231 }
232 }
233
234 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
235 self.handles.borrow().clone()
236 }
237}
238
239impl crate::ThreadEnvironment for MultiTerminalEnvironment {
240 fn create_terminal(
241 &self,
242 _command: String,
243 _cwd: Option<std::path::PathBuf>,
244 _output_byte_limit: Option<u64>,
245 cx: &mut AsyncApp,
246 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
247 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
248 self.handles.borrow_mut().push(handle.clone());
249 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
250 }
251
252 fn create_subagent(
253 &self,
254 _parent_thread: Entity<Thread>,
255 _label: String,
256 _initial_prompt: String,
257 _timeout: Option<Duration>,
258 _allowed_tools: Option<Vec<String>>,
259 _cx: &mut App,
260 ) -> Result<Rc<dyn SubagentHandle>> {
261 unimplemented!()
262 }
263}
264
265fn always_allow_tools(cx: &mut TestAppContext) {
266 cx.update(|cx| {
267 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
268 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
269 agent_settings::AgentSettings::override_global(settings, cx);
270 });
271}
272
273#[gpui::test]
274async fn test_echo(cx: &mut TestAppContext) {
275 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
276 let fake_model = model.as_fake();
277
278 let events = thread
279 .update(cx, |thread, cx| {
280 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
281 })
282 .unwrap();
283 cx.run_until_parked();
284 fake_model.send_last_completion_stream_text_chunk("Hello");
285 fake_model
286 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
287 fake_model.end_last_completion_stream();
288
289 let events = events.collect().await;
290 thread.update(cx, |thread, _cx| {
291 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
292 assert_eq!(thread.last_message().unwrap().to_markdown(), "Hello\n")
293 });
294 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
295}
296
297#[gpui::test]
298async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
299 init_test(cx);
300 always_allow_tools(cx);
301
302 let fs = FakeFs::new(cx.executor());
303 let project = Project::test(fs, [], cx).await;
304
305 let environment = Rc::new(cx.update(|cx| {
306 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
307 }));
308 let handle = environment.terminal_handle.clone().unwrap();
309
310 #[allow(clippy::arc_with_non_send_sync)]
311 let tool = Arc::new(crate::TerminalTool::new(project, environment));
312 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
313
314 let task = cx.update(|cx| {
315 tool.run(
316 crate::TerminalToolInput {
317 command: "sleep 1000".to_string(),
318 cd: ".".to_string(),
319 timeout_ms: Some(5),
320 },
321 event_stream,
322 cx,
323 )
324 });
325
326 let update = rx.expect_update_fields().await;
327 assert!(
328 update.content.iter().any(|blocks| {
329 blocks
330 .iter()
331 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
332 }),
333 "expected tool call update to include terminal content"
334 );
335
336 let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
337
338 let deadline = std::time::Instant::now() + Duration::from_millis(500);
339 loop {
340 if let Some(result) = task_future.as_mut().now_or_never() {
341 let result = result.expect("terminal tool task should complete");
342
343 assert!(
344 handle.was_killed(),
345 "expected terminal handle to be killed on timeout"
346 );
347 assert!(
348 result.contains("partial output"),
349 "expected result to include terminal output, got: {result}"
350 );
351 return;
352 }
353
354 if std::time::Instant::now() >= deadline {
355 panic!("timed out waiting for terminal tool task to complete");
356 }
357
358 cx.run_until_parked();
359 cx.background_executor.timer(Duration::from_millis(1)).await;
360 }
361}
362
363#[gpui::test]
364#[ignore]
365async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
366 init_test(cx);
367 always_allow_tools(cx);
368
369 let fs = FakeFs::new(cx.executor());
370 let project = Project::test(fs, [], cx).await;
371
372 let environment = Rc::new(cx.update(|cx| {
373 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
374 }));
375 let handle = environment.terminal_handle.clone().unwrap();
376
377 #[allow(clippy::arc_with_non_send_sync)]
378 let tool = Arc::new(crate::TerminalTool::new(project, environment));
379 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
380
381 let _task = cx.update(|cx| {
382 tool.run(
383 crate::TerminalToolInput {
384 command: "sleep 1000".to_string(),
385 cd: ".".to_string(),
386 timeout_ms: None,
387 },
388 event_stream,
389 cx,
390 )
391 });
392
393 let update = rx.expect_update_fields().await;
394 assert!(
395 update.content.iter().any(|blocks| {
396 blocks
397 .iter()
398 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
399 }),
400 "expected tool call update to include terminal content"
401 );
402
403 cx.background_executor
404 .timer(Duration::from_millis(25))
405 .await;
406
407 assert!(
408 !handle.was_killed(),
409 "did not expect terminal handle to be killed without a timeout"
410 );
411}
412
413#[gpui::test]
414async fn test_thinking(cx: &mut TestAppContext) {
415 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
416 let fake_model = model.as_fake();
417
418 let events = thread
419 .update(cx, |thread, cx| {
420 thread.send(
421 UserMessageId::new(),
422 [indoc! {"
423 Testing:
424
425 Generate a thinking step where you just think the word 'Think',
426 and have your final answer be 'Hello'
427 "}],
428 cx,
429 )
430 })
431 .unwrap();
432 cx.run_until_parked();
433 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
434 text: "Think".to_string(),
435 signature: None,
436 });
437 fake_model.send_last_completion_stream_text_chunk("Hello");
438 fake_model
439 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
440 fake_model.end_last_completion_stream();
441
442 let events = events.collect().await;
443 thread.update(cx, |thread, _cx| {
444 assert_eq!(thread.last_message().unwrap().role(), Role::Assistant);
445 assert_eq!(
446 thread.last_message().unwrap().to_markdown(),
447 indoc! {"
448 <think>Think</think>
449 Hello
450 "}
451 )
452 });
453 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
454}
455
456#[gpui::test]
457async fn test_system_prompt(cx: &mut TestAppContext) {
458 let ThreadTest {
459 model,
460 thread,
461 project_context,
462 ..
463 } = setup(cx, TestModel::Fake).await;
464 let fake_model = model.as_fake();
465
466 project_context.update(cx, |project_context, _cx| {
467 project_context.shell = "test-shell".into()
468 });
469 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
470 thread
471 .update(cx, |thread, cx| {
472 thread.send(UserMessageId::new(), ["abc"], cx)
473 })
474 .unwrap();
475 cx.run_until_parked();
476 let mut pending_completions = fake_model.pending_completions();
477 assert_eq!(
478 pending_completions.len(),
479 1,
480 "unexpected pending completions: {:?}",
481 pending_completions
482 );
483
484 let pending_completion = pending_completions.pop().unwrap();
485 assert_eq!(pending_completion.messages[0].role, Role::System);
486
487 let system_message = &pending_completion.messages[0];
488 let system_prompt = system_message.content[0].to_str().unwrap();
489 assert!(
490 system_prompt.contains("test-shell"),
491 "unexpected system message: {:?}",
492 system_message
493 );
494 assert!(
495 system_prompt.contains("## Fixing Diagnostics"),
496 "unexpected system message: {:?}",
497 system_message
498 );
499}
500
501#[gpui::test]
502async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
503 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
504 let fake_model = model.as_fake();
505
506 thread
507 .update(cx, |thread, cx| {
508 thread.send(UserMessageId::new(), ["abc"], cx)
509 })
510 .unwrap();
511 cx.run_until_parked();
512 let mut pending_completions = fake_model.pending_completions();
513 assert_eq!(
514 pending_completions.len(),
515 1,
516 "unexpected pending completions: {:?}",
517 pending_completions
518 );
519
520 let pending_completion = pending_completions.pop().unwrap();
521 assert_eq!(pending_completion.messages[0].role, Role::System);
522
523 let system_message = &pending_completion.messages[0];
524 let system_prompt = system_message.content[0].to_str().unwrap();
525 assert!(
526 !system_prompt.contains("## Tool Use"),
527 "unexpected system message: {:?}",
528 system_message
529 );
530 assert!(
531 !system_prompt.contains("## Fixing Diagnostics"),
532 "unexpected system message: {:?}",
533 system_message
534 );
535}
536
537#[gpui::test]
538async fn test_prompt_caching(cx: &mut TestAppContext) {
539 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
540 let fake_model = model.as_fake();
541
542 // Send initial user message and verify it's cached
543 thread
544 .update(cx, |thread, cx| {
545 thread.send(UserMessageId::new(), ["Message 1"], cx)
546 })
547 .unwrap();
548 cx.run_until_parked();
549
550 let completion = fake_model.pending_completions().pop().unwrap();
551 assert_eq!(
552 completion.messages[1..],
553 vec![LanguageModelRequestMessage {
554 role: Role::User,
555 content: vec!["Message 1".into()],
556 cache: true,
557 reasoning_details: None,
558 }]
559 );
560 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
561 "Response to Message 1".into(),
562 ));
563 fake_model.end_last_completion_stream();
564 cx.run_until_parked();
565
566 // Send another user message and verify only the latest is cached
567 thread
568 .update(cx, |thread, cx| {
569 thread.send(UserMessageId::new(), ["Message 2"], cx)
570 })
571 .unwrap();
572 cx.run_until_parked();
573
574 let completion = fake_model.pending_completions().pop().unwrap();
575 assert_eq!(
576 completion.messages[1..],
577 vec![
578 LanguageModelRequestMessage {
579 role: Role::User,
580 content: vec!["Message 1".into()],
581 cache: false,
582 reasoning_details: None,
583 },
584 LanguageModelRequestMessage {
585 role: Role::Assistant,
586 content: vec!["Response to Message 1".into()],
587 cache: false,
588 reasoning_details: None,
589 },
590 LanguageModelRequestMessage {
591 role: Role::User,
592 content: vec!["Message 2".into()],
593 cache: true,
594 reasoning_details: None,
595 }
596 ]
597 );
598 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
599 "Response to Message 2".into(),
600 ));
601 fake_model.end_last_completion_stream();
602 cx.run_until_parked();
603
604 // Simulate a tool call and verify that the latest tool result is cached
605 thread.update(cx, |thread, _| thread.add_tool(EchoTool, None));
606 thread
607 .update(cx, |thread, cx| {
608 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
609 })
610 .unwrap();
611 cx.run_until_parked();
612
613 let tool_use = LanguageModelToolUse {
614 id: "tool_1".into(),
615 name: EchoTool::NAME.into(),
616 raw_input: json!({"text": "test"}).to_string(),
617 input: json!({"text": "test"}),
618 is_input_complete: true,
619 thought_signature: None,
620 };
621 fake_model
622 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
623 fake_model.end_last_completion_stream();
624 cx.run_until_parked();
625
626 let completion = fake_model.pending_completions().pop().unwrap();
627 let tool_result = LanguageModelToolResult {
628 tool_use_id: "tool_1".into(),
629 tool_name: EchoTool::NAME.into(),
630 is_error: false,
631 content: "test".into(),
632 output: Some("test".into()),
633 };
634 assert_eq!(
635 completion.messages[1..],
636 vec![
637 LanguageModelRequestMessage {
638 role: Role::User,
639 content: vec!["Message 1".into()],
640 cache: false,
641 reasoning_details: None,
642 },
643 LanguageModelRequestMessage {
644 role: Role::Assistant,
645 content: vec!["Response to Message 1".into()],
646 cache: false,
647 reasoning_details: None,
648 },
649 LanguageModelRequestMessage {
650 role: Role::User,
651 content: vec!["Message 2".into()],
652 cache: false,
653 reasoning_details: None,
654 },
655 LanguageModelRequestMessage {
656 role: Role::Assistant,
657 content: vec!["Response to Message 2".into()],
658 cache: false,
659 reasoning_details: None,
660 },
661 LanguageModelRequestMessage {
662 role: Role::User,
663 content: vec!["Use the echo tool".into()],
664 cache: false,
665 reasoning_details: None,
666 },
667 LanguageModelRequestMessage {
668 role: Role::Assistant,
669 content: vec![MessageContent::ToolUse(tool_use)],
670 cache: false,
671 reasoning_details: None,
672 },
673 LanguageModelRequestMessage {
674 role: Role::User,
675 content: vec![MessageContent::ToolResult(tool_result)],
676 cache: true,
677 reasoning_details: None,
678 }
679 ]
680 );
681}
682
683#[gpui::test]
684#[cfg_attr(not(feature = "e2e"), ignore)]
685async fn test_basic_tool_calls(cx: &mut TestAppContext) {
686 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
687
688 // Test a tool call that's likely to complete *before* streaming stops.
689 let events = thread
690 .update(cx, |thread, cx| {
691 thread.add_tool(EchoTool, None);
692 thread.send(
693 UserMessageId::new(),
694 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
695 cx,
696 )
697 })
698 .unwrap()
699 .collect()
700 .await;
701 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
702
703 // Test a tool calls that's likely to complete *after* streaming stops.
704 let events = thread
705 .update(cx, |thread, cx| {
706 thread.remove_tool(&EchoTool::NAME);
707 thread.add_tool(DelayTool, None);
708 thread.send(
709 UserMessageId::new(),
710 [
711 "Now call the delay tool with 200ms.",
712 "When the timer goes off, then you echo the output of the tool.",
713 ],
714 cx,
715 )
716 })
717 .unwrap()
718 .collect()
719 .await;
720 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
721 thread.update(cx, |thread, _cx| {
722 assert!(
723 thread
724 .last_message()
725 .unwrap()
726 .as_agent_message()
727 .unwrap()
728 .content
729 .iter()
730 .any(|content| {
731 if let AgentMessageContent::Text(text) = content {
732 text.contains("Ding")
733 } else {
734 false
735 }
736 }),
737 "{}",
738 thread.to_markdown()
739 );
740 });
741}
742
743#[gpui::test]
744#[cfg_attr(not(feature = "e2e"), ignore)]
745async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
746 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
747
748 // Test a tool call that's likely to complete *before* streaming stops.
749 let mut events = thread
750 .update(cx, |thread, cx| {
751 thread.add_tool(WordListTool, None);
752 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
753 })
754 .unwrap();
755
756 let mut saw_partial_tool_use = false;
757 while let Some(event) = events.next().await {
758 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
759 thread.update(cx, |thread, _cx| {
760 // Look for a tool use in the thread's last message
761 let message = thread.last_message().unwrap();
762 let agent_message = message.as_agent_message().unwrap();
763 let last_content = agent_message.content.last().unwrap();
764 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
765 assert_eq!(last_tool_use.name.as_ref(), "word_list");
766 if tool_call.status == acp::ToolCallStatus::Pending {
767 if !last_tool_use.is_input_complete
768 && last_tool_use.input.get("g").is_none()
769 {
770 saw_partial_tool_use = true;
771 }
772 } else {
773 last_tool_use
774 .input
775 .get("a")
776 .expect("'a' has streamed because input is now complete");
777 last_tool_use
778 .input
779 .get("g")
780 .expect("'g' has streamed because input is now complete");
781 }
782 } else {
783 panic!("last content should be a tool use");
784 }
785 });
786 }
787 }
788
789 assert!(
790 saw_partial_tool_use,
791 "should see at least one partially streamed tool use in the history"
792 );
793}
794
795#[gpui::test]
796async fn test_tool_authorization(cx: &mut TestAppContext) {
797 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
798 let fake_model = model.as_fake();
799
800 let mut events = thread
801 .update(cx, |thread, cx| {
802 thread.add_tool(ToolRequiringPermission, None);
803 thread.send(UserMessageId::new(), ["abc"], cx)
804 })
805 .unwrap();
806 cx.run_until_parked();
807 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
808 LanguageModelToolUse {
809 id: "tool_id_1".into(),
810 name: ToolRequiringPermission::NAME.into(),
811 raw_input: "{}".into(),
812 input: json!({}),
813 is_input_complete: true,
814 thought_signature: None,
815 },
816 ));
817 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
818 LanguageModelToolUse {
819 id: "tool_id_2".into(),
820 name: ToolRequiringPermission::NAME.into(),
821 raw_input: "{}".into(),
822 input: json!({}),
823 is_input_complete: true,
824 thought_signature: None,
825 },
826 ));
827 fake_model.end_last_completion_stream();
828 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
829 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
830
831 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
832 tool_call_auth_1
833 .response
834 .send(acp::PermissionOptionId::new("allow"))
835 .unwrap();
836 cx.run_until_parked();
837
838 // Reject the second - send "deny" option_id directly since Deny is now a button
839 tool_call_auth_2
840 .response
841 .send(acp::PermissionOptionId::new("deny"))
842 .unwrap();
843 cx.run_until_parked();
844
845 let completion = fake_model.pending_completions().pop().unwrap();
846 let message = completion.messages.last().unwrap();
847 assert_eq!(
848 message.content,
849 vec![
850 language_model::MessageContent::ToolResult(LanguageModelToolResult {
851 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
852 tool_name: ToolRequiringPermission::NAME.into(),
853 is_error: false,
854 content: "Allowed".into(),
855 output: Some("Allowed".into())
856 }),
857 language_model::MessageContent::ToolResult(LanguageModelToolResult {
858 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
859 tool_name: ToolRequiringPermission::NAME.into(),
860 is_error: true,
861 content: "Permission to run tool denied by user".into(),
862 output: Some("Permission to run tool denied by user".into())
863 })
864 ]
865 );
866
867 // Simulate yet another tool call.
868 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
869 LanguageModelToolUse {
870 id: "tool_id_3".into(),
871 name: ToolRequiringPermission::NAME.into(),
872 raw_input: "{}".into(),
873 input: json!({}),
874 is_input_complete: true,
875 thought_signature: None,
876 },
877 ));
878 fake_model.end_last_completion_stream();
879
880 // Respond by always allowing tools - send transformed option_id
881 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
882 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
883 tool_call_auth_3
884 .response
885 .send(acp::PermissionOptionId::new(
886 "always_allow:tool_requiring_permission",
887 ))
888 .unwrap();
889 cx.run_until_parked();
890 let completion = fake_model.pending_completions().pop().unwrap();
891 let message = completion.messages.last().unwrap();
892 assert_eq!(
893 message.content,
894 vec![language_model::MessageContent::ToolResult(
895 LanguageModelToolResult {
896 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
897 tool_name: ToolRequiringPermission::NAME.into(),
898 is_error: false,
899 content: "Allowed".into(),
900 output: Some("Allowed".into())
901 }
902 )]
903 );
904
905 // Simulate a final tool call, ensuring we don't trigger authorization.
906 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
907 LanguageModelToolUse {
908 id: "tool_id_4".into(),
909 name: ToolRequiringPermission::NAME.into(),
910 raw_input: "{}".into(),
911 input: json!({}),
912 is_input_complete: true,
913 thought_signature: None,
914 },
915 ));
916 fake_model.end_last_completion_stream();
917 cx.run_until_parked();
918 let completion = fake_model.pending_completions().pop().unwrap();
919 let message = completion.messages.last().unwrap();
920 assert_eq!(
921 message.content,
922 vec![language_model::MessageContent::ToolResult(
923 LanguageModelToolResult {
924 tool_use_id: "tool_id_4".into(),
925 tool_name: ToolRequiringPermission::NAME.into(),
926 is_error: false,
927 content: "Allowed".into(),
928 output: Some("Allowed".into())
929 }
930 )]
931 );
932}
933
934#[gpui::test]
935async fn test_tool_hallucination(cx: &mut TestAppContext) {
936 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
937 let fake_model = model.as_fake();
938
939 let mut events = thread
940 .update(cx, |thread, cx| {
941 thread.send(UserMessageId::new(), ["abc"], cx)
942 })
943 .unwrap();
944 cx.run_until_parked();
945 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
946 LanguageModelToolUse {
947 id: "tool_id_1".into(),
948 name: "nonexistent_tool".into(),
949 raw_input: "{}".into(),
950 input: json!({}),
951 is_input_complete: true,
952 thought_signature: None,
953 },
954 ));
955 fake_model.end_last_completion_stream();
956
957 let tool_call = expect_tool_call(&mut events).await;
958 assert_eq!(tool_call.title, "nonexistent_tool");
959 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
960 let update = expect_tool_call_update_fields(&mut events).await;
961 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
962}
963
964async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
965 let event = events
966 .next()
967 .await
968 .expect("no tool call authorization event received")
969 .unwrap();
970 match event {
971 ThreadEvent::ToolCall(tool_call) => tool_call,
972 event => {
973 panic!("Unexpected event {event:?}");
974 }
975 }
976}
977
978async fn expect_tool_call_update_fields(
979 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
980) -> acp::ToolCallUpdate {
981 let event = events
982 .next()
983 .await
984 .expect("no tool call authorization event received")
985 .unwrap();
986 match event {
987 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
988 event => {
989 panic!("Unexpected event {event:?}");
990 }
991 }
992}
993
994async fn next_tool_call_authorization(
995 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
996) -> ToolCallAuthorization {
997 loop {
998 let event = events
999 .next()
1000 .await
1001 .expect("no tool call authorization event received")
1002 .unwrap();
1003 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1004 let permission_kinds = tool_call_authorization
1005 .options
1006 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1007 .map(|option| option.kind);
1008 let allow_once = tool_call_authorization
1009 .options
1010 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1011 .map(|option| option.kind);
1012
1013 assert_eq!(
1014 permission_kinds,
1015 Some(acp::PermissionOptionKind::AllowAlways)
1016 );
1017 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1018 return tool_call_authorization;
1019 }
1020 }
1021}
1022
1023#[test]
1024fn test_permission_options_terminal_with_pattern() {
1025 let permission_options = ToolPermissionContext::new(
1026 TerminalTool::NAME,
1027 vec!["cargo build --release".to_string()],
1028 )
1029 .build_permission_options();
1030
1031 let PermissionOptions::Dropdown(choices) = permission_options else {
1032 panic!("Expected dropdown permission options");
1033 };
1034
1035 assert_eq!(choices.len(), 3);
1036 let labels: Vec<&str> = choices
1037 .iter()
1038 .map(|choice| choice.allow.name.as_ref())
1039 .collect();
1040 assert!(labels.contains(&"Always for terminal"));
1041 assert!(labels.contains(&"Always for `cargo build` commands"));
1042 assert!(labels.contains(&"Only this time"));
1043}
1044
1045#[test]
1046fn test_permission_options_terminal_command_with_flag_second_token() {
1047 let permission_options =
1048 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1049 .build_permission_options();
1050
1051 let PermissionOptions::Dropdown(choices) = permission_options else {
1052 panic!("Expected dropdown permission options");
1053 };
1054
1055 assert_eq!(choices.len(), 3);
1056 let labels: Vec<&str> = choices
1057 .iter()
1058 .map(|choice| choice.allow.name.as_ref())
1059 .collect();
1060 assert!(labels.contains(&"Always for terminal"));
1061 assert!(labels.contains(&"Always for `ls` commands"));
1062 assert!(labels.contains(&"Only this time"));
1063}
1064
1065#[test]
1066fn test_permission_options_terminal_single_word_command() {
1067 let permission_options =
1068 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1069 .build_permission_options();
1070
1071 let PermissionOptions::Dropdown(choices) = permission_options else {
1072 panic!("Expected dropdown permission options");
1073 };
1074
1075 assert_eq!(choices.len(), 3);
1076 let labels: Vec<&str> = choices
1077 .iter()
1078 .map(|choice| choice.allow.name.as_ref())
1079 .collect();
1080 assert!(labels.contains(&"Always for terminal"));
1081 assert!(labels.contains(&"Always for `whoami` commands"));
1082 assert!(labels.contains(&"Only this time"));
1083}
1084
1085#[test]
1086fn test_permission_options_edit_file_with_path_pattern() {
1087 let permission_options =
1088 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1089 .build_permission_options();
1090
1091 let PermissionOptions::Dropdown(choices) = permission_options else {
1092 panic!("Expected dropdown permission options");
1093 };
1094
1095 let labels: Vec<&str> = choices
1096 .iter()
1097 .map(|choice| choice.allow.name.as_ref())
1098 .collect();
1099 assert!(labels.contains(&"Always for edit file"));
1100 assert!(labels.contains(&"Always for `src/`"));
1101}
1102
1103#[test]
1104fn test_permission_options_fetch_with_domain_pattern() {
1105 let permission_options =
1106 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1107 .build_permission_options();
1108
1109 let PermissionOptions::Dropdown(choices) = permission_options else {
1110 panic!("Expected dropdown permission options");
1111 };
1112
1113 let labels: Vec<&str> = choices
1114 .iter()
1115 .map(|choice| choice.allow.name.as_ref())
1116 .collect();
1117 assert!(labels.contains(&"Always for fetch"));
1118 assert!(labels.contains(&"Always for `docs.rs`"));
1119}
1120
1121#[test]
1122fn test_permission_options_without_pattern() {
1123 let permission_options = ToolPermissionContext::new(
1124 TerminalTool::NAME,
1125 vec!["./deploy.sh --production".to_string()],
1126 )
1127 .build_permission_options();
1128
1129 let PermissionOptions::Dropdown(choices) = permission_options else {
1130 panic!("Expected dropdown permission options");
1131 };
1132
1133 assert_eq!(choices.len(), 2);
1134 let labels: Vec<&str> = choices
1135 .iter()
1136 .map(|choice| choice.allow.name.as_ref())
1137 .collect();
1138 assert!(labels.contains(&"Always for terminal"));
1139 assert!(labels.contains(&"Only this time"));
1140 assert!(!labels.iter().any(|label| label.contains("commands")));
1141}
1142
1143#[test]
1144fn test_permission_options_symlink_target_are_flat_once_only() {
1145 let permission_options =
1146 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1147 .build_permission_options();
1148
1149 let PermissionOptions::Flat(options) = permission_options else {
1150 panic!("Expected flat permission options for symlink target authorization");
1151 };
1152
1153 assert_eq!(options.len(), 2);
1154 assert!(options.iter().any(|option| {
1155 option.option_id.0.as_ref() == "allow"
1156 && option.kind == acp::PermissionOptionKind::AllowOnce
1157 }));
1158 assert!(options.iter().any(|option| {
1159 option.option_id.0.as_ref() == "deny"
1160 && option.kind == acp::PermissionOptionKind::RejectOnce
1161 }));
1162}
1163
1164#[test]
1165fn test_permission_option_ids_for_terminal() {
1166 let permission_options = ToolPermissionContext::new(
1167 TerminalTool::NAME,
1168 vec!["cargo build --release".to_string()],
1169 )
1170 .build_permission_options();
1171
1172 let PermissionOptions::Dropdown(choices) = permission_options else {
1173 panic!("Expected dropdown permission options");
1174 };
1175
1176 let allow_ids: Vec<String> = choices
1177 .iter()
1178 .map(|choice| choice.allow.option_id.0.to_string())
1179 .collect();
1180 let deny_ids: Vec<String> = choices
1181 .iter()
1182 .map(|choice| choice.deny.option_id.0.to_string())
1183 .collect();
1184
1185 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1186 assert!(allow_ids.contains(&"allow".to_string()));
1187 assert!(
1188 allow_ids
1189 .iter()
1190 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1191 "Missing allow pattern option"
1192 );
1193
1194 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1195 assert!(deny_ids.contains(&"deny".to_string()));
1196 assert!(
1197 deny_ids
1198 .iter()
1199 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1200 "Missing deny pattern option"
1201 );
1202}
1203
1204#[gpui::test]
1205#[cfg_attr(not(feature = "e2e"), ignore)]
1206async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1207 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1208
1209 // Test concurrent tool calls with different delay times
1210 let events = thread
1211 .update(cx, |thread, cx| {
1212 thread.add_tool(DelayTool, None);
1213 thread.send(
1214 UserMessageId::new(),
1215 [
1216 "Call the delay tool twice in the same message.",
1217 "Once with 100ms. Once with 300ms.",
1218 "When both timers are complete, describe the outputs.",
1219 ],
1220 cx,
1221 )
1222 })
1223 .unwrap()
1224 .collect()
1225 .await;
1226
1227 let stop_reasons = stop_events(events);
1228 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1229
1230 thread.update(cx, |thread, _cx| {
1231 let last_message = thread.last_message().unwrap();
1232 let agent_message = last_message.as_agent_message().unwrap();
1233 let text = agent_message
1234 .content
1235 .iter()
1236 .filter_map(|content| {
1237 if let AgentMessageContent::Text(text) = content {
1238 Some(text.as_str())
1239 } else {
1240 None
1241 }
1242 })
1243 .collect::<String>();
1244
1245 assert!(text.contains("Ding"));
1246 });
1247}
1248
1249#[gpui::test]
1250async fn test_profiles(cx: &mut TestAppContext) {
1251 let ThreadTest {
1252 model, thread, fs, ..
1253 } = setup(cx, TestModel::Fake).await;
1254 let fake_model = model.as_fake();
1255
1256 thread.update(cx, |thread, _cx| {
1257 thread.add_tool(DelayTool, None);
1258 thread.add_tool(EchoTool, None);
1259 thread.add_tool(InfiniteTool, None);
1260 });
1261
1262 // Override profiles and wait for settings to be loaded.
1263 fs.insert_file(
1264 paths::settings_file(),
1265 json!({
1266 "agent": {
1267 "profiles": {
1268 "test-1": {
1269 "name": "Test Profile 1",
1270 "tools": {
1271 EchoTool::NAME: true,
1272 DelayTool::NAME: true,
1273 }
1274 },
1275 "test-2": {
1276 "name": "Test Profile 2",
1277 "tools": {
1278 InfiniteTool::NAME: true,
1279 }
1280 }
1281 }
1282 }
1283 })
1284 .to_string()
1285 .into_bytes(),
1286 )
1287 .await;
1288 cx.run_until_parked();
1289
1290 // Test that test-1 profile (default) has echo and delay tools
1291 thread
1292 .update(cx, |thread, cx| {
1293 thread.set_profile(AgentProfileId("test-1".into()), cx);
1294 thread.send(UserMessageId::new(), ["test"], cx)
1295 })
1296 .unwrap();
1297 cx.run_until_parked();
1298
1299 let mut pending_completions = fake_model.pending_completions();
1300 assert_eq!(pending_completions.len(), 1);
1301 let completion = pending_completions.pop().unwrap();
1302 let tool_names: Vec<String> = completion
1303 .tools
1304 .iter()
1305 .map(|tool| tool.name.clone())
1306 .collect();
1307 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1308 fake_model.end_last_completion_stream();
1309
1310 // Switch to test-2 profile, and verify that it has only the infinite tool.
1311 thread
1312 .update(cx, |thread, cx| {
1313 thread.set_profile(AgentProfileId("test-2".into()), cx);
1314 thread.send(UserMessageId::new(), ["test2"], cx)
1315 })
1316 .unwrap();
1317 cx.run_until_parked();
1318 let mut pending_completions = fake_model.pending_completions();
1319 assert_eq!(pending_completions.len(), 1);
1320 let completion = pending_completions.pop().unwrap();
1321 let tool_names: Vec<String> = completion
1322 .tools
1323 .iter()
1324 .map(|tool| tool.name.clone())
1325 .collect();
1326 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1327}
1328
1329#[gpui::test]
1330async fn test_mcp_tools(cx: &mut TestAppContext) {
1331 let ThreadTest {
1332 model,
1333 thread,
1334 context_server_store,
1335 fs,
1336 ..
1337 } = setup(cx, TestModel::Fake).await;
1338 let fake_model = model.as_fake();
1339
1340 // Override profiles and wait for settings to be loaded.
1341 fs.insert_file(
1342 paths::settings_file(),
1343 json!({
1344 "agent": {
1345 "tool_permissions": { "default": "allow" },
1346 "profiles": {
1347 "test": {
1348 "name": "Test Profile",
1349 "enable_all_context_servers": true,
1350 "tools": {
1351 EchoTool::NAME: true,
1352 }
1353 },
1354 }
1355 }
1356 })
1357 .to_string()
1358 .into_bytes(),
1359 )
1360 .await;
1361 cx.run_until_parked();
1362 thread.update(cx, |thread, cx| {
1363 thread.set_profile(AgentProfileId("test".into()), cx)
1364 });
1365
1366 let mut mcp_tool_calls = setup_context_server(
1367 "test_server",
1368 vec![context_server::types::Tool {
1369 name: "echo".into(),
1370 description: None,
1371 input_schema: serde_json::to_value(EchoTool::input_schema(
1372 LanguageModelToolSchemaFormat::JsonSchema,
1373 ))
1374 .unwrap(),
1375 output_schema: None,
1376 annotations: None,
1377 }],
1378 &context_server_store,
1379 cx,
1380 );
1381
1382 let events = thread.update(cx, |thread, cx| {
1383 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1384 });
1385 cx.run_until_parked();
1386
1387 // Simulate the model calling the MCP tool.
1388 let completion = fake_model.pending_completions().pop().unwrap();
1389 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1390 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1391 LanguageModelToolUse {
1392 id: "tool_1".into(),
1393 name: "echo".into(),
1394 raw_input: json!({"text": "test"}).to_string(),
1395 input: json!({"text": "test"}),
1396 is_input_complete: true,
1397 thought_signature: None,
1398 },
1399 ));
1400 fake_model.end_last_completion_stream();
1401 cx.run_until_parked();
1402
1403 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1404 assert_eq!(tool_call_params.name, "echo");
1405 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1406 tool_call_response
1407 .send(context_server::types::CallToolResponse {
1408 content: vec![context_server::types::ToolResponseContent::Text {
1409 text: "test".into(),
1410 }],
1411 is_error: None,
1412 meta: None,
1413 structured_content: None,
1414 })
1415 .unwrap();
1416 cx.run_until_parked();
1417
1418 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1419 fake_model.send_last_completion_stream_text_chunk("Done!");
1420 fake_model.end_last_completion_stream();
1421 events.collect::<Vec<_>>().await;
1422
1423 // Send again after adding the echo tool, ensuring the name collision is resolved.
1424 let events = thread.update(cx, |thread, cx| {
1425 thread.add_tool(EchoTool, None);
1426 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1427 });
1428 cx.run_until_parked();
1429 let completion = fake_model.pending_completions().pop().unwrap();
1430 assert_eq!(
1431 tool_names_for_completion(&completion),
1432 vec!["echo", "test_server_echo"]
1433 );
1434 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1435 LanguageModelToolUse {
1436 id: "tool_2".into(),
1437 name: "test_server_echo".into(),
1438 raw_input: json!({"text": "mcp"}).to_string(),
1439 input: json!({"text": "mcp"}),
1440 is_input_complete: true,
1441 thought_signature: None,
1442 },
1443 ));
1444 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1445 LanguageModelToolUse {
1446 id: "tool_3".into(),
1447 name: "echo".into(),
1448 raw_input: json!({"text": "native"}).to_string(),
1449 input: json!({"text": "native"}),
1450 is_input_complete: true,
1451 thought_signature: None,
1452 },
1453 ));
1454 fake_model.end_last_completion_stream();
1455 cx.run_until_parked();
1456
1457 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1458 assert_eq!(tool_call_params.name, "echo");
1459 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1460 tool_call_response
1461 .send(context_server::types::CallToolResponse {
1462 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1463 is_error: None,
1464 meta: None,
1465 structured_content: None,
1466 })
1467 .unwrap();
1468 cx.run_until_parked();
1469
1470 // Ensure the tool results were inserted with the correct names.
1471 let completion = fake_model.pending_completions().pop().unwrap();
1472 assert_eq!(
1473 completion.messages.last().unwrap().content,
1474 vec![
1475 MessageContent::ToolResult(LanguageModelToolResult {
1476 tool_use_id: "tool_3".into(),
1477 tool_name: "echo".into(),
1478 is_error: false,
1479 content: "native".into(),
1480 output: Some("native".into()),
1481 },),
1482 MessageContent::ToolResult(LanguageModelToolResult {
1483 tool_use_id: "tool_2".into(),
1484 tool_name: "test_server_echo".into(),
1485 is_error: false,
1486 content: "mcp".into(),
1487 output: Some("mcp".into()),
1488 },),
1489 ]
1490 );
1491 fake_model.end_last_completion_stream();
1492 events.collect::<Vec<_>>().await;
1493}
1494
1495#[gpui::test]
1496async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1497 let ThreadTest {
1498 model,
1499 thread,
1500 context_server_store,
1501 fs,
1502 ..
1503 } = setup(cx, TestModel::Fake).await;
1504 let fake_model = model.as_fake();
1505
1506 // Setup settings to allow MCP tools
1507 fs.insert_file(
1508 paths::settings_file(),
1509 json!({
1510 "agent": {
1511 "always_allow_tool_actions": true,
1512 "profiles": {
1513 "test": {
1514 "name": "Test Profile",
1515 "enable_all_context_servers": true,
1516 "tools": {}
1517 },
1518 }
1519 }
1520 })
1521 .to_string()
1522 .into_bytes(),
1523 )
1524 .await;
1525 cx.run_until_parked();
1526 thread.update(cx, |thread, cx| {
1527 thread.set_profile(AgentProfileId("test".into()), cx)
1528 });
1529
1530 // Setup a context server with a tool
1531 let mut mcp_tool_calls = setup_context_server(
1532 "github_server",
1533 vec![context_server::types::Tool {
1534 name: "issue_read".into(),
1535 description: Some("Read a GitHub issue".into()),
1536 input_schema: json!({
1537 "type": "object",
1538 "properties": {
1539 "issue_url": { "type": "string" }
1540 }
1541 }),
1542 output_schema: None,
1543 annotations: None,
1544 }],
1545 &context_server_store,
1546 cx,
1547 );
1548
1549 // Send a message and have the model call the MCP tool
1550 let events = thread.update(cx, |thread, cx| {
1551 thread
1552 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1553 .unwrap()
1554 });
1555 cx.run_until_parked();
1556
1557 // Verify the MCP tool is available to the model
1558 let completion = fake_model.pending_completions().pop().unwrap();
1559 assert_eq!(
1560 tool_names_for_completion(&completion),
1561 vec!["issue_read"],
1562 "MCP tool should be available"
1563 );
1564
1565 // Simulate the model calling the MCP tool
1566 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1567 LanguageModelToolUse {
1568 id: "tool_1".into(),
1569 name: "issue_read".into(),
1570 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1571 .to_string(),
1572 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1573 is_input_complete: true,
1574 thought_signature: None,
1575 },
1576 ));
1577 fake_model.end_last_completion_stream();
1578 cx.run_until_parked();
1579
1580 // The MCP server receives the tool call and responds with content
1581 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1582 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1583 assert_eq!(tool_call_params.name, "issue_read");
1584 tool_call_response
1585 .send(context_server::types::CallToolResponse {
1586 content: vec![context_server::types::ToolResponseContent::Text {
1587 text: expected_tool_output.into(),
1588 }],
1589 is_error: None,
1590 meta: None,
1591 structured_content: None,
1592 })
1593 .unwrap();
1594 cx.run_until_parked();
1595
1596 // After tool completes, the model continues with a new completion request
1597 // that includes the tool results. We need to respond to this.
1598 let _completion = fake_model.pending_completions().pop().unwrap();
1599 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1600 fake_model
1601 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1602 fake_model.end_last_completion_stream();
1603 events.collect::<Vec<_>>().await;
1604
1605 // Verify the tool result is stored in the thread by checking the markdown output.
1606 // The tool result is in the first assistant message (not the last one, which is
1607 // the model's response after the tool completed).
1608 thread.update(cx, |thread, _cx| {
1609 let markdown = thread.to_markdown();
1610 assert!(
1611 markdown.contains("**Tool Result**: issue_read"),
1612 "Thread should contain tool result header"
1613 );
1614 assert!(
1615 markdown.contains(expected_tool_output),
1616 "Thread should contain tool output: {}",
1617 expected_tool_output
1618 );
1619 });
1620
1621 // Simulate app restart: disconnect the MCP server.
1622 // After restart, the MCP server won't be connected yet when the thread is replayed.
1623 context_server_store.update(cx, |store, cx| {
1624 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1625 });
1626 cx.run_until_parked();
1627
1628 // Replay the thread (this is what happens when loading a saved thread)
1629 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1630
1631 let mut found_tool_call = None;
1632 let mut found_tool_call_update_with_output = None;
1633
1634 while let Some(event) = replay_events.next().await {
1635 let event = event.unwrap();
1636 match &event {
1637 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1638 found_tool_call = Some(tc.clone());
1639 }
1640 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1641 if update.tool_call_id.to_string() == "tool_1" =>
1642 {
1643 if update.fields.raw_output.is_some() {
1644 found_tool_call_update_with_output = Some(update.clone());
1645 }
1646 }
1647 _ => {}
1648 }
1649 }
1650
1651 // The tool call should be found
1652 assert!(
1653 found_tool_call.is_some(),
1654 "Tool call should be emitted during replay"
1655 );
1656
1657 assert!(
1658 found_tool_call_update_with_output.is_some(),
1659 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1660 );
1661
1662 let update = found_tool_call_update_with_output.unwrap();
1663 assert_eq!(
1664 update.fields.raw_output,
1665 Some(expected_tool_output.into()),
1666 "raw_output should contain the saved tool result"
1667 );
1668
1669 // Also verify the status is correct (completed, not failed)
1670 assert_eq!(
1671 update.fields.status,
1672 Some(acp::ToolCallStatus::Completed),
1673 "Tool call status should reflect the original completion status"
1674 );
1675}
1676
1677#[gpui::test]
1678async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1679 let ThreadTest {
1680 model,
1681 thread,
1682 context_server_store,
1683 fs,
1684 ..
1685 } = setup(cx, TestModel::Fake).await;
1686 let fake_model = model.as_fake();
1687
1688 // Set up a profile with all tools enabled
1689 fs.insert_file(
1690 paths::settings_file(),
1691 json!({
1692 "agent": {
1693 "profiles": {
1694 "test": {
1695 "name": "Test Profile",
1696 "enable_all_context_servers": true,
1697 "tools": {
1698 EchoTool::NAME: true,
1699 DelayTool::NAME: true,
1700 WordListTool::NAME: true,
1701 ToolRequiringPermission::NAME: true,
1702 InfiniteTool::NAME: true,
1703 }
1704 },
1705 }
1706 }
1707 })
1708 .to_string()
1709 .into_bytes(),
1710 )
1711 .await;
1712 cx.run_until_parked();
1713
1714 thread.update(cx, |thread, cx| {
1715 thread.set_profile(AgentProfileId("test".into()), cx);
1716 thread.add_tool(EchoTool, None);
1717 thread.add_tool(DelayTool, None);
1718 thread.add_tool(WordListTool, None);
1719 thread.add_tool(ToolRequiringPermission, None);
1720 thread.add_tool(InfiniteTool, None);
1721 });
1722
1723 // Set up multiple context servers with some overlapping tool names
1724 let _server1_calls = setup_context_server(
1725 "xxx",
1726 vec![
1727 context_server::types::Tool {
1728 name: "echo".into(), // Conflicts with native EchoTool
1729 description: None,
1730 input_schema: serde_json::to_value(EchoTool::input_schema(
1731 LanguageModelToolSchemaFormat::JsonSchema,
1732 ))
1733 .unwrap(),
1734 output_schema: None,
1735 annotations: None,
1736 },
1737 context_server::types::Tool {
1738 name: "unique_tool_1".into(),
1739 description: None,
1740 input_schema: json!({"type": "object", "properties": {}}),
1741 output_schema: None,
1742 annotations: None,
1743 },
1744 ],
1745 &context_server_store,
1746 cx,
1747 );
1748
1749 let _server2_calls = setup_context_server(
1750 "yyy",
1751 vec![
1752 context_server::types::Tool {
1753 name: "echo".into(), // Also conflicts with native EchoTool
1754 description: None,
1755 input_schema: serde_json::to_value(EchoTool::input_schema(
1756 LanguageModelToolSchemaFormat::JsonSchema,
1757 ))
1758 .unwrap(),
1759 output_schema: None,
1760 annotations: None,
1761 },
1762 context_server::types::Tool {
1763 name: "unique_tool_2".into(),
1764 description: None,
1765 input_schema: json!({"type": "object", "properties": {}}),
1766 output_schema: None,
1767 annotations: None,
1768 },
1769 context_server::types::Tool {
1770 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1771 description: None,
1772 input_schema: json!({"type": "object", "properties": {}}),
1773 output_schema: None,
1774 annotations: None,
1775 },
1776 context_server::types::Tool {
1777 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1778 description: None,
1779 input_schema: json!({"type": "object", "properties": {}}),
1780 output_schema: None,
1781 annotations: None,
1782 },
1783 ],
1784 &context_server_store,
1785 cx,
1786 );
1787 let _server3_calls = setup_context_server(
1788 "zzz",
1789 vec![
1790 context_server::types::Tool {
1791 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1792 description: None,
1793 input_schema: json!({"type": "object", "properties": {}}),
1794 output_schema: None,
1795 annotations: None,
1796 },
1797 context_server::types::Tool {
1798 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1799 description: None,
1800 input_schema: json!({"type": "object", "properties": {}}),
1801 output_schema: None,
1802 annotations: None,
1803 },
1804 context_server::types::Tool {
1805 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1806 description: None,
1807 input_schema: json!({"type": "object", "properties": {}}),
1808 output_schema: None,
1809 annotations: None,
1810 },
1811 ],
1812 &context_server_store,
1813 cx,
1814 );
1815
1816 // Server with spaces in name - tests snake_case conversion for API compatibility
1817 let _server4_calls = setup_context_server(
1818 "Azure DevOps",
1819 vec![context_server::types::Tool {
1820 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1821 description: None,
1822 input_schema: serde_json::to_value(EchoTool::input_schema(
1823 LanguageModelToolSchemaFormat::JsonSchema,
1824 ))
1825 .unwrap(),
1826 output_schema: None,
1827 annotations: None,
1828 }],
1829 &context_server_store,
1830 cx,
1831 );
1832
1833 thread
1834 .update(cx, |thread, cx| {
1835 thread.send(UserMessageId::new(), ["Go"], cx)
1836 })
1837 .unwrap();
1838 cx.run_until_parked();
1839 let completion = fake_model.pending_completions().pop().unwrap();
1840 assert_eq!(
1841 tool_names_for_completion(&completion),
1842 vec![
1843 "azure_dev_ops_echo",
1844 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1845 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1846 "delay",
1847 "echo",
1848 "infinite",
1849 "tool_requiring_permission",
1850 "unique_tool_1",
1851 "unique_tool_2",
1852 "word_list",
1853 "xxx_echo",
1854 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1855 "yyy_echo",
1856 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1857 ]
1858 );
1859}
1860
1861#[gpui::test]
1862#[cfg_attr(not(feature = "e2e"), ignore)]
1863async fn test_cancellation(cx: &mut TestAppContext) {
1864 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1865
1866 let mut events = thread
1867 .update(cx, |thread, cx| {
1868 thread.add_tool(InfiniteTool, None);
1869 thread.add_tool(EchoTool, None);
1870 thread.send(
1871 UserMessageId::new(),
1872 ["Call the echo tool, then call the infinite tool, then explain their output"],
1873 cx,
1874 )
1875 })
1876 .unwrap();
1877
1878 // Wait until both tools are called.
1879 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1880 let mut echo_id = None;
1881 let mut echo_completed = false;
1882 while let Some(event) = events.next().await {
1883 match event.unwrap() {
1884 ThreadEvent::ToolCall(tool_call) => {
1885 assert_eq!(tool_call.title, expected_tools.remove(0));
1886 if tool_call.title == "Echo" {
1887 echo_id = Some(tool_call.tool_call_id);
1888 }
1889 }
1890 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1891 acp::ToolCallUpdate {
1892 tool_call_id,
1893 fields:
1894 acp::ToolCallUpdateFields {
1895 status: Some(acp::ToolCallStatus::Completed),
1896 ..
1897 },
1898 ..
1899 },
1900 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1901 echo_completed = true;
1902 }
1903 _ => {}
1904 }
1905
1906 if expected_tools.is_empty() && echo_completed {
1907 break;
1908 }
1909 }
1910
1911 // Cancel the current send and ensure that the event stream is closed, even
1912 // if one of the tools is still running.
1913 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1914 let events = events.collect::<Vec<_>>().await;
1915 let last_event = events.last();
1916 assert!(
1917 matches!(
1918 last_event,
1919 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1920 ),
1921 "unexpected event {last_event:?}"
1922 );
1923
1924 // Ensure we can still send a new message after cancellation.
1925 let events = thread
1926 .update(cx, |thread, cx| {
1927 thread.send(
1928 UserMessageId::new(),
1929 ["Testing: reply with 'Hello' then stop."],
1930 cx,
1931 )
1932 })
1933 .unwrap()
1934 .collect::<Vec<_>>()
1935 .await;
1936 thread.update(cx, |thread, _cx| {
1937 let message = thread.last_message().unwrap();
1938 let agent_message = message.as_agent_message().unwrap();
1939 assert_eq!(
1940 agent_message.content,
1941 vec![AgentMessageContent::Text("Hello".to_string())]
1942 );
1943 });
1944 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1945}
1946
1947#[gpui::test]
1948async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1949 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1950 always_allow_tools(cx);
1951 let fake_model = model.as_fake();
1952
1953 let environment = Rc::new(cx.update(|cx| {
1954 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1955 }));
1956 let handle = environment.terminal_handle.clone().unwrap();
1957
1958 let mut events = thread
1959 .update(cx, |thread, cx| {
1960 thread.add_tool(
1961 crate::TerminalTool::new(thread.project().clone(), environment),
1962 None,
1963 );
1964 thread.send(UserMessageId::new(), ["run a command"], cx)
1965 })
1966 .unwrap();
1967
1968 cx.run_until_parked();
1969
1970 // Simulate the model calling the terminal tool
1971 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1972 LanguageModelToolUse {
1973 id: "terminal_tool_1".into(),
1974 name: TerminalTool::NAME.into(),
1975 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1976 input: json!({"command": "sleep 1000", "cd": "."}),
1977 is_input_complete: true,
1978 thought_signature: None,
1979 },
1980 ));
1981 fake_model.end_last_completion_stream();
1982
1983 // Wait for the terminal tool to start running
1984 wait_for_terminal_tool_started(&mut events, cx).await;
1985
1986 // Cancel the thread while the terminal is running
1987 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1988
1989 // Collect remaining events, driving the executor to let cancellation complete
1990 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1991
1992 // Verify the terminal was killed
1993 assert!(
1994 handle.was_killed(),
1995 "expected terminal handle to be killed on cancellation"
1996 );
1997
1998 // Verify we got a cancellation stop event
1999 assert_eq!(
2000 stop_events(remaining_events),
2001 vec![acp::StopReason::Cancelled],
2002 );
2003
2004 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2005 thread.update(cx, |thread, _cx| {
2006 let message = thread.last_message().unwrap();
2007 let agent_message = message.as_agent_message().unwrap();
2008
2009 let tool_use = agent_message
2010 .content
2011 .iter()
2012 .find_map(|content| match content {
2013 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2014 _ => None,
2015 })
2016 .expect("expected tool use in agent message");
2017
2018 let tool_result = agent_message
2019 .tool_results
2020 .get(&tool_use.id)
2021 .expect("expected tool result");
2022
2023 let result_text = match &tool_result.content {
2024 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2025 _ => panic!("expected text content in tool result"),
2026 };
2027
2028 // "partial output" comes from FakeTerminalHandle's output field
2029 assert!(
2030 result_text.contains("partial output"),
2031 "expected tool result to contain terminal output, got: {result_text}"
2032 );
2033 // Match the actual format from process_content in terminal_tool.rs
2034 assert!(
2035 result_text.contains("The user stopped this command"),
2036 "expected tool result to indicate user stopped, got: {result_text}"
2037 );
2038 });
2039
2040 // Verify we can send a new message after cancellation
2041 verify_thread_recovery(&thread, &fake_model, cx).await;
2042}
2043
2044#[gpui::test]
2045async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2046 // This test verifies that tools which properly handle cancellation via
2047 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2048 // to cancellation and report that they were cancelled.
2049 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2050 always_allow_tools(cx);
2051 let fake_model = model.as_fake();
2052
2053 let (tool, was_cancelled) = CancellationAwareTool::new();
2054
2055 let mut events = thread
2056 .update(cx, |thread, cx| {
2057 thread.add_tool(tool, None);
2058 thread.send(
2059 UserMessageId::new(),
2060 ["call the cancellation aware tool"],
2061 cx,
2062 )
2063 })
2064 .unwrap();
2065
2066 cx.run_until_parked();
2067
2068 // Simulate the model calling the cancellation-aware tool
2069 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2070 LanguageModelToolUse {
2071 id: "cancellation_aware_1".into(),
2072 name: "cancellation_aware".into(),
2073 raw_input: r#"{}"#.into(),
2074 input: json!({}),
2075 is_input_complete: true,
2076 thought_signature: None,
2077 },
2078 ));
2079 fake_model.end_last_completion_stream();
2080
2081 cx.run_until_parked();
2082
2083 // Wait for the tool call to be reported
2084 let mut tool_started = false;
2085 let deadline = cx.executor().num_cpus() * 100;
2086 for _ in 0..deadline {
2087 cx.run_until_parked();
2088
2089 while let Some(Some(event)) = events.next().now_or_never() {
2090 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2091 if tool_call.title == "Cancellation Aware Tool" {
2092 tool_started = true;
2093 break;
2094 }
2095 }
2096 }
2097
2098 if tool_started {
2099 break;
2100 }
2101
2102 cx.background_executor
2103 .timer(Duration::from_millis(10))
2104 .await;
2105 }
2106 assert!(tool_started, "expected cancellation aware tool to start");
2107
2108 // Cancel the thread and wait for it to complete
2109 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2110
2111 // The cancel task should complete promptly because the tool handles cancellation
2112 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2113 futures::select! {
2114 _ = cancel_task.fuse() => {}
2115 _ = timeout.fuse() => {
2116 panic!("cancel task timed out - tool did not respond to cancellation");
2117 }
2118 }
2119
2120 // Verify the tool detected cancellation via its flag
2121 assert!(
2122 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2123 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2124 );
2125
2126 // Collect remaining events
2127 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2128
2129 // Verify we got a cancellation stop event
2130 assert_eq!(
2131 stop_events(remaining_events),
2132 vec![acp::StopReason::Cancelled],
2133 );
2134
2135 // Verify we can send a new message after cancellation
2136 verify_thread_recovery(&thread, &fake_model, cx).await;
2137}
2138
2139/// Helper to verify thread can recover after cancellation by sending a simple message.
2140async fn verify_thread_recovery(
2141 thread: &Entity<Thread>,
2142 fake_model: &FakeLanguageModel,
2143 cx: &mut TestAppContext,
2144) {
2145 let events = thread
2146 .update(cx, |thread, cx| {
2147 thread.send(
2148 UserMessageId::new(),
2149 ["Testing: reply with 'Hello' then stop."],
2150 cx,
2151 )
2152 })
2153 .unwrap();
2154 cx.run_until_parked();
2155 fake_model.send_last_completion_stream_text_chunk("Hello");
2156 fake_model
2157 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2158 fake_model.end_last_completion_stream();
2159
2160 let events = events.collect::<Vec<_>>().await;
2161 thread.update(cx, |thread, _cx| {
2162 let message = thread.last_message().unwrap();
2163 let agent_message = message.as_agent_message().unwrap();
2164 assert_eq!(
2165 agent_message.content,
2166 vec![AgentMessageContent::Text("Hello".to_string())]
2167 );
2168 });
2169 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2170}
2171
2172/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2173async fn wait_for_terminal_tool_started(
2174 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2175 cx: &mut TestAppContext,
2176) {
2177 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2178 for _ in 0..deadline {
2179 cx.run_until_parked();
2180
2181 while let Some(Some(event)) = events.next().now_or_never() {
2182 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2183 update,
2184 ))) = &event
2185 {
2186 if update.fields.content.as_ref().is_some_and(|content| {
2187 content
2188 .iter()
2189 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2190 }) {
2191 return;
2192 }
2193 }
2194 }
2195
2196 cx.background_executor
2197 .timer(Duration::from_millis(10))
2198 .await;
2199 }
2200 panic!("terminal tool did not start within the expected time");
2201}
2202
2203/// Collects events until a Stop event is received, driving the executor to completion.
2204async fn collect_events_until_stop(
2205 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2206 cx: &mut TestAppContext,
2207) -> Vec<Result<ThreadEvent>> {
2208 let mut collected = Vec::new();
2209 let deadline = cx.executor().num_cpus() * 200;
2210
2211 for _ in 0..deadline {
2212 cx.executor().advance_clock(Duration::from_millis(10));
2213 cx.run_until_parked();
2214
2215 while let Some(Some(event)) = events.next().now_or_never() {
2216 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2217 collected.push(event);
2218 if is_stop {
2219 return collected;
2220 }
2221 }
2222 }
2223 panic!(
2224 "did not receive Stop event within the expected time; collected {} events",
2225 collected.len()
2226 );
2227}
2228
2229#[gpui::test]
2230async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2231 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2232 always_allow_tools(cx);
2233 let fake_model = model.as_fake();
2234
2235 let environment = Rc::new(cx.update(|cx| {
2236 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2237 }));
2238 let handle = environment.terminal_handle.clone().unwrap();
2239
2240 let message_id = UserMessageId::new();
2241 let mut events = thread
2242 .update(cx, |thread, cx| {
2243 thread.add_tool(
2244 crate::TerminalTool::new(thread.project().clone(), environment),
2245 None,
2246 );
2247 thread.send(message_id.clone(), ["run a command"], cx)
2248 })
2249 .unwrap();
2250
2251 cx.run_until_parked();
2252
2253 // Simulate the model calling the terminal tool
2254 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2255 LanguageModelToolUse {
2256 id: "terminal_tool_1".into(),
2257 name: TerminalTool::NAME.into(),
2258 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2259 input: json!({"command": "sleep 1000", "cd": "."}),
2260 is_input_complete: true,
2261 thought_signature: None,
2262 },
2263 ));
2264 fake_model.end_last_completion_stream();
2265
2266 // Wait for the terminal tool to start running
2267 wait_for_terminal_tool_started(&mut events, cx).await;
2268
2269 // Truncate the thread while the terminal is running
2270 thread
2271 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2272 .unwrap();
2273
2274 // Drive the executor to let cancellation complete
2275 let _ = collect_events_until_stop(&mut events, cx).await;
2276
2277 // Verify the terminal was killed
2278 assert!(
2279 handle.was_killed(),
2280 "expected terminal handle to be killed on truncate"
2281 );
2282
2283 // Verify the thread is empty after truncation
2284 thread.update(cx, |thread, _cx| {
2285 assert_eq!(
2286 thread.to_markdown(),
2287 "",
2288 "expected thread to be empty after truncating the only message"
2289 );
2290 });
2291
2292 // Verify we can send a new message after truncation
2293 verify_thread_recovery(&thread, &fake_model, cx).await;
2294}
2295
2296#[gpui::test]
2297async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2298 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2299 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2300 always_allow_tools(cx);
2301 let fake_model = model.as_fake();
2302
2303 let environment = Rc::new(MultiTerminalEnvironment::new());
2304
2305 let mut events = thread
2306 .update(cx, |thread, cx| {
2307 thread.add_tool(
2308 crate::TerminalTool::new(thread.project().clone(), environment.clone()),
2309 None,
2310 );
2311 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2312 })
2313 .unwrap();
2314
2315 cx.run_until_parked();
2316
2317 // Simulate the model calling two terminal tools
2318 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2319 LanguageModelToolUse {
2320 id: "terminal_tool_1".into(),
2321 name: TerminalTool::NAME.into(),
2322 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2323 input: json!({"command": "sleep 1000", "cd": "."}),
2324 is_input_complete: true,
2325 thought_signature: None,
2326 },
2327 ));
2328 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2329 LanguageModelToolUse {
2330 id: "terminal_tool_2".into(),
2331 name: TerminalTool::NAME.into(),
2332 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2333 input: json!({"command": "sleep 2000", "cd": "."}),
2334 is_input_complete: true,
2335 thought_signature: None,
2336 },
2337 ));
2338 fake_model.end_last_completion_stream();
2339
2340 // Wait for both terminal tools to start by counting terminal content updates
2341 let mut terminals_started = 0;
2342 let deadline = cx.executor().num_cpus() * 100;
2343 for _ in 0..deadline {
2344 cx.run_until_parked();
2345
2346 while let Some(Some(event)) = events.next().now_or_never() {
2347 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2348 update,
2349 ))) = &event
2350 {
2351 if update.fields.content.as_ref().is_some_and(|content| {
2352 content
2353 .iter()
2354 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2355 }) {
2356 terminals_started += 1;
2357 if terminals_started >= 2 {
2358 break;
2359 }
2360 }
2361 }
2362 }
2363 if terminals_started >= 2 {
2364 break;
2365 }
2366
2367 cx.background_executor
2368 .timer(Duration::from_millis(10))
2369 .await;
2370 }
2371 assert!(
2372 terminals_started >= 2,
2373 "expected 2 terminal tools to start, got {terminals_started}"
2374 );
2375
2376 // Cancel the thread while both terminals are running
2377 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2378
2379 // Collect remaining events
2380 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2381
2382 // Verify both terminal handles were killed
2383 let handles = environment.handles();
2384 assert_eq!(
2385 handles.len(),
2386 2,
2387 "expected 2 terminal handles to be created"
2388 );
2389 assert!(
2390 handles[0].was_killed(),
2391 "expected first terminal handle to be killed on cancellation"
2392 );
2393 assert!(
2394 handles[1].was_killed(),
2395 "expected second terminal handle to be killed on cancellation"
2396 );
2397
2398 // Verify we got a cancellation stop event
2399 assert_eq!(
2400 stop_events(remaining_events),
2401 vec![acp::StopReason::Cancelled],
2402 );
2403}
2404
2405#[gpui::test]
2406async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2407 // Tests that clicking the stop button on the terminal card (as opposed to the main
2408 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2409 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2410 always_allow_tools(cx);
2411 let fake_model = model.as_fake();
2412
2413 let environment = Rc::new(cx.update(|cx| {
2414 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2415 }));
2416 let handle = environment.terminal_handle.clone().unwrap();
2417
2418 let mut events = thread
2419 .update(cx, |thread, cx| {
2420 thread.add_tool(
2421 crate::TerminalTool::new(thread.project().clone(), environment),
2422 None,
2423 );
2424 thread.send(UserMessageId::new(), ["run a command"], cx)
2425 })
2426 .unwrap();
2427
2428 cx.run_until_parked();
2429
2430 // Simulate the model calling the terminal tool
2431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2432 LanguageModelToolUse {
2433 id: "terminal_tool_1".into(),
2434 name: TerminalTool::NAME.into(),
2435 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2436 input: json!({"command": "sleep 1000", "cd": "."}),
2437 is_input_complete: true,
2438 thought_signature: None,
2439 },
2440 ));
2441 fake_model.end_last_completion_stream();
2442
2443 // Wait for the terminal tool to start running
2444 wait_for_terminal_tool_started(&mut events, cx).await;
2445
2446 // Simulate user clicking stop on the terminal card itself.
2447 // This sets the flag and signals exit (simulating what the real UI would do).
2448 handle.set_stopped_by_user(true);
2449 handle.killed.store(true, Ordering::SeqCst);
2450 handle.signal_exit();
2451
2452 // Wait for the tool to complete
2453 cx.run_until_parked();
2454
2455 // The thread continues after tool completion - simulate the model ending its turn
2456 fake_model
2457 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2458 fake_model.end_last_completion_stream();
2459
2460 // Collect remaining events
2461 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2462
2463 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2464 assert_eq!(
2465 stop_events(remaining_events),
2466 vec![acp::StopReason::EndTurn],
2467 );
2468
2469 // Verify the tool result indicates user stopped
2470 thread.update(cx, |thread, _cx| {
2471 let message = thread.last_message().unwrap();
2472 let agent_message = message.as_agent_message().unwrap();
2473
2474 let tool_use = agent_message
2475 .content
2476 .iter()
2477 .find_map(|content| match content {
2478 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2479 _ => None,
2480 })
2481 .expect("expected tool use in agent message");
2482
2483 let tool_result = agent_message
2484 .tool_results
2485 .get(&tool_use.id)
2486 .expect("expected tool result");
2487
2488 let result_text = match &tool_result.content {
2489 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2490 _ => panic!("expected text content in tool result"),
2491 };
2492
2493 assert!(
2494 result_text.contains("The user stopped this command"),
2495 "expected tool result to indicate user stopped, got: {result_text}"
2496 );
2497 });
2498}
2499
2500#[gpui::test]
2501async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2502 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2503 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2504 always_allow_tools(cx);
2505 let fake_model = model.as_fake();
2506
2507 let environment = Rc::new(cx.update(|cx| {
2508 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2509 }));
2510 let handle = environment.terminal_handle.clone().unwrap();
2511
2512 let mut events = thread
2513 .update(cx, |thread, cx| {
2514 thread.add_tool(
2515 crate::TerminalTool::new(thread.project().clone(), environment),
2516 None,
2517 );
2518 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2519 })
2520 .unwrap();
2521
2522 cx.run_until_parked();
2523
2524 // Simulate the model calling the terminal tool with a short timeout
2525 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2526 LanguageModelToolUse {
2527 id: "terminal_tool_1".into(),
2528 name: TerminalTool::NAME.into(),
2529 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2530 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2531 is_input_complete: true,
2532 thought_signature: None,
2533 },
2534 ));
2535 fake_model.end_last_completion_stream();
2536
2537 // Wait for the terminal tool to start running
2538 wait_for_terminal_tool_started(&mut events, cx).await;
2539
2540 // Advance clock past the timeout
2541 cx.executor().advance_clock(Duration::from_millis(200));
2542 cx.run_until_parked();
2543
2544 // The thread continues after tool completion - simulate the model ending its turn
2545 fake_model
2546 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2547 fake_model.end_last_completion_stream();
2548
2549 // Collect remaining events
2550 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2551
2552 // Verify the terminal was killed due to timeout
2553 assert!(
2554 handle.was_killed(),
2555 "expected terminal handle to be killed on timeout"
2556 );
2557
2558 // Verify we got an EndTurn (the tool completed, just with timeout)
2559 assert_eq!(
2560 stop_events(remaining_events),
2561 vec![acp::StopReason::EndTurn],
2562 );
2563
2564 // Verify the tool result indicates timeout, not user stopped
2565 thread.update(cx, |thread, _cx| {
2566 let message = thread.last_message().unwrap();
2567 let agent_message = message.as_agent_message().unwrap();
2568
2569 let tool_use = agent_message
2570 .content
2571 .iter()
2572 .find_map(|content| match content {
2573 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2574 _ => None,
2575 })
2576 .expect("expected tool use in agent message");
2577
2578 let tool_result = agent_message
2579 .tool_results
2580 .get(&tool_use.id)
2581 .expect("expected tool result");
2582
2583 let result_text = match &tool_result.content {
2584 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2585 _ => panic!("expected text content in tool result"),
2586 };
2587
2588 assert!(
2589 result_text.contains("timed out"),
2590 "expected tool result to indicate timeout, got: {result_text}"
2591 );
2592 assert!(
2593 !result_text.contains("The user stopped"),
2594 "tool result should not mention user stopped when it timed out, got: {result_text}"
2595 );
2596 });
2597}
2598
2599#[gpui::test]
2600async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2601 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2602 let fake_model = model.as_fake();
2603
2604 let events_1 = thread
2605 .update(cx, |thread, cx| {
2606 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2607 })
2608 .unwrap();
2609 cx.run_until_parked();
2610 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2611 cx.run_until_parked();
2612
2613 let events_2 = thread
2614 .update(cx, |thread, cx| {
2615 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2616 })
2617 .unwrap();
2618 cx.run_until_parked();
2619 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2620 fake_model
2621 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2622 fake_model.end_last_completion_stream();
2623
2624 let events_1 = events_1.collect::<Vec<_>>().await;
2625 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2626 let events_2 = events_2.collect::<Vec<_>>().await;
2627 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2628}
2629
2630#[gpui::test]
2631async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2632 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2633 let fake_model = model.as_fake();
2634
2635 let events_1 = thread
2636 .update(cx, |thread, cx| {
2637 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2638 })
2639 .unwrap();
2640 cx.run_until_parked();
2641 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2642 fake_model
2643 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2644 fake_model.end_last_completion_stream();
2645 let events_1 = events_1.collect::<Vec<_>>().await;
2646
2647 let events_2 = thread
2648 .update(cx, |thread, cx| {
2649 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2650 })
2651 .unwrap();
2652 cx.run_until_parked();
2653 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2654 fake_model
2655 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2656 fake_model.end_last_completion_stream();
2657 let events_2 = events_2.collect::<Vec<_>>().await;
2658
2659 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2660 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2661}
2662
2663#[gpui::test]
2664async fn test_refusal(cx: &mut TestAppContext) {
2665 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2666 let fake_model = model.as_fake();
2667
2668 let events = thread
2669 .update(cx, |thread, cx| {
2670 thread.send(UserMessageId::new(), ["Hello"], cx)
2671 })
2672 .unwrap();
2673 cx.run_until_parked();
2674 thread.read_with(cx, |thread, _| {
2675 assert_eq!(
2676 thread.to_markdown(),
2677 indoc! {"
2678 ## User
2679
2680 Hello
2681 "}
2682 );
2683 });
2684
2685 fake_model.send_last_completion_stream_text_chunk("Hey!");
2686 cx.run_until_parked();
2687 thread.read_with(cx, |thread, _| {
2688 assert_eq!(
2689 thread.to_markdown(),
2690 indoc! {"
2691 ## User
2692
2693 Hello
2694
2695 ## Assistant
2696
2697 Hey!
2698 "}
2699 );
2700 });
2701
2702 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2703 fake_model
2704 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2705 let events = events.collect::<Vec<_>>().await;
2706 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2707 thread.read_with(cx, |thread, _| {
2708 assert_eq!(thread.to_markdown(), "");
2709 });
2710}
2711
2712#[gpui::test]
2713async fn test_truncate_first_message(cx: &mut TestAppContext) {
2714 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2715 let fake_model = model.as_fake();
2716
2717 let message_id = UserMessageId::new();
2718 thread
2719 .update(cx, |thread, cx| {
2720 thread.send(message_id.clone(), ["Hello"], cx)
2721 })
2722 .unwrap();
2723 cx.run_until_parked();
2724 thread.read_with(cx, |thread, _| {
2725 assert_eq!(
2726 thread.to_markdown(),
2727 indoc! {"
2728 ## User
2729
2730 Hello
2731 "}
2732 );
2733 assert_eq!(thread.latest_token_usage(), None);
2734 });
2735
2736 fake_model.send_last_completion_stream_text_chunk("Hey!");
2737 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2738 language_model::TokenUsage {
2739 input_tokens: 32_000,
2740 output_tokens: 16_000,
2741 cache_creation_input_tokens: 0,
2742 cache_read_input_tokens: 0,
2743 },
2744 ));
2745 cx.run_until_parked();
2746 thread.read_with(cx, |thread, _| {
2747 assert_eq!(
2748 thread.to_markdown(),
2749 indoc! {"
2750 ## User
2751
2752 Hello
2753
2754 ## Assistant
2755
2756 Hey!
2757 "}
2758 );
2759 assert_eq!(
2760 thread.latest_token_usage(),
2761 Some(acp_thread::TokenUsage {
2762 used_tokens: 32_000 + 16_000,
2763 max_tokens: 1_000_000,
2764 max_output_tokens: None,
2765 input_tokens: 32_000,
2766 output_tokens: 16_000,
2767 })
2768 );
2769 });
2770
2771 thread
2772 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2773 .unwrap();
2774 cx.run_until_parked();
2775 thread.read_with(cx, |thread, _| {
2776 assert_eq!(thread.to_markdown(), "");
2777 assert_eq!(thread.latest_token_usage(), None);
2778 });
2779
2780 // Ensure we can still send a new message after truncation.
2781 thread
2782 .update(cx, |thread, cx| {
2783 thread.send(UserMessageId::new(), ["Hi"], cx)
2784 })
2785 .unwrap();
2786 thread.update(cx, |thread, _cx| {
2787 assert_eq!(
2788 thread.to_markdown(),
2789 indoc! {"
2790 ## User
2791
2792 Hi
2793 "}
2794 );
2795 });
2796 cx.run_until_parked();
2797 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2798 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2799 language_model::TokenUsage {
2800 input_tokens: 40_000,
2801 output_tokens: 20_000,
2802 cache_creation_input_tokens: 0,
2803 cache_read_input_tokens: 0,
2804 },
2805 ));
2806 cx.run_until_parked();
2807 thread.read_with(cx, |thread, _| {
2808 assert_eq!(
2809 thread.to_markdown(),
2810 indoc! {"
2811 ## User
2812
2813 Hi
2814
2815 ## Assistant
2816
2817 Ahoy!
2818 "}
2819 );
2820
2821 assert_eq!(
2822 thread.latest_token_usage(),
2823 Some(acp_thread::TokenUsage {
2824 used_tokens: 40_000 + 20_000,
2825 max_tokens: 1_000_000,
2826 max_output_tokens: None,
2827 input_tokens: 40_000,
2828 output_tokens: 20_000,
2829 })
2830 );
2831 });
2832}
2833
2834#[gpui::test]
2835async fn test_truncate_second_message(cx: &mut TestAppContext) {
2836 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2837 let fake_model = model.as_fake();
2838
2839 thread
2840 .update(cx, |thread, cx| {
2841 thread.send(UserMessageId::new(), ["Message 1"], cx)
2842 })
2843 .unwrap();
2844 cx.run_until_parked();
2845 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2846 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2847 language_model::TokenUsage {
2848 input_tokens: 32_000,
2849 output_tokens: 16_000,
2850 cache_creation_input_tokens: 0,
2851 cache_read_input_tokens: 0,
2852 },
2853 ));
2854 fake_model.end_last_completion_stream();
2855 cx.run_until_parked();
2856
2857 let assert_first_message_state = |cx: &mut TestAppContext| {
2858 thread.clone().read_with(cx, |thread, _| {
2859 assert_eq!(
2860 thread.to_markdown(),
2861 indoc! {"
2862 ## User
2863
2864 Message 1
2865
2866 ## Assistant
2867
2868 Message 1 response
2869 "}
2870 );
2871
2872 assert_eq!(
2873 thread.latest_token_usage(),
2874 Some(acp_thread::TokenUsage {
2875 used_tokens: 32_000 + 16_000,
2876 max_tokens: 1_000_000,
2877 max_output_tokens: None,
2878 input_tokens: 32_000,
2879 output_tokens: 16_000,
2880 })
2881 );
2882 });
2883 };
2884
2885 assert_first_message_state(cx);
2886
2887 let second_message_id = UserMessageId::new();
2888 thread
2889 .update(cx, |thread, cx| {
2890 thread.send(second_message_id.clone(), ["Message 2"], cx)
2891 })
2892 .unwrap();
2893 cx.run_until_parked();
2894
2895 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2896 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2897 language_model::TokenUsage {
2898 input_tokens: 40_000,
2899 output_tokens: 20_000,
2900 cache_creation_input_tokens: 0,
2901 cache_read_input_tokens: 0,
2902 },
2903 ));
2904 fake_model.end_last_completion_stream();
2905 cx.run_until_parked();
2906
2907 thread.read_with(cx, |thread, _| {
2908 assert_eq!(
2909 thread.to_markdown(),
2910 indoc! {"
2911 ## User
2912
2913 Message 1
2914
2915 ## Assistant
2916
2917 Message 1 response
2918
2919 ## User
2920
2921 Message 2
2922
2923 ## Assistant
2924
2925 Message 2 response
2926 "}
2927 );
2928
2929 assert_eq!(
2930 thread.latest_token_usage(),
2931 Some(acp_thread::TokenUsage {
2932 used_tokens: 40_000 + 20_000,
2933 max_tokens: 1_000_000,
2934 max_output_tokens: None,
2935 input_tokens: 40_000,
2936 output_tokens: 20_000,
2937 })
2938 );
2939 });
2940
2941 thread
2942 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2943 .unwrap();
2944 cx.run_until_parked();
2945
2946 assert_first_message_state(cx);
2947}
2948
2949#[gpui::test]
2950async fn test_title_generation(cx: &mut TestAppContext) {
2951 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2952 let fake_model = model.as_fake();
2953
2954 let summary_model = Arc::new(FakeLanguageModel::default());
2955 thread.update(cx, |thread, cx| {
2956 thread.set_summarization_model(Some(summary_model.clone()), cx)
2957 });
2958
2959 let send = thread
2960 .update(cx, |thread, cx| {
2961 thread.send(UserMessageId::new(), ["Hello"], cx)
2962 })
2963 .unwrap();
2964 cx.run_until_parked();
2965
2966 fake_model.send_last_completion_stream_text_chunk("Hey!");
2967 fake_model.end_last_completion_stream();
2968 cx.run_until_parked();
2969 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2970
2971 // Ensure the summary model has been invoked to generate a title.
2972 summary_model.send_last_completion_stream_text_chunk("Hello ");
2973 summary_model.send_last_completion_stream_text_chunk("world\nG");
2974 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2975 summary_model.end_last_completion_stream();
2976 send.collect::<Vec<_>>().await;
2977 cx.run_until_parked();
2978 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2979
2980 // Send another message, ensuring no title is generated this time.
2981 let send = thread
2982 .update(cx, |thread, cx| {
2983 thread.send(UserMessageId::new(), ["Hello again"], cx)
2984 })
2985 .unwrap();
2986 cx.run_until_parked();
2987 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2988 fake_model.end_last_completion_stream();
2989 cx.run_until_parked();
2990 assert_eq!(summary_model.pending_completions(), Vec::new());
2991 send.collect::<Vec<_>>().await;
2992 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2993}
2994
2995#[gpui::test]
2996async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2997 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2998 let fake_model = model.as_fake();
2999
3000 let _events = thread
3001 .update(cx, |thread, cx| {
3002 thread.add_tool(ToolRequiringPermission, None);
3003 thread.add_tool(EchoTool, None);
3004 thread.send(UserMessageId::new(), ["Hey!"], cx)
3005 })
3006 .unwrap();
3007 cx.run_until_parked();
3008
3009 let permission_tool_use = LanguageModelToolUse {
3010 id: "tool_id_1".into(),
3011 name: ToolRequiringPermission::NAME.into(),
3012 raw_input: "{}".into(),
3013 input: json!({}),
3014 is_input_complete: true,
3015 thought_signature: None,
3016 };
3017 let echo_tool_use = LanguageModelToolUse {
3018 id: "tool_id_2".into(),
3019 name: EchoTool::NAME.into(),
3020 raw_input: json!({"text": "test"}).to_string(),
3021 input: json!({"text": "test"}),
3022 is_input_complete: true,
3023 thought_signature: None,
3024 };
3025 fake_model.send_last_completion_stream_text_chunk("Hi!");
3026 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3027 permission_tool_use,
3028 ));
3029 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3030 echo_tool_use.clone(),
3031 ));
3032 fake_model.end_last_completion_stream();
3033 cx.run_until_parked();
3034
3035 // Ensure pending tools are skipped when building a request.
3036 let request = thread
3037 .read_with(cx, |thread, cx| {
3038 thread.build_completion_request(CompletionIntent::EditFile, cx)
3039 })
3040 .unwrap();
3041 assert_eq!(
3042 request.messages[1..],
3043 vec![
3044 LanguageModelRequestMessage {
3045 role: Role::User,
3046 content: vec!["Hey!".into()],
3047 cache: true,
3048 reasoning_details: None,
3049 },
3050 LanguageModelRequestMessage {
3051 role: Role::Assistant,
3052 content: vec![
3053 MessageContent::Text("Hi!".into()),
3054 MessageContent::ToolUse(echo_tool_use.clone())
3055 ],
3056 cache: false,
3057 reasoning_details: None,
3058 },
3059 LanguageModelRequestMessage {
3060 role: Role::User,
3061 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3062 tool_use_id: echo_tool_use.id.clone(),
3063 tool_name: echo_tool_use.name,
3064 is_error: false,
3065 content: "test".into(),
3066 output: Some("test".into())
3067 })],
3068 cache: false,
3069 reasoning_details: None,
3070 },
3071 ],
3072 );
3073}
3074
3075#[gpui::test]
3076async fn test_agent_connection(cx: &mut TestAppContext) {
3077 cx.update(settings::init);
3078 let templates = Templates::new();
3079
3080 // Initialize language model system with test provider
3081 cx.update(|cx| {
3082 gpui_tokio::init(cx);
3083
3084 let http_client = FakeHttpClient::with_404_response();
3085 let clock = Arc::new(clock::FakeSystemClock::new());
3086 let client = Client::new(clock, http_client, cx);
3087 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3088 language_model::init(client.clone(), cx);
3089 language_models::init(user_store, client.clone(), cx);
3090 LanguageModelRegistry::test(cx);
3091 });
3092 cx.executor().forbid_parking();
3093
3094 // Create a project for new_thread
3095 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3096 fake_fs.insert_tree(path!("/test"), json!({})).await;
3097 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3098 let cwd = Path::new("/test");
3099 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3100
3101 // Create agent and connection
3102 let agent = NativeAgent::new(
3103 project.clone(),
3104 thread_store,
3105 templates.clone(),
3106 None,
3107 fake_fs.clone(),
3108 &mut cx.to_async(),
3109 )
3110 .await
3111 .unwrap();
3112 let connection = NativeAgentConnection(agent.clone());
3113
3114 // Create a thread using new_thread
3115 let connection_rc = Rc::new(connection.clone());
3116 let acp_thread = cx
3117 .update(|cx| connection_rc.new_session(project, cwd, cx))
3118 .await
3119 .expect("new_thread should succeed");
3120
3121 // Get the session_id from the AcpThread
3122 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3123
3124 // Test model_selector returns Some
3125 let selector_opt = connection.model_selector(&session_id);
3126 assert!(
3127 selector_opt.is_some(),
3128 "agent should always support ModelSelector"
3129 );
3130 let selector = selector_opt.unwrap();
3131
3132 // Test list_models
3133 let listed_models = cx
3134 .update(|cx| selector.list_models(cx))
3135 .await
3136 .expect("list_models should succeed");
3137 let AgentModelList::Grouped(listed_models) = listed_models else {
3138 panic!("Unexpected model list type");
3139 };
3140 assert!(!listed_models.is_empty(), "should have at least one model");
3141 assert_eq!(
3142 listed_models[&AgentModelGroupName("Fake".into())][0]
3143 .id
3144 .0
3145 .as_ref(),
3146 "fake/fake"
3147 );
3148
3149 // Test selected_model returns the default
3150 let model = cx
3151 .update(|cx| selector.selected_model(cx))
3152 .await
3153 .expect("selected_model should succeed");
3154 let model = cx
3155 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3156 .unwrap();
3157 let model = model.as_fake();
3158 assert_eq!(model.id().0, "fake", "should return default model");
3159
3160 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3161 cx.run_until_parked();
3162 model.send_last_completion_stream_text_chunk("def");
3163 cx.run_until_parked();
3164 acp_thread.read_with(cx, |thread, cx| {
3165 assert_eq!(
3166 thread.to_markdown(cx),
3167 indoc! {"
3168 ## User
3169
3170 abc
3171
3172 ## Assistant
3173
3174 def
3175
3176 "}
3177 )
3178 });
3179
3180 // Test cancel
3181 cx.update(|cx| connection.cancel(&session_id, cx));
3182 request.await.expect("prompt should fail gracefully");
3183
3184 // Explicitly close the session and drop the ACP thread.
3185 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3186 .await
3187 .unwrap();
3188 drop(acp_thread);
3189 let result = cx
3190 .update(|cx| {
3191 connection.prompt(
3192 Some(acp_thread::UserMessageId::new()),
3193 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3194 cx,
3195 )
3196 })
3197 .await;
3198 assert_eq!(
3199 result.as_ref().unwrap_err().to_string(),
3200 "Session not found",
3201 "unexpected result: {:?}",
3202 result
3203 );
3204}
3205
3206#[gpui::test]
3207async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3208 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3209 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool, None));
3210 let fake_model = model.as_fake();
3211
3212 let mut events = thread
3213 .update(cx, |thread, cx| {
3214 thread.send(UserMessageId::new(), ["Echo something"], cx)
3215 })
3216 .unwrap();
3217 cx.run_until_parked();
3218
3219 // Simulate streaming partial input.
3220 let input = json!({});
3221 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3222 LanguageModelToolUse {
3223 id: "1".into(),
3224 name: EchoTool::NAME.into(),
3225 raw_input: input.to_string(),
3226 input,
3227 is_input_complete: false,
3228 thought_signature: None,
3229 },
3230 ));
3231
3232 // Input streaming completed
3233 let input = json!({ "text": "Hello!" });
3234 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3235 LanguageModelToolUse {
3236 id: "1".into(),
3237 name: "echo".into(),
3238 raw_input: input.to_string(),
3239 input,
3240 is_input_complete: true,
3241 thought_signature: None,
3242 },
3243 ));
3244 fake_model.end_last_completion_stream();
3245 cx.run_until_parked();
3246
3247 let tool_call = expect_tool_call(&mut events).await;
3248 assert_eq!(
3249 tool_call,
3250 acp::ToolCall::new("1", "Echo")
3251 .raw_input(json!({}))
3252 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3253 );
3254 let update = expect_tool_call_update_fields(&mut events).await;
3255 assert_eq!(
3256 update,
3257 acp::ToolCallUpdate::new(
3258 "1",
3259 acp::ToolCallUpdateFields::new()
3260 .title("Echo")
3261 .kind(acp::ToolKind::Other)
3262 .raw_input(json!({ "text": "Hello!"}))
3263 )
3264 );
3265 let update = expect_tool_call_update_fields(&mut events).await;
3266 assert_eq!(
3267 update,
3268 acp::ToolCallUpdate::new(
3269 "1",
3270 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3271 )
3272 );
3273 let update = expect_tool_call_update_fields(&mut events).await;
3274 assert_eq!(
3275 update,
3276 acp::ToolCallUpdate::new(
3277 "1",
3278 acp::ToolCallUpdateFields::new()
3279 .status(acp::ToolCallStatus::Completed)
3280 .raw_output("Hello!")
3281 )
3282 );
3283}
3284
3285#[gpui::test]
3286async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3287 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3288 let fake_model = model.as_fake();
3289
3290 let mut events = thread
3291 .update(cx, |thread, cx| {
3292 thread.send(UserMessageId::new(), ["Hello!"], cx)
3293 })
3294 .unwrap();
3295 cx.run_until_parked();
3296
3297 fake_model.send_last_completion_stream_text_chunk("Hey!");
3298 fake_model.end_last_completion_stream();
3299
3300 let mut retry_events = Vec::new();
3301 while let Some(Ok(event)) = events.next().await {
3302 match event {
3303 ThreadEvent::Retry(retry_status) => {
3304 retry_events.push(retry_status);
3305 }
3306 ThreadEvent::Stop(..) => break,
3307 _ => {}
3308 }
3309 }
3310
3311 assert_eq!(retry_events.len(), 0);
3312 thread.read_with(cx, |thread, _cx| {
3313 assert_eq!(
3314 thread.to_markdown(),
3315 indoc! {"
3316 ## User
3317
3318 Hello!
3319
3320 ## Assistant
3321
3322 Hey!
3323 "}
3324 )
3325 });
3326}
3327
3328#[gpui::test]
3329async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3330 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3331 let fake_model = model.as_fake();
3332
3333 let mut events = thread
3334 .update(cx, |thread, cx| {
3335 thread.send(UserMessageId::new(), ["Hello!"], cx)
3336 })
3337 .unwrap();
3338 cx.run_until_parked();
3339
3340 fake_model.send_last_completion_stream_text_chunk("Hey,");
3341 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3342 provider: LanguageModelProviderName::new("Anthropic"),
3343 retry_after: Some(Duration::from_secs(3)),
3344 });
3345 fake_model.end_last_completion_stream();
3346
3347 cx.executor().advance_clock(Duration::from_secs(3));
3348 cx.run_until_parked();
3349
3350 fake_model.send_last_completion_stream_text_chunk("there!");
3351 fake_model.end_last_completion_stream();
3352 cx.run_until_parked();
3353
3354 let mut retry_events = Vec::new();
3355 while let Some(Ok(event)) = events.next().await {
3356 match event {
3357 ThreadEvent::Retry(retry_status) => {
3358 retry_events.push(retry_status);
3359 }
3360 ThreadEvent::Stop(..) => break,
3361 _ => {}
3362 }
3363 }
3364
3365 assert_eq!(retry_events.len(), 1);
3366 assert!(matches!(
3367 retry_events[0],
3368 acp_thread::RetryStatus { attempt: 1, .. }
3369 ));
3370 thread.read_with(cx, |thread, _cx| {
3371 assert_eq!(
3372 thread.to_markdown(),
3373 indoc! {"
3374 ## User
3375
3376 Hello!
3377
3378 ## Assistant
3379
3380 Hey,
3381
3382 [resume]
3383
3384 ## Assistant
3385
3386 there!
3387 "}
3388 )
3389 });
3390}
3391
3392#[gpui::test]
3393async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3394 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3395 let fake_model = model.as_fake();
3396
3397 let events = thread
3398 .update(cx, |thread, cx| {
3399 thread.add_tool(EchoTool, None);
3400 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3401 })
3402 .unwrap();
3403 cx.run_until_parked();
3404
3405 let tool_use_1 = LanguageModelToolUse {
3406 id: "tool_1".into(),
3407 name: EchoTool::NAME.into(),
3408 raw_input: json!({"text": "test"}).to_string(),
3409 input: json!({"text": "test"}),
3410 is_input_complete: true,
3411 thought_signature: None,
3412 };
3413 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3414 tool_use_1.clone(),
3415 ));
3416 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3417 provider: LanguageModelProviderName::new("Anthropic"),
3418 retry_after: Some(Duration::from_secs(3)),
3419 });
3420 fake_model.end_last_completion_stream();
3421
3422 cx.executor().advance_clock(Duration::from_secs(3));
3423 let completion = fake_model.pending_completions().pop().unwrap();
3424 assert_eq!(
3425 completion.messages[1..],
3426 vec![
3427 LanguageModelRequestMessage {
3428 role: Role::User,
3429 content: vec!["Call the echo tool!".into()],
3430 cache: false,
3431 reasoning_details: None,
3432 },
3433 LanguageModelRequestMessage {
3434 role: Role::Assistant,
3435 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3436 cache: false,
3437 reasoning_details: None,
3438 },
3439 LanguageModelRequestMessage {
3440 role: Role::User,
3441 content: vec![language_model::MessageContent::ToolResult(
3442 LanguageModelToolResult {
3443 tool_use_id: tool_use_1.id.clone(),
3444 tool_name: tool_use_1.name.clone(),
3445 is_error: false,
3446 content: "test".into(),
3447 output: Some("test".into())
3448 }
3449 )],
3450 cache: true,
3451 reasoning_details: None,
3452 },
3453 ]
3454 );
3455
3456 fake_model.send_last_completion_stream_text_chunk("Done");
3457 fake_model.end_last_completion_stream();
3458 cx.run_until_parked();
3459 events.collect::<Vec<_>>().await;
3460 thread.read_with(cx, |thread, _cx| {
3461 assert_eq!(
3462 thread.last_message(),
3463 Some(Message::Agent(AgentMessage {
3464 content: vec![AgentMessageContent::Text("Done".into())],
3465 tool_results: IndexMap::default(),
3466 reasoning_details: None,
3467 }))
3468 );
3469 })
3470}
3471
3472#[gpui::test]
3473async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3474 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3475 let fake_model = model.as_fake();
3476
3477 let mut events = thread
3478 .update(cx, |thread, cx| {
3479 thread.send(UserMessageId::new(), ["Hello!"], cx)
3480 })
3481 .unwrap();
3482 cx.run_until_parked();
3483
3484 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3485 fake_model.send_last_completion_stream_error(
3486 LanguageModelCompletionError::ServerOverloaded {
3487 provider: LanguageModelProviderName::new("Anthropic"),
3488 retry_after: Some(Duration::from_secs(3)),
3489 },
3490 );
3491 fake_model.end_last_completion_stream();
3492 cx.executor().advance_clock(Duration::from_secs(3));
3493 cx.run_until_parked();
3494 }
3495
3496 let mut errors = Vec::new();
3497 let mut retry_events = Vec::new();
3498 while let Some(event) = events.next().await {
3499 match event {
3500 Ok(ThreadEvent::Retry(retry_status)) => {
3501 retry_events.push(retry_status);
3502 }
3503 Ok(ThreadEvent::Stop(..)) => break,
3504 Err(error) => errors.push(error),
3505 _ => {}
3506 }
3507 }
3508
3509 assert_eq!(
3510 retry_events.len(),
3511 crate::thread::MAX_RETRY_ATTEMPTS as usize
3512 );
3513 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3514 assert_eq!(retry_events[i].attempt, i + 1);
3515 }
3516 assert_eq!(errors.len(), 1);
3517 let error = errors[0]
3518 .downcast_ref::<LanguageModelCompletionError>()
3519 .unwrap();
3520 assert!(matches!(
3521 error,
3522 LanguageModelCompletionError::ServerOverloaded { .. }
3523 ));
3524}
3525
3526/// Filters out the stop events for asserting against in tests
3527fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3528 result_events
3529 .into_iter()
3530 .filter_map(|event| match event.unwrap() {
3531 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3532 _ => None,
3533 })
3534 .collect()
3535}
3536
3537struct ThreadTest {
3538 model: Arc<dyn LanguageModel>,
3539 thread: Entity<Thread>,
3540 project_context: Entity<ProjectContext>,
3541 context_server_store: Entity<ContextServerStore>,
3542 fs: Arc<FakeFs>,
3543}
3544
3545enum TestModel {
3546 Sonnet4,
3547 Fake,
3548}
3549
3550impl TestModel {
3551 fn id(&self) -> LanguageModelId {
3552 match self {
3553 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3554 TestModel::Fake => unreachable!(),
3555 }
3556 }
3557}
3558
3559async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3560 cx.executor().allow_parking();
3561
3562 let fs = FakeFs::new(cx.background_executor.clone());
3563 fs.create_dir(paths::settings_file().parent().unwrap())
3564 .await
3565 .unwrap();
3566 fs.insert_file(
3567 paths::settings_file(),
3568 json!({
3569 "agent": {
3570 "default_profile": "test-profile",
3571 "profiles": {
3572 "test-profile": {
3573 "name": "Test Profile",
3574 "tools": {
3575 EchoTool::NAME: true,
3576 DelayTool::NAME: true,
3577 WordListTool::NAME: true,
3578 ToolRequiringPermission::NAME: true,
3579 InfiniteTool::NAME: true,
3580 CancellationAwareTool::NAME: true,
3581 (TerminalTool::NAME): true,
3582 }
3583 }
3584 }
3585 }
3586 })
3587 .to_string()
3588 .into_bytes(),
3589 )
3590 .await;
3591
3592 cx.update(|cx| {
3593 settings::init(cx);
3594
3595 match model {
3596 TestModel::Fake => {}
3597 TestModel::Sonnet4 => {
3598 gpui_tokio::init(cx);
3599 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3600 cx.set_http_client(Arc::new(http_client));
3601 let client = Client::production(cx);
3602 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3603 language_model::init(client.clone(), cx);
3604 language_models::init(user_store, client.clone(), cx);
3605 }
3606 };
3607
3608 watch_settings(fs.clone(), cx);
3609 });
3610
3611 let templates = Templates::new();
3612
3613 fs.insert_tree(path!("/test"), json!({})).await;
3614 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3615
3616 let model = cx
3617 .update(|cx| {
3618 if let TestModel::Fake = model {
3619 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3620 } else {
3621 let model_id = model.id();
3622 let models = LanguageModelRegistry::read_global(cx);
3623 let model = models
3624 .available_models(cx)
3625 .find(|model| model.id() == model_id)
3626 .unwrap();
3627
3628 let provider = models.provider(&model.provider_id()).unwrap();
3629 let authenticated = provider.authenticate(cx);
3630
3631 cx.spawn(async move |_cx| {
3632 authenticated.await.unwrap();
3633 model
3634 })
3635 }
3636 })
3637 .await;
3638
3639 let project_context = cx.new(|_cx| ProjectContext::default());
3640 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3641 let context_server_registry =
3642 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3643 let thread = cx.new(|cx| {
3644 Thread::new(
3645 project,
3646 project_context.clone(),
3647 context_server_registry,
3648 templates,
3649 Some(model.clone()),
3650 cx,
3651 )
3652 });
3653 ThreadTest {
3654 model,
3655 thread,
3656 project_context,
3657 context_server_store,
3658 fs,
3659 }
3660}
3661
3662#[cfg(test)]
3663#[ctor::ctor]
3664fn init_logger() {
3665 if std::env::var("RUST_LOG").is_ok() {
3666 env_logger::init();
3667 }
3668}
3669
3670fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3671 let fs = fs.clone();
3672 cx.spawn({
3673 async move |cx| {
3674 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3675 cx.background_executor(),
3676 fs,
3677 paths::settings_file().clone(),
3678 );
3679 let _watcher_task = watcher_task;
3680
3681 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3682 cx.update(|cx| {
3683 SettingsStore::update_global(cx, |settings, cx| {
3684 settings.set_user_settings(&new_settings_content, cx)
3685 })
3686 })
3687 .ok();
3688 }
3689 }
3690 })
3691 .detach();
3692}
3693
3694fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3695 completion
3696 .tools
3697 .iter()
3698 .map(|tool| tool.name.clone())
3699 .collect()
3700}
3701
3702fn setup_context_server(
3703 name: &'static str,
3704 tools: Vec<context_server::types::Tool>,
3705 context_server_store: &Entity<ContextServerStore>,
3706 cx: &mut TestAppContext,
3707) -> mpsc::UnboundedReceiver<(
3708 context_server::types::CallToolParams,
3709 oneshot::Sender<context_server::types::CallToolResponse>,
3710)> {
3711 cx.update(|cx| {
3712 let mut settings = ProjectSettings::get_global(cx).clone();
3713 settings.context_servers.insert(
3714 name.into(),
3715 project::project_settings::ContextServerSettings::Stdio {
3716 enabled: true,
3717 remote: false,
3718 command: ContextServerCommand {
3719 path: "somebinary".into(),
3720 args: Vec::new(),
3721 env: None,
3722 timeout: None,
3723 },
3724 },
3725 );
3726 ProjectSettings::override_global(settings, cx);
3727 });
3728
3729 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3730 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3731 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3732 context_server::types::InitializeResponse {
3733 protocol_version: context_server::types::ProtocolVersion(
3734 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3735 ),
3736 server_info: context_server::types::Implementation {
3737 name: name.into(),
3738 version: "1.0.0".to_string(),
3739 },
3740 capabilities: context_server::types::ServerCapabilities {
3741 tools: Some(context_server::types::ToolsCapabilities {
3742 list_changed: Some(true),
3743 }),
3744 ..Default::default()
3745 },
3746 meta: None,
3747 }
3748 })
3749 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3750 let tools = tools.clone();
3751 async move {
3752 context_server::types::ListToolsResponse {
3753 tools,
3754 next_cursor: None,
3755 meta: None,
3756 }
3757 }
3758 })
3759 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3760 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3761 async move {
3762 let (response_tx, response_rx) = oneshot::channel();
3763 mcp_tool_calls_tx
3764 .unbounded_send((params, response_tx))
3765 .unwrap();
3766 response_rx.await.unwrap()
3767 }
3768 });
3769 context_server_store.update(cx, |store, cx| {
3770 store.start_server(
3771 Arc::new(ContextServer::new(
3772 ContextServerId(name.into()),
3773 Arc::new(fake_transport),
3774 )),
3775 cx,
3776 );
3777 });
3778 cx.run_until_parked();
3779 mcp_tool_calls_rx
3780}
3781
3782#[gpui::test]
3783async fn test_tokens_before_message(cx: &mut TestAppContext) {
3784 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3785 let fake_model = model.as_fake();
3786
3787 // First message
3788 let message_1_id = UserMessageId::new();
3789 thread
3790 .update(cx, |thread, cx| {
3791 thread.send(message_1_id.clone(), ["First message"], cx)
3792 })
3793 .unwrap();
3794 cx.run_until_parked();
3795
3796 // Before any response, tokens_before_message should return None for first message
3797 thread.read_with(cx, |thread, _| {
3798 assert_eq!(
3799 thread.tokens_before_message(&message_1_id),
3800 None,
3801 "First message should have no tokens before it"
3802 );
3803 });
3804
3805 // Complete first message with usage
3806 fake_model.send_last_completion_stream_text_chunk("Response 1");
3807 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3808 language_model::TokenUsage {
3809 input_tokens: 100,
3810 output_tokens: 50,
3811 cache_creation_input_tokens: 0,
3812 cache_read_input_tokens: 0,
3813 },
3814 ));
3815 fake_model.end_last_completion_stream();
3816 cx.run_until_parked();
3817
3818 // First message still has no tokens before it
3819 thread.read_with(cx, |thread, _| {
3820 assert_eq!(
3821 thread.tokens_before_message(&message_1_id),
3822 None,
3823 "First message should still have no tokens before it after response"
3824 );
3825 });
3826
3827 // Second message
3828 let message_2_id = UserMessageId::new();
3829 thread
3830 .update(cx, |thread, cx| {
3831 thread.send(message_2_id.clone(), ["Second message"], cx)
3832 })
3833 .unwrap();
3834 cx.run_until_parked();
3835
3836 // Second message should have first message's input tokens before it
3837 thread.read_with(cx, |thread, _| {
3838 assert_eq!(
3839 thread.tokens_before_message(&message_2_id),
3840 Some(100),
3841 "Second message should have 100 tokens before it (from first request)"
3842 );
3843 });
3844
3845 // Complete second message
3846 fake_model.send_last_completion_stream_text_chunk("Response 2");
3847 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3848 language_model::TokenUsage {
3849 input_tokens: 250, // Total for this request (includes previous context)
3850 output_tokens: 75,
3851 cache_creation_input_tokens: 0,
3852 cache_read_input_tokens: 0,
3853 },
3854 ));
3855 fake_model.end_last_completion_stream();
3856 cx.run_until_parked();
3857
3858 // Third message
3859 let message_3_id = UserMessageId::new();
3860 thread
3861 .update(cx, |thread, cx| {
3862 thread.send(message_3_id.clone(), ["Third message"], cx)
3863 })
3864 .unwrap();
3865 cx.run_until_parked();
3866
3867 // Third message should have second message's input tokens (250) before it
3868 thread.read_with(cx, |thread, _| {
3869 assert_eq!(
3870 thread.tokens_before_message(&message_3_id),
3871 Some(250),
3872 "Third message should have 250 tokens before it (from second request)"
3873 );
3874 // Second message should still have 100
3875 assert_eq!(
3876 thread.tokens_before_message(&message_2_id),
3877 Some(100),
3878 "Second message should still have 100 tokens before it"
3879 );
3880 // First message still has none
3881 assert_eq!(
3882 thread.tokens_before_message(&message_1_id),
3883 None,
3884 "First message should still have no tokens before it"
3885 );
3886 });
3887}
3888
3889#[gpui::test]
3890async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3891 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3892 let fake_model = model.as_fake();
3893
3894 // Set up three messages with responses
3895 let message_1_id = UserMessageId::new();
3896 thread
3897 .update(cx, |thread, cx| {
3898 thread.send(message_1_id.clone(), ["Message 1"], cx)
3899 })
3900 .unwrap();
3901 cx.run_until_parked();
3902 fake_model.send_last_completion_stream_text_chunk("Response 1");
3903 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3904 language_model::TokenUsage {
3905 input_tokens: 100,
3906 output_tokens: 50,
3907 cache_creation_input_tokens: 0,
3908 cache_read_input_tokens: 0,
3909 },
3910 ));
3911 fake_model.end_last_completion_stream();
3912 cx.run_until_parked();
3913
3914 let message_2_id = UserMessageId::new();
3915 thread
3916 .update(cx, |thread, cx| {
3917 thread.send(message_2_id.clone(), ["Message 2"], cx)
3918 })
3919 .unwrap();
3920 cx.run_until_parked();
3921 fake_model.send_last_completion_stream_text_chunk("Response 2");
3922 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3923 language_model::TokenUsage {
3924 input_tokens: 250,
3925 output_tokens: 75,
3926 cache_creation_input_tokens: 0,
3927 cache_read_input_tokens: 0,
3928 },
3929 ));
3930 fake_model.end_last_completion_stream();
3931 cx.run_until_parked();
3932
3933 // Verify initial state
3934 thread.read_with(cx, |thread, _| {
3935 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3936 });
3937
3938 // Truncate at message 2 (removes message 2 and everything after)
3939 thread
3940 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3941 .unwrap();
3942 cx.run_until_parked();
3943
3944 // After truncation, message_2_id no longer exists, so lookup should return None
3945 thread.read_with(cx, |thread, _| {
3946 assert_eq!(
3947 thread.tokens_before_message(&message_2_id),
3948 None,
3949 "After truncation, message 2 no longer exists"
3950 );
3951 // Message 1 still exists but has no tokens before it
3952 assert_eq!(
3953 thread.tokens_before_message(&message_1_id),
3954 None,
3955 "First message still has no tokens before it"
3956 );
3957 });
3958}
3959
3960#[gpui::test]
3961async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3962 init_test(cx);
3963
3964 let fs = FakeFs::new(cx.executor());
3965 fs.insert_tree("/root", json!({})).await;
3966 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3967
3968 // Test 1: Deny rule blocks command
3969 {
3970 let environment = Rc::new(cx.update(|cx| {
3971 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3972 }));
3973
3974 cx.update(|cx| {
3975 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3976 settings.tool_permissions.tools.insert(
3977 TerminalTool::NAME.into(),
3978 agent_settings::ToolRules {
3979 default: Some(settings::ToolPermissionMode::Confirm),
3980 always_allow: vec![],
3981 always_deny: vec![
3982 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3983 ],
3984 always_confirm: vec![],
3985 invalid_patterns: vec![],
3986 },
3987 );
3988 agent_settings::AgentSettings::override_global(settings, cx);
3989 });
3990
3991 #[allow(clippy::arc_with_non_send_sync)]
3992 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3993 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3994
3995 let task = cx.update(|cx| {
3996 tool.run(
3997 crate::TerminalToolInput {
3998 command: "rm -rf /".to_string(),
3999 cd: ".".to_string(),
4000 timeout_ms: None,
4001 },
4002 event_stream,
4003 cx,
4004 )
4005 });
4006
4007 let result = task.await;
4008 assert!(
4009 result.is_err(),
4010 "expected command to be blocked by deny rule"
4011 );
4012 let err_msg = result.unwrap_err().to_string().to_lowercase();
4013 assert!(
4014 err_msg.contains("blocked"),
4015 "error should mention the command was blocked"
4016 );
4017 }
4018
4019 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4020 {
4021 let environment = Rc::new(cx.update(|cx| {
4022 FakeThreadEnvironment::default()
4023 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4024 }));
4025
4026 cx.update(|cx| {
4027 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4028 settings.tool_permissions.tools.insert(
4029 TerminalTool::NAME.into(),
4030 agent_settings::ToolRules {
4031 default: Some(settings::ToolPermissionMode::Deny),
4032 always_allow: vec![
4033 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4034 ],
4035 always_deny: vec![],
4036 always_confirm: vec![],
4037 invalid_patterns: vec![],
4038 },
4039 );
4040 agent_settings::AgentSettings::override_global(settings, cx);
4041 });
4042
4043 #[allow(clippy::arc_with_non_send_sync)]
4044 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4045 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4046
4047 let task = cx.update(|cx| {
4048 tool.run(
4049 crate::TerminalToolInput {
4050 command: "echo hello".to_string(),
4051 cd: ".".to_string(),
4052 timeout_ms: None,
4053 },
4054 event_stream,
4055 cx,
4056 )
4057 });
4058
4059 let update = rx.expect_update_fields().await;
4060 assert!(
4061 update.content.iter().any(|blocks| {
4062 blocks
4063 .iter()
4064 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4065 }),
4066 "expected terminal content (allow rule should skip confirmation and override default deny)"
4067 );
4068
4069 let result = task.await;
4070 assert!(
4071 result.is_ok(),
4072 "expected command to succeed without confirmation"
4073 );
4074 }
4075
4076 // Test 3: global default: allow does NOT override always_confirm patterns
4077 {
4078 let environment = Rc::new(cx.update(|cx| {
4079 FakeThreadEnvironment::default()
4080 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4081 }));
4082
4083 cx.update(|cx| {
4084 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4085 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4086 settings.tool_permissions.tools.insert(
4087 TerminalTool::NAME.into(),
4088 agent_settings::ToolRules {
4089 default: Some(settings::ToolPermissionMode::Allow),
4090 always_allow: vec![],
4091 always_deny: vec![],
4092 always_confirm: vec![
4093 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4094 ],
4095 invalid_patterns: vec![],
4096 },
4097 );
4098 agent_settings::AgentSettings::override_global(settings, cx);
4099 });
4100
4101 #[allow(clippy::arc_with_non_send_sync)]
4102 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4103 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4104
4105 let _task = cx.update(|cx| {
4106 tool.run(
4107 crate::TerminalToolInput {
4108 command: "sudo rm file".to_string(),
4109 cd: ".".to_string(),
4110 timeout_ms: None,
4111 },
4112 event_stream,
4113 cx,
4114 )
4115 });
4116
4117 // With global default: allow, confirm patterns are still respected
4118 // The expect_authorization() call will panic if no authorization is requested,
4119 // which validates that the confirm pattern still triggers confirmation
4120 let _auth = rx.expect_authorization().await;
4121
4122 drop(_task);
4123 }
4124
4125 // Test 4: tool-specific default: deny is respected even with global default: allow
4126 {
4127 let environment = Rc::new(cx.update(|cx| {
4128 FakeThreadEnvironment::default()
4129 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4130 }));
4131
4132 cx.update(|cx| {
4133 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4134 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4135 settings.tool_permissions.tools.insert(
4136 TerminalTool::NAME.into(),
4137 agent_settings::ToolRules {
4138 default: Some(settings::ToolPermissionMode::Deny),
4139 always_allow: vec![],
4140 always_deny: vec![],
4141 always_confirm: vec![],
4142 invalid_patterns: vec![],
4143 },
4144 );
4145 agent_settings::AgentSettings::override_global(settings, cx);
4146 });
4147
4148 #[allow(clippy::arc_with_non_send_sync)]
4149 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4150 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4151
4152 let task = cx.update(|cx| {
4153 tool.run(
4154 crate::TerminalToolInput {
4155 command: "echo hello".to_string(),
4156 cd: ".".to_string(),
4157 timeout_ms: None,
4158 },
4159 event_stream,
4160 cx,
4161 )
4162 });
4163
4164 // tool-specific default: deny is respected even with global default: allow
4165 let result = task.await;
4166 assert!(
4167 result.is_err(),
4168 "expected command to be blocked by tool-specific deny default"
4169 );
4170 let err_msg = result.unwrap_err().to_string().to_lowercase();
4171 assert!(
4172 err_msg.contains("disabled"),
4173 "error should mention the tool is disabled, got: {err_msg}"
4174 );
4175 }
4176}
4177
4178#[gpui::test]
4179async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4180 init_test(cx);
4181 cx.update(|cx| {
4182 LanguageModelRegistry::test(cx);
4183 });
4184 cx.update(|cx| {
4185 cx.update_flags(true, vec!["subagents".to_string()]);
4186 });
4187
4188 let fs = FakeFs::new(cx.executor());
4189 fs.insert_tree(
4190 "/",
4191 json!({
4192 "a": {
4193 "b.md": "Lorem"
4194 }
4195 }),
4196 )
4197 .await;
4198 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4199 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4200 let agent = NativeAgent::new(
4201 project.clone(),
4202 thread_store.clone(),
4203 Templates::new(),
4204 None,
4205 fs.clone(),
4206 &mut cx.to_async(),
4207 )
4208 .await
4209 .unwrap();
4210 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4211
4212 let acp_thread = cx
4213 .update(|cx| {
4214 connection
4215 .clone()
4216 .new_session(project.clone(), Path::new(""), cx)
4217 })
4218 .await
4219 .unwrap();
4220 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4221 let thread = agent.read_with(cx, |agent, _| {
4222 agent.sessions.get(&session_id).unwrap().thread.clone()
4223 });
4224 let model = Arc::new(FakeLanguageModel::default());
4225
4226 // Ensure empty threads are not saved, even if they get mutated.
4227 thread.update(cx, |thread, cx| {
4228 thread.set_model(model.clone(), cx);
4229 });
4230 cx.run_until_parked();
4231
4232 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4233 cx.run_until_parked();
4234 model.send_last_completion_stream_text_chunk("spawning subagent");
4235 let subagent_tool_input = SubagentToolInput {
4236 label: "label".to_string(),
4237 task_prompt: "subagent task prompt".to_string(),
4238 summary_prompt: "subagent summary prompt".to_string(),
4239 timeout_ms: None,
4240 allowed_tools: None,
4241 };
4242 let subagent_tool_use = LanguageModelToolUse {
4243 id: "subagent_1".into(),
4244 name: SubagentTool::NAME.into(),
4245 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4246 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4247 is_input_complete: true,
4248 thought_signature: None,
4249 };
4250 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4251 subagent_tool_use,
4252 ));
4253 model.end_last_completion_stream();
4254
4255 cx.run_until_parked();
4256
4257 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4258 thread
4259 .running_subagent_ids(cx)
4260 .get(0)
4261 .expect("subagent thread should be running")
4262 .clone()
4263 });
4264
4265 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4266 agent
4267 .sessions
4268 .get(&subagent_session_id)
4269 .expect("subagent session should exist")
4270 .acp_thread
4271 .clone()
4272 });
4273
4274 model.send_last_completion_stream_text_chunk("subagent task response");
4275 model.end_last_completion_stream();
4276
4277 cx.run_until_parked();
4278
4279 model.send_last_completion_stream_text_chunk("subagent summary response");
4280 model.end_last_completion_stream();
4281
4282 cx.run_until_parked();
4283
4284 assert_eq!(
4285 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4286 indoc! {"
4287 ## User
4288
4289 subagent task prompt
4290
4291 ## Assistant
4292
4293 subagent task response
4294
4295 ## User
4296
4297 subagent summary prompt
4298
4299 ## Assistant
4300
4301 subagent summary response
4302
4303 "}
4304 );
4305
4306 model.send_last_completion_stream_text_chunk("Response");
4307 model.end_last_completion_stream();
4308
4309 send.await.unwrap();
4310
4311 assert_eq!(
4312 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4313 format!(
4314 indoc! {r#"
4315 ## User
4316
4317 Prompt
4318
4319 ## Assistant
4320
4321 spawning subagent
4322
4323 **Tool Call: label**
4324 Status: Completed
4325
4326 ```json
4327 {{
4328 "subagent_session_id": "{}",
4329 "summary": "subagent summary response\n"
4330 }}
4331 ```
4332
4333 ## Assistant
4334
4335 Response
4336
4337 "#},
4338 subagent_session_id
4339 )
4340 );
4341}
4342
4343#[gpui::test]
4344async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4345 init_test(cx);
4346 cx.update(|cx| {
4347 LanguageModelRegistry::test(cx);
4348 });
4349 cx.update(|cx| {
4350 cx.update_flags(true, vec!["subagents".to_string()]);
4351 });
4352
4353 let fs = FakeFs::new(cx.executor());
4354 fs.insert_tree(
4355 "/",
4356 json!({
4357 "a": {
4358 "b.md": "Lorem"
4359 }
4360 }),
4361 )
4362 .await;
4363 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4364 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4365 let agent = NativeAgent::new(
4366 project.clone(),
4367 thread_store.clone(),
4368 Templates::new(),
4369 None,
4370 fs.clone(),
4371 &mut cx.to_async(),
4372 )
4373 .await
4374 .unwrap();
4375 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4376
4377 let acp_thread = cx
4378 .update(|cx| {
4379 connection
4380 .clone()
4381 .new_session(project.clone(), Path::new(""), cx)
4382 })
4383 .await
4384 .unwrap();
4385 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4386 let thread = agent.read_with(cx, |agent, _| {
4387 agent.sessions.get(&session_id).unwrap().thread.clone()
4388 });
4389 let model = Arc::new(FakeLanguageModel::default());
4390
4391 // Ensure empty threads are not saved, even if they get mutated.
4392 thread.update(cx, |thread, cx| {
4393 thread.set_model(model.clone(), cx);
4394 });
4395 cx.run_until_parked();
4396
4397 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4398 cx.run_until_parked();
4399 model.send_last_completion_stream_text_chunk("spawning subagent");
4400 let subagent_tool_input = SubagentToolInput {
4401 label: "label".to_string(),
4402 task_prompt: "subagent task prompt".to_string(),
4403 summary_prompt: "subagent summary prompt".to_string(),
4404 timeout_ms: None,
4405 allowed_tools: None,
4406 };
4407 let subagent_tool_use = LanguageModelToolUse {
4408 id: "subagent_1".into(),
4409 name: SubagentTool::NAME.into(),
4410 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4411 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4412 is_input_complete: true,
4413 thought_signature: None,
4414 };
4415 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4416 subagent_tool_use,
4417 ));
4418 model.end_last_completion_stream();
4419
4420 cx.run_until_parked();
4421
4422 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4423 thread
4424 .running_subagent_ids(cx)
4425 .get(0)
4426 .expect("subagent thread should be running")
4427 .clone()
4428 });
4429 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4430 agent
4431 .sessions
4432 .get(&subagent_session_id)
4433 .expect("subagent session should exist")
4434 .acp_thread
4435 .clone()
4436 });
4437
4438 // model.send_last_completion_stream_text_chunk("subagent task response");
4439 // model.end_last_completion_stream();
4440
4441 // cx.run_until_parked();
4442
4443 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4444
4445 cx.run_until_parked();
4446
4447 send.await.unwrap();
4448
4449 acp_thread.read_with(cx, |thread, cx| {
4450 assert_eq!(thread.status(), ThreadStatus::Idle);
4451 assert_eq!(
4452 thread.to_markdown(cx),
4453 indoc! {"
4454 ## User
4455
4456 Prompt
4457
4458 ## Assistant
4459
4460 spawning subagent
4461
4462 **Tool Call: label**
4463 Status: Canceled
4464
4465 "}
4466 );
4467 });
4468 subagent_acp_thread.read_with(cx, |thread, cx| {
4469 assert_eq!(thread.status(), ThreadStatus::Idle);
4470 assert_eq!(
4471 thread.to_markdown(cx),
4472 indoc! {"
4473 ## User
4474
4475 subagent task prompt
4476
4477 "}
4478 );
4479 });
4480}
4481
4482#[gpui::test]
4483async fn test_subagent_tool_call_cancellation_during_summary_prompt(cx: &mut TestAppContext) {
4484 init_test(cx);
4485 cx.update(|cx| {
4486 LanguageModelRegistry::test(cx);
4487 });
4488 cx.update(|cx| {
4489 cx.update_flags(true, vec!["subagents".to_string()]);
4490 });
4491
4492 let fs = FakeFs::new(cx.executor());
4493 fs.insert_tree(
4494 "/",
4495 json!({
4496 "a": {
4497 "b.md": "Lorem"
4498 }
4499 }),
4500 )
4501 .await;
4502 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4503 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4504 let agent = NativeAgent::new(
4505 project.clone(),
4506 thread_store.clone(),
4507 Templates::new(),
4508 None,
4509 fs.clone(),
4510 &mut cx.to_async(),
4511 )
4512 .await
4513 .unwrap();
4514 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4515
4516 let acp_thread = cx
4517 .update(|cx| {
4518 connection
4519 .clone()
4520 .new_session(project.clone(), Path::new(""), cx)
4521 })
4522 .await
4523 .unwrap();
4524 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4525 let thread = agent.read_with(cx, |agent, _| {
4526 agent.sessions.get(&session_id).unwrap().thread.clone()
4527 });
4528 let model = Arc::new(FakeLanguageModel::default());
4529
4530 // Ensure empty threads are not saved, even if they get mutated.
4531 thread.update(cx, |thread, cx| {
4532 thread.set_model(model.clone(), cx);
4533 });
4534 cx.run_until_parked();
4535
4536 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4537 cx.run_until_parked();
4538 model.send_last_completion_stream_text_chunk("spawning subagent");
4539 let subagent_tool_input = SubagentToolInput {
4540 label: "label".to_string(),
4541 task_prompt: "subagent task prompt".to_string(),
4542 summary_prompt: "subagent summary prompt".to_string(),
4543 timeout_ms: None,
4544 allowed_tools: None,
4545 };
4546 let subagent_tool_use = LanguageModelToolUse {
4547 id: "subagent_1".into(),
4548 name: SubagentTool::NAME.into(),
4549 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4550 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4551 is_input_complete: true,
4552 thought_signature: None,
4553 };
4554 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4555 subagent_tool_use,
4556 ));
4557 model.end_last_completion_stream();
4558
4559 cx.run_until_parked();
4560
4561 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4562 thread
4563 .running_subagent_ids(cx)
4564 .get(0)
4565 .expect("subagent thread should be running")
4566 .clone()
4567 });
4568 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4569 agent
4570 .sessions
4571 .get(&subagent_session_id)
4572 .expect("subagent session should exist")
4573 .acp_thread
4574 .clone()
4575 });
4576
4577 model.send_last_completion_stream_text_chunk("subagent task response");
4578 model.end_last_completion_stream();
4579
4580 cx.run_until_parked();
4581
4582 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4583
4584 cx.run_until_parked();
4585
4586 send.await.unwrap();
4587
4588 acp_thread.read_with(cx, |thread, cx| {
4589 assert_eq!(thread.status(), ThreadStatus::Idle);
4590 assert_eq!(
4591 thread.to_markdown(cx),
4592 indoc! {"
4593 ## User
4594
4595 Prompt
4596
4597 ## Assistant
4598
4599 spawning subagent
4600
4601 **Tool Call: label**
4602 Status: Canceled
4603
4604 "}
4605 );
4606 });
4607 subagent_acp_thread.read_with(cx, |thread, cx| {
4608 assert_eq!(thread.status(), ThreadStatus::Idle);
4609 assert_eq!(
4610 thread.to_markdown(cx),
4611 indoc! {"
4612 ## User
4613
4614 subagent task prompt
4615
4616 ## Assistant
4617
4618 subagent task response
4619
4620 ## User
4621
4622 subagent summary prompt
4623
4624 "}
4625 );
4626 });
4627}
4628
4629#[gpui::test]
4630async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4631 init_test(cx);
4632
4633 cx.update(|cx| {
4634 cx.update_flags(true, vec!["subagents".to_string()]);
4635 });
4636
4637 let fs = FakeFs::new(cx.executor());
4638 fs.insert_tree(path!("/test"), json!({})).await;
4639 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4640 let project_context = cx.new(|_cx| ProjectContext::default());
4641 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4642 let context_server_registry =
4643 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4644 let model = Arc::new(FakeLanguageModel::default());
4645
4646 let environment = Rc::new(cx.update(|cx| {
4647 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4648 }));
4649
4650 let thread = cx.new(|cx| {
4651 let mut thread = Thread::new(
4652 project.clone(),
4653 project_context,
4654 context_server_registry,
4655 Templates::new(),
4656 Some(model),
4657 cx,
4658 );
4659 thread.add_default_tools(None, environment, cx);
4660 thread
4661 });
4662
4663 thread.read_with(cx, |thread, _| {
4664 assert!(
4665 thread.has_registered_tool(SubagentTool::NAME),
4666 "subagent tool should be present when feature flag is enabled"
4667 );
4668 });
4669}
4670
4671#[gpui::test]
4672async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4673 init_test(cx);
4674
4675 cx.update(|cx| {
4676 cx.update_flags(true, vec!["subagents".to_string()]);
4677 });
4678
4679 let fs = FakeFs::new(cx.executor());
4680 fs.insert_tree(path!("/test"), json!({})).await;
4681 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4682 let project_context = cx.new(|_cx| ProjectContext::default());
4683 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4684 let context_server_registry =
4685 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4686 let model = Arc::new(FakeLanguageModel::default());
4687
4688 let parent_thread = cx.new(|cx| {
4689 Thread::new(
4690 project.clone(),
4691 project_context,
4692 context_server_registry,
4693 Templates::new(),
4694 Some(model.clone()),
4695 cx,
4696 )
4697 });
4698
4699 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4700 subagent_thread.read_with(cx, |subagent_thread, cx| {
4701 assert!(subagent_thread.is_subagent());
4702 assert_eq!(subagent_thread.depth(), 1);
4703 assert_eq!(
4704 subagent_thread.model().map(|model| model.id()),
4705 Some(model.id())
4706 );
4707 assert_eq!(
4708 subagent_thread.parent_thread_id(),
4709 Some(parent_thread.read(cx).id().clone())
4710 );
4711 });
4712}
4713
4714#[gpui::test]
4715async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4716 init_test(cx);
4717
4718 cx.update(|cx| {
4719 cx.update_flags(true, vec!["subagents".to_string()]);
4720 });
4721
4722 let fs = FakeFs::new(cx.executor());
4723 fs.insert_tree(path!("/test"), json!({})).await;
4724 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4725 let project_context = cx.new(|_cx| ProjectContext::default());
4726 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4727 let context_server_registry =
4728 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4729 let model = Arc::new(FakeLanguageModel::default());
4730 let environment = Rc::new(cx.update(|cx| {
4731 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4732 }));
4733
4734 let deep_parent_thread = cx.new(|cx| {
4735 let mut thread = Thread::new(
4736 project.clone(),
4737 project_context,
4738 context_server_registry,
4739 Templates::new(),
4740 Some(model.clone()),
4741 cx,
4742 );
4743 thread.set_subagent_context(SubagentContext {
4744 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4745 depth: MAX_SUBAGENT_DEPTH - 1,
4746 });
4747 thread
4748 });
4749 let deep_subagent_thread = cx.new(|cx| {
4750 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4751 thread.add_default_tools(None, environment, cx);
4752 thread
4753 });
4754
4755 deep_subagent_thread.read_with(cx, |thread, _| {
4756 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4757 assert!(
4758 !thread.has_registered_tool(SubagentTool::NAME),
4759 "subagent tool should not be present at max depth"
4760 );
4761 });
4762}
4763
4764#[gpui::test]
4765async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4766 init_test(cx);
4767
4768 cx.update(|cx| {
4769 cx.update_flags(true, vec!["subagents".to_string()]);
4770 });
4771
4772 let fs = FakeFs::new(cx.executor());
4773 fs.insert_tree(path!("/test"), json!({})).await;
4774 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4775 let project_context = cx.new(|_cx| ProjectContext::default());
4776 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4777 let context_server_registry =
4778 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4779 let model = Arc::new(FakeLanguageModel::default());
4780
4781 let parent = cx.new(|cx| {
4782 Thread::new(
4783 project.clone(),
4784 project_context.clone(),
4785 context_server_registry.clone(),
4786 Templates::new(),
4787 Some(model.clone()),
4788 cx,
4789 )
4790 });
4791
4792 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4793
4794 parent.update(cx, |thread, _cx| {
4795 thread.register_running_subagent(subagent.downgrade());
4796 });
4797
4798 subagent
4799 .update(cx, |thread, cx| {
4800 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4801 })
4802 .unwrap();
4803 cx.run_until_parked();
4804
4805 subagent.read_with(cx, |thread, _| {
4806 assert!(!thread.is_turn_complete(), "subagent should be running");
4807 });
4808
4809 parent.update(cx, |thread, cx| {
4810 thread.cancel(cx).detach();
4811 });
4812
4813 subagent.read_with(cx, |thread, _| {
4814 assert!(
4815 thread.is_turn_complete(),
4816 "subagent should be cancelled when parent cancels"
4817 );
4818 });
4819}
4820
4821#[gpui::test]
4822async fn test_thread_environment_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4823 init_test(cx);
4824 always_allow_tools(cx);
4825
4826 cx.update(|cx| {
4827 cx.update_flags(true, vec!["subagents".to_string()]);
4828 });
4829
4830 let fs = FakeFs::new(cx.executor());
4831 fs.insert_tree(path!("/test"), json!({})).await;
4832 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4833 let project_context = cx.new(|_cx| ProjectContext::default());
4834 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4835 let context_server_registry =
4836 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4837 cx.update(LanguageModelRegistry::test);
4838 let model = Arc::new(FakeLanguageModel::default());
4839 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4840 let native_agent = NativeAgent::new(
4841 project.clone(),
4842 thread_store,
4843 Templates::new(),
4844 None,
4845 fs,
4846 &mut cx.to_async(),
4847 )
4848 .await
4849 .unwrap();
4850 let parent_thread = cx.new(|cx| {
4851 Thread::new(
4852 project.clone(),
4853 project_context,
4854 context_server_registry,
4855 Templates::new(),
4856 Some(model.clone()),
4857 cx,
4858 )
4859 });
4860
4861 let mut handles = Vec::new();
4862 for _ in 0..MAX_PARALLEL_SUBAGENTS {
4863 let handle = cx
4864 .update(|cx| {
4865 NativeThreadEnvironment::create_subagent_thread(
4866 native_agent.downgrade(),
4867 parent_thread.clone(),
4868 "some title".to_string(),
4869 "some task".to_string(),
4870 None,
4871 None,
4872 cx,
4873 )
4874 })
4875 .expect("Expected to be able to create subagent thread");
4876 handles.push(handle);
4877 }
4878
4879 let result = cx.update(|cx| {
4880 NativeThreadEnvironment::create_subagent_thread(
4881 native_agent.downgrade(),
4882 parent_thread.clone(),
4883 "some title".to_string(),
4884 "some task".to_string(),
4885 None,
4886 None,
4887 cx,
4888 )
4889 });
4890 assert!(result.is_err());
4891 assert_eq!(
4892 result.err().unwrap().to_string(),
4893 format!(
4894 "Maximum parallel subagents ({}) reached. Wait for existing subagents to complete.",
4895 MAX_PARALLEL_SUBAGENTS
4896 )
4897 );
4898}
4899
4900#[gpui::test]
4901async fn test_subagent_tool_returns_summary(cx: &mut TestAppContext) {
4902 init_test(cx);
4903
4904 always_allow_tools(cx);
4905
4906 cx.update(|cx| {
4907 cx.update_flags(true, vec!["subagents".to_string()]);
4908 });
4909
4910 let fs = FakeFs::new(cx.executor());
4911 fs.insert_tree(path!("/test"), json!({})).await;
4912 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
4913 let project_context = cx.new(|_cx| ProjectContext::default());
4914 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4915 let context_server_registry =
4916 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4917 cx.update(LanguageModelRegistry::test);
4918 let model = Arc::new(FakeLanguageModel::default());
4919 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4920 let native_agent = NativeAgent::new(
4921 project.clone(),
4922 thread_store,
4923 Templates::new(),
4924 None,
4925 fs,
4926 &mut cx.to_async(),
4927 )
4928 .await
4929 .unwrap();
4930 let parent_thread = cx.new(|cx| {
4931 Thread::new(
4932 project.clone(),
4933 project_context,
4934 context_server_registry,
4935 Templates::new(),
4936 Some(model.clone()),
4937 cx,
4938 )
4939 });
4940
4941 let subagent_handle = cx
4942 .update(|cx| {
4943 NativeThreadEnvironment::create_subagent_thread(
4944 native_agent.downgrade(),
4945 parent_thread.clone(),
4946 "some title".to_string(),
4947 "task prompt".to_string(),
4948 Some(Duration::from_millis(10)),
4949 None,
4950 cx,
4951 )
4952 })
4953 .expect("Failed to create subagent");
4954
4955 let summary_task =
4956 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
4957
4958 cx.run_until_parked();
4959
4960 {
4961 let messages = model.pending_completions().last().unwrap().messages.clone();
4962 // Ensure that model received a system prompt
4963 assert_eq!(messages[0].role, Role::System);
4964 // Ensure that model received a task prompt
4965 assert_eq!(messages[1].role, Role::User);
4966 assert_eq!(
4967 messages[1].content,
4968 vec![MessageContent::Text("task prompt".to_string())]
4969 );
4970 }
4971
4972 model.send_last_completion_stream_text_chunk("Some task response...");
4973 model.end_last_completion_stream();
4974
4975 cx.run_until_parked();
4976
4977 {
4978 let messages = model.pending_completions().last().unwrap().messages.clone();
4979 assert_eq!(messages[2].role, Role::Assistant);
4980 assert_eq!(
4981 messages[2].content,
4982 vec![MessageContent::Text("Some task response...".to_string())]
4983 );
4984 // Ensure that model received a summary prompt
4985 assert_eq!(messages[3].role, Role::User);
4986 assert_eq!(
4987 messages[3].content,
4988 vec![MessageContent::Text("summary prompt".to_string())]
4989 );
4990 }
4991
4992 model.send_last_completion_stream_text_chunk("Some summary...");
4993 model.end_last_completion_stream();
4994
4995 let result = summary_task.await;
4996 assert_eq!(result.unwrap(), "Some summary...\n");
4997}
4998
4999#[gpui::test]
5000async fn test_subagent_tool_includes_cancellation_notice_when_timeout_is_exceeded(
5001 cx: &mut TestAppContext,
5002) {
5003 init_test(cx);
5004
5005 always_allow_tools(cx);
5006
5007 cx.update(|cx| {
5008 cx.update_flags(true, vec!["subagents".to_string()]);
5009 });
5010
5011 let fs = FakeFs::new(cx.executor());
5012 fs.insert_tree(path!("/test"), json!({})).await;
5013 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
5014 let project_context = cx.new(|_cx| ProjectContext::default());
5015 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5016 let context_server_registry =
5017 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5018 cx.update(LanguageModelRegistry::test);
5019 let model = Arc::new(FakeLanguageModel::default());
5020 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5021 let native_agent = NativeAgent::new(
5022 project.clone(),
5023 thread_store,
5024 Templates::new(),
5025 None,
5026 fs,
5027 &mut cx.to_async(),
5028 )
5029 .await
5030 .unwrap();
5031 let parent_thread = cx.new(|cx| {
5032 Thread::new(
5033 project.clone(),
5034 project_context,
5035 context_server_registry,
5036 Templates::new(),
5037 Some(model.clone()),
5038 cx,
5039 )
5040 });
5041
5042 let subagent_handle = cx
5043 .update(|cx| {
5044 NativeThreadEnvironment::create_subagent_thread(
5045 native_agent.downgrade(),
5046 parent_thread.clone(),
5047 "some title".to_string(),
5048 "task prompt".to_string(),
5049 Some(Duration::from_millis(100)),
5050 None,
5051 cx,
5052 )
5053 })
5054 .expect("Failed to create subagent");
5055
5056 let summary_task =
5057 subagent_handle.wait_for_summary("summary prompt".to_string(), &cx.to_async());
5058
5059 cx.run_until_parked();
5060
5061 {
5062 let messages = model.pending_completions().last().unwrap().messages.clone();
5063 // Ensure that model received a system prompt
5064 assert_eq!(messages[0].role, Role::System);
5065 // Ensure that model received a task prompt
5066 assert_eq!(
5067 messages[1].content,
5068 vec![MessageContent::Text("task prompt".to_string())]
5069 );
5070 }
5071
5072 // Don't complete the initial model stream — let the timeout expire instead.
5073 cx.executor().advance_clock(Duration::from_millis(200));
5074 cx.run_until_parked();
5075
5076 // After the timeout fires, the thread is cancelled and context_low_prompt is sent
5077 // instead of the summary_prompt.
5078 {
5079 let messages = model.pending_completions().last().unwrap().messages.clone();
5080 let last_user_message = messages
5081 .iter()
5082 .rev()
5083 .find(|m| m.role == Role::User)
5084 .unwrap();
5085 assert_eq!(
5086 last_user_message.content,
5087 vec![MessageContent::Text("The time to complete the task was exceeded. Stop with the task and follow the directions below:\nsummary prompt".to_string())]
5088 );
5089 }
5090
5091 model.send_last_completion_stream_text_chunk("Some context low response...");
5092 model.end_last_completion_stream();
5093
5094 let result = summary_task.await;
5095 assert_eq!(result.unwrap(), "Some context low response...\n");
5096}
5097
5098#[gpui::test]
5099async fn test_subagent_inherits_parent_thread_tools(cx: &mut TestAppContext) {
5100 init_test(cx);
5101
5102 always_allow_tools(cx);
5103
5104 cx.update(|cx| {
5105 cx.update_flags(true, vec!["subagents".to_string()]);
5106 });
5107
5108 let fs = FakeFs::new(cx.executor());
5109 fs.insert_tree(path!("/test"), json!({})).await;
5110 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
5111 let project_context = cx.new(|_cx| ProjectContext::default());
5112 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5113 let context_server_registry =
5114 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5115 cx.update(LanguageModelRegistry::test);
5116 let model = Arc::new(FakeLanguageModel::default());
5117 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5118 let native_agent = NativeAgent::new(
5119 project.clone(),
5120 thread_store,
5121 Templates::new(),
5122 None,
5123 fs,
5124 &mut cx.to_async(),
5125 )
5126 .await
5127 .unwrap();
5128 let parent_thread = cx.new(|cx| {
5129 let mut thread = Thread::new(
5130 project.clone(),
5131 project_context,
5132 context_server_registry,
5133 Templates::new(),
5134 Some(model.clone()),
5135 cx,
5136 );
5137 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
5138 thread.add_tool(GrepTool::new(project.clone()), None);
5139 thread
5140 });
5141
5142 let _subagent_handle = cx
5143 .update(|cx| {
5144 NativeThreadEnvironment::create_subagent_thread(
5145 native_agent.downgrade(),
5146 parent_thread.clone(),
5147 "some title".to_string(),
5148 "task prompt".to_string(),
5149 Some(Duration::from_millis(10)),
5150 None,
5151 cx,
5152 )
5153 })
5154 .expect("Failed to create subagent");
5155
5156 cx.run_until_parked();
5157
5158 let tools = model
5159 .pending_completions()
5160 .last()
5161 .unwrap()
5162 .tools
5163 .iter()
5164 .map(|tool| tool.name.clone())
5165 .collect::<Vec<_>>();
5166 assert_eq!(tools.len(), 2);
5167 assert!(tools.contains(&"grep".to_string()));
5168 assert!(tools.contains(&"list_directory".to_string()));
5169}
5170
5171#[gpui::test]
5172async fn test_subagent_tool_restricts_tool_access(cx: &mut TestAppContext) {
5173 init_test(cx);
5174
5175 always_allow_tools(cx);
5176
5177 cx.update(|cx| {
5178 cx.update_flags(true, vec!["subagents".to_string()]);
5179 });
5180
5181 let fs = FakeFs::new(cx.executor());
5182 fs.insert_tree(path!("/test"), json!({})).await;
5183 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
5184 let project_context = cx.new(|_cx| ProjectContext::default());
5185 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5186 let context_server_registry =
5187 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5188 cx.update(LanguageModelRegistry::test);
5189 let model = Arc::new(FakeLanguageModel::default());
5190 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5191 let native_agent = NativeAgent::new(
5192 project.clone(),
5193 thread_store,
5194 Templates::new(),
5195 None,
5196 fs,
5197 &mut cx.to_async(),
5198 )
5199 .await
5200 .unwrap();
5201 let parent_thread = cx.new(|cx| {
5202 let mut thread = Thread::new(
5203 project.clone(),
5204 project_context,
5205 context_server_registry,
5206 Templates::new(),
5207 Some(model.clone()),
5208 cx,
5209 );
5210 thread.add_tool(ListDirectoryTool::new(project.clone()), None);
5211 thread.add_tool(GrepTool::new(project.clone()), None);
5212 thread
5213 });
5214
5215 let _subagent_handle = cx
5216 .update(|cx| {
5217 NativeThreadEnvironment::create_subagent_thread(
5218 native_agent.downgrade(),
5219 parent_thread.clone(),
5220 "some title".to_string(),
5221 "task prompt".to_string(),
5222 Some(Duration::from_millis(10)),
5223 Some(vec!["grep".to_string()]),
5224 cx,
5225 )
5226 })
5227 .expect("Failed to create subagent");
5228
5229 cx.run_until_parked();
5230
5231 let tools = model
5232 .pending_completions()
5233 .last()
5234 .unwrap()
5235 .tools
5236 .iter()
5237 .map(|tool| tool.name.clone())
5238 .collect::<Vec<_>>();
5239 assert_eq!(tools, vec!["grep"]);
5240}
5241
5242#[gpui::test]
5243async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5244 init_test(cx);
5245
5246 let fs = FakeFs::new(cx.executor());
5247 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5248 .await;
5249 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5250
5251 cx.update(|cx| {
5252 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5253 settings.tool_permissions.tools.insert(
5254 EditFileTool::NAME.into(),
5255 agent_settings::ToolRules {
5256 default: Some(settings::ToolPermissionMode::Allow),
5257 always_allow: vec![],
5258 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5259 always_confirm: vec![],
5260 invalid_patterns: vec![],
5261 },
5262 );
5263 agent_settings::AgentSettings::override_global(settings, cx);
5264 });
5265
5266 let context_server_registry =
5267 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5268 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5269 let templates = crate::Templates::new();
5270 let thread = cx.new(|cx| {
5271 crate::Thread::new(
5272 project.clone(),
5273 cx.new(|_cx| prompt_store::ProjectContext::default()),
5274 context_server_registry,
5275 templates.clone(),
5276 None,
5277 cx,
5278 )
5279 });
5280
5281 #[allow(clippy::arc_with_non_send_sync)]
5282 let tool = Arc::new(crate::EditFileTool::new(
5283 project.clone(),
5284 thread.downgrade(),
5285 language_registry,
5286 templates,
5287 ));
5288 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5289
5290 let task = cx.update(|cx| {
5291 tool.run(
5292 crate::EditFileToolInput {
5293 display_description: "Edit sensitive file".to_string(),
5294 path: "root/sensitive_config.txt".into(),
5295 mode: crate::EditFileMode::Edit,
5296 },
5297 event_stream,
5298 cx,
5299 )
5300 });
5301
5302 let result = task.await;
5303 assert!(result.is_err(), "expected edit to be blocked");
5304 assert!(
5305 result.unwrap_err().to_string().contains("blocked"),
5306 "error should mention the edit was blocked"
5307 );
5308}
5309
5310#[gpui::test]
5311async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5312 init_test(cx);
5313
5314 let fs = FakeFs::new(cx.executor());
5315 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5316 .await;
5317 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5318
5319 cx.update(|cx| {
5320 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5321 settings.tool_permissions.tools.insert(
5322 DeletePathTool::NAME.into(),
5323 agent_settings::ToolRules {
5324 default: Some(settings::ToolPermissionMode::Allow),
5325 always_allow: vec![],
5326 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5327 always_confirm: vec![],
5328 invalid_patterns: vec![],
5329 },
5330 );
5331 agent_settings::AgentSettings::override_global(settings, cx);
5332 });
5333
5334 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5335
5336 #[allow(clippy::arc_with_non_send_sync)]
5337 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5338 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5339
5340 let task = cx.update(|cx| {
5341 tool.run(
5342 crate::DeletePathToolInput {
5343 path: "root/important_data.txt".to_string(),
5344 },
5345 event_stream,
5346 cx,
5347 )
5348 });
5349
5350 let result = task.await;
5351 assert!(result.is_err(), "expected deletion to be blocked");
5352 assert!(
5353 result.unwrap_err().to_string().contains("blocked"),
5354 "error should mention the deletion was blocked"
5355 );
5356}
5357
5358#[gpui::test]
5359async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5360 init_test(cx);
5361
5362 let fs = FakeFs::new(cx.executor());
5363 fs.insert_tree(
5364 "/root",
5365 json!({
5366 "safe.txt": "content",
5367 "protected": {}
5368 }),
5369 )
5370 .await;
5371 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5372
5373 cx.update(|cx| {
5374 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5375 settings.tool_permissions.tools.insert(
5376 MovePathTool::NAME.into(),
5377 agent_settings::ToolRules {
5378 default: Some(settings::ToolPermissionMode::Allow),
5379 always_allow: vec![],
5380 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5381 always_confirm: vec![],
5382 invalid_patterns: vec![],
5383 },
5384 );
5385 agent_settings::AgentSettings::override_global(settings, cx);
5386 });
5387
5388 #[allow(clippy::arc_with_non_send_sync)]
5389 let tool = Arc::new(crate::MovePathTool::new(project));
5390 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5391
5392 let task = cx.update(|cx| {
5393 tool.run(
5394 crate::MovePathToolInput {
5395 source_path: "root/safe.txt".to_string(),
5396 destination_path: "root/protected/safe.txt".to_string(),
5397 },
5398 event_stream,
5399 cx,
5400 )
5401 });
5402
5403 let result = task.await;
5404 assert!(
5405 result.is_err(),
5406 "expected move to be blocked due to destination path"
5407 );
5408 assert!(
5409 result.unwrap_err().to_string().contains("blocked"),
5410 "error should mention the move was blocked"
5411 );
5412}
5413
5414#[gpui::test]
5415async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5416 init_test(cx);
5417
5418 let fs = FakeFs::new(cx.executor());
5419 fs.insert_tree(
5420 "/root",
5421 json!({
5422 "secret.txt": "secret content",
5423 "public": {}
5424 }),
5425 )
5426 .await;
5427 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5428
5429 cx.update(|cx| {
5430 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5431 settings.tool_permissions.tools.insert(
5432 MovePathTool::NAME.into(),
5433 agent_settings::ToolRules {
5434 default: Some(settings::ToolPermissionMode::Allow),
5435 always_allow: vec![],
5436 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5437 always_confirm: vec![],
5438 invalid_patterns: vec![],
5439 },
5440 );
5441 agent_settings::AgentSettings::override_global(settings, cx);
5442 });
5443
5444 #[allow(clippy::arc_with_non_send_sync)]
5445 let tool = Arc::new(crate::MovePathTool::new(project));
5446 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5447
5448 let task = cx.update(|cx| {
5449 tool.run(
5450 crate::MovePathToolInput {
5451 source_path: "root/secret.txt".to_string(),
5452 destination_path: "root/public/not_secret.txt".to_string(),
5453 },
5454 event_stream,
5455 cx,
5456 )
5457 });
5458
5459 let result = task.await;
5460 assert!(
5461 result.is_err(),
5462 "expected move to be blocked due to source path"
5463 );
5464 assert!(
5465 result.unwrap_err().to_string().contains("blocked"),
5466 "error should mention the move was blocked"
5467 );
5468}
5469
5470#[gpui::test]
5471async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5472 init_test(cx);
5473
5474 let fs = FakeFs::new(cx.executor());
5475 fs.insert_tree(
5476 "/root",
5477 json!({
5478 "confidential.txt": "confidential data",
5479 "dest": {}
5480 }),
5481 )
5482 .await;
5483 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5484
5485 cx.update(|cx| {
5486 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5487 settings.tool_permissions.tools.insert(
5488 CopyPathTool::NAME.into(),
5489 agent_settings::ToolRules {
5490 default: Some(settings::ToolPermissionMode::Allow),
5491 always_allow: vec![],
5492 always_deny: vec![
5493 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5494 ],
5495 always_confirm: vec![],
5496 invalid_patterns: vec![],
5497 },
5498 );
5499 agent_settings::AgentSettings::override_global(settings, cx);
5500 });
5501
5502 #[allow(clippy::arc_with_non_send_sync)]
5503 let tool = Arc::new(crate::CopyPathTool::new(project));
5504 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5505
5506 let task = cx.update(|cx| {
5507 tool.run(
5508 crate::CopyPathToolInput {
5509 source_path: "root/confidential.txt".to_string(),
5510 destination_path: "root/dest/copy.txt".to_string(),
5511 },
5512 event_stream,
5513 cx,
5514 )
5515 });
5516
5517 let result = task.await;
5518 assert!(result.is_err(), "expected copy to be blocked");
5519 assert!(
5520 result.unwrap_err().to_string().contains("blocked"),
5521 "error should mention the copy was blocked"
5522 );
5523}
5524
5525#[gpui::test]
5526async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5527 init_test(cx);
5528
5529 let fs = FakeFs::new(cx.executor());
5530 fs.insert_tree(
5531 "/root",
5532 json!({
5533 "normal.txt": "normal content",
5534 "readonly": {
5535 "config.txt": "readonly content"
5536 }
5537 }),
5538 )
5539 .await;
5540 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5541
5542 cx.update(|cx| {
5543 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5544 settings.tool_permissions.tools.insert(
5545 SaveFileTool::NAME.into(),
5546 agent_settings::ToolRules {
5547 default: Some(settings::ToolPermissionMode::Allow),
5548 always_allow: vec![],
5549 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5550 always_confirm: vec![],
5551 invalid_patterns: vec![],
5552 },
5553 );
5554 agent_settings::AgentSettings::override_global(settings, cx);
5555 });
5556
5557 #[allow(clippy::arc_with_non_send_sync)]
5558 let tool = Arc::new(crate::SaveFileTool::new(project));
5559 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5560
5561 let task = cx.update(|cx| {
5562 tool.run(
5563 crate::SaveFileToolInput {
5564 paths: vec![
5565 std::path::PathBuf::from("root/normal.txt"),
5566 std::path::PathBuf::from("root/readonly/config.txt"),
5567 ],
5568 },
5569 event_stream,
5570 cx,
5571 )
5572 });
5573
5574 let result = task.await;
5575 assert!(
5576 result.is_err(),
5577 "expected save to be blocked due to denied path"
5578 );
5579 assert!(
5580 result.unwrap_err().to_string().contains("blocked"),
5581 "error should mention the save was blocked"
5582 );
5583}
5584
5585#[gpui::test]
5586async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5587 init_test(cx);
5588
5589 let fs = FakeFs::new(cx.executor());
5590 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5591 .await;
5592 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5593
5594 cx.update(|cx| {
5595 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5596 settings.tool_permissions.tools.insert(
5597 SaveFileTool::NAME.into(),
5598 agent_settings::ToolRules {
5599 default: Some(settings::ToolPermissionMode::Allow),
5600 always_allow: vec![],
5601 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5602 always_confirm: vec![],
5603 invalid_patterns: vec![],
5604 },
5605 );
5606 agent_settings::AgentSettings::override_global(settings, cx);
5607 });
5608
5609 #[allow(clippy::arc_with_non_send_sync)]
5610 let tool = Arc::new(crate::SaveFileTool::new(project));
5611 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5612
5613 let task = cx.update(|cx| {
5614 tool.run(
5615 crate::SaveFileToolInput {
5616 paths: vec![std::path::PathBuf::from("root/config.secret")],
5617 },
5618 event_stream,
5619 cx,
5620 )
5621 });
5622
5623 let result = task.await;
5624 assert!(result.is_err(), "expected save to be blocked");
5625 assert!(
5626 result.unwrap_err().to_string().contains("blocked"),
5627 "error should mention the save was blocked"
5628 );
5629}
5630
5631#[gpui::test]
5632async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5633 init_test(cx);
5634
5635 cx.update(|cx| {
5636 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5637 settings.tool_permissions.tools.insert(
5638 WebSearchTool::NAME.into(),
5639 agent_settings::ToolRules {
5640 default: Some(settings::ToolPermissionMode::Allow),
5641 always_allow: vec![],
5642 always_deny: vec![
5643 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5644 ],
5645 always_confirm: vec![],
5646 invalid_patterns: vec![],
5647 },
5648 );
5649 agent_settings::AgentSettings::override_global(settings, cx);
5650 });
5651
5652 #[allow(clippy::arc_with_non_send_sync)]
5653 let tool = Arc::new(crate::WebSearchTool);
5654 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5655
5656 let input: crate::WebSearchToolInput =
5657 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5658
5659 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5660
5661 let result = task.await;
5662 assert!(result.is_err(), "expected search to be blocked");
5663 assert!(
5664 result.unwrap_err().to_string().contains("blocked"),
5665 "error should mention the search was blocked"
5666 );
5667}
5668
5669#[gpui::test]
5670async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5671 init_test(cx);
5672
5673 let fs = FakeFs::new(cx.executor());
5674 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5675 .await;
5676 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5677
5678 cx.update(|cx| {
5679 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5680 settings.tool_permissions.tools.insert(
5681 EditFileTool::NAME.into(),
5682 agent_settings::ToolRules {
5683 default: Some(settings::ToolPermissionMode::Confirm),
5684 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5685 always_deny: vec![],
5686 always_confirm: vec![],
5687 invalid_patterns: vec![],
5688 },
5689 );
5690 agent_settings::AgentSettings::override_global(settings, cx);
5691 });
5692
5693 let context_server_registry =
5694 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5695 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5696 let templates = crate::Templates::new();
5697 let thread = cx.new(|cx| {
5698 crate::Thread::new(
5699 project.clone(),
5700 cx.new(|_cx| prompt_store::ProjectContext::default()),
5701 context_server_registry,
5702 templates.clone(),
5703 None,
5704 cx,
5705 )
5706 });
5707
5708 #[allow(clippy::arc_with_non_send_sync)]
5709 let tool = Arc::new(crate::EditFileTool::new(
5710 project,
5711 thread.downgrade(),
5712 language_registry,
5713 templates,
5714 ));
5715 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5716
5717 let _task = cx.update(|cx| {
5718 tool.run(
5719 crate::EditFileToolInput {
5720 display_description: "Edit README".to_string(),
5721 path: "root/README.md".into(),
5722 mode: crate::EditFileMode::Edit,
5723 },
5724 event_stream,
5725 cx,
5726 )
5727 });
5728
5729 cx.run_until_parked();
5730
5731 let event = rx.try_next();
5732 assert!(
5733 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5734 "expected no authorization request for allowed .md file"
5735 );
5736}
5737
5738#[gpui::test]
5739async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5740 init_test(cx);
5741
5742 let fs = FakeFs::new(cx.executor());
5743 fs.insert_tree(
5744 "/root",
5745 json!({
5746 ".zed": {
5747 "settings.json": "{}"
5748 },
5749 "README.md": "# Hello"
5750 }),
5751 )
5752 .await;
5753 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5754
5755 cx.update(|cx| {
5756 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5757 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5758 agent_settings::AgentSettings::override_global(settings, cx);
5759 });
5760
5761 let context_server_registry =
5762 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5763 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5764 let templates = crate::Templates::new();
5765 let thread = cx.new(|cx| {
5766 crate::Thread::new(
5767 project.clone(),
5768 cx.new(|_cx| prompt_store::ProjectContext::default()),
5769 context_server_registry,
5770 templates.clone(),
5771 None,
5772 cx,
5773 )
5774 });
5775
5776 #[allow(clippy::arc_with_non_send_sync)]
5777 let tool = Arc::new(crate::EditFileTool::new(
5778 project,
5779 thread.downgrade(),
5780 language_registry,
5781 templates,
5782 ));
5783
5784 // Editing a file inside .zed/ should still prompt even with global default: allow,
5785 // because local settings paths are sensitive and require confirmation regardless.
5786 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5787 let _task = cx.update(|cx| {
5788 tool.run(
5789 crate::EditFileToolInput {
5790 display_description: "Edit local settings".to_string(),
5791 path: "root/.zed/settings.json".into(),
5792 mode: crate::EditFileMode::Edit,
5793 },
5794 event_stream,
5795 cx,
5796 )
5797 });
5798
5799 let _update = rx.expect_update_fields().await;
5800 let _auth = rx.expect_authorization().await;
5801}
5802
5803#[gpui::test]
5804async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5805 init_test(cx);
5806
5807 cx.update(|cx| {
5808 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5809 settings.tool_permissions.tools.insert(
5810 FetchTool::NAME.into(),
5811 agent_settings::ToolRules {
5812 default: Some(settings::ToolPermissionMode::Allow),
5813 always_allow: vec![],
5814 always_deny: vec![
5815 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5816 ],
5817 always_confirm: vec![],
5818 invalid_patterns: vec![],
5819 },
5820 );
5821 agent_settings::AgentSettings::override_global(settings, cx);
5822 });
5823
5824 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5825
5826 #[allow(clippy::arc_with_non_send_sync)]
5827 let tool = Arc::new(crate::FetchTool::new(http_client));
5828 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5829
5830 let input: crate::FetchToolInput =
5831 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5832
5833 let task = cx.update(|cx| tool.run(input, event_stream, cx));
5834
5835 let result = task.await;
5836 assert!(result.is_err(), "expected fetch to be blocked");
5837 assert!(
5838 result.unwrap_err().to_string().contains("blocked"),
5839 "error should mention the fetch was blocked"
5840 );
5841}
5842
5843#[gpui::test]
5844async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5845 init_test(cx);
5846
5847 cx.update(|cx| {
5848 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5849 settings.tool_permissions.tools.insert(
5850 FetchTool::NAME.into(),
5851 agent_settings::ToolRules {
5852 default: Some(settings::ToolPermissionMode::Confirm),
5853 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5854 always_deny: vec![],
5855 always_confirm: vec![],
5856 invalid_patterns: vec![],
5857 },
5858 );
5859 agent_settings::AgentSettings::override_global(settings, cx);
5860 });
5861
5862 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5863
5864 #[allow(clippy::arc_with_non_send_sync)]
5865 let tool = Arc::new(crate::FetchTool::new(http_client));
5866 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5867
5868 let input: crate::FetchToolInput =
5869 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5870
5871 let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5872
5873 cx.run_until_parked();
5874
5875 let event = rx.try_next();
5876 assert!(
5877 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5878 "expected no authorization request for allowed docs.rs URL"
5879 );
5880}
5881
5882#[gpui::test]
5883async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
5884 init_test(cx);
5885 always_allow_tools(cx);
5886
5887 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
5888 let fake_model = model.as_fake();
5889
5890 // Add a tool so we can simulate tool calls
5891 thread.update(cx, |thread, _cx| {
5892 thread.add_tool(EchoTool, None);
5893 });
5894
5895 // Start a turn by sending a message
5896 let mut events = thread
5897 .update(cx, |thread, cx| {
5898 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
5899 })
5900 .unwrap();
5901 cx.run_until_parked();
5902
5903 // Simulate the model making a tool call
5904 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5905 LanguageModelToolUse {
5906 id: "tool_1".into(),
5907 name: "echo".into(),
5908 raw_input: r#"{"text": "hello"}"#.into(),
5909 input: json!({"text": "hello"}),
5910 is_input_complete: true,
5911 thought_signature: None,
5912 },
5913 ));
5914 fake_model
5915 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
5916
5917 // Signal that a message is queued before ending the stream
5918 thread.update(cx, |thread, _cx| {
5919 thread.set_has_queued_message(true);
5920 });
5921
5922 // Now end the stream - tool will run, and the boundary check should see the queue
5923 fake_model.end_last_completion_stream();
5924
5925 // Collect all events until the turn stops
5926 let all_events = collect_events_until_stop(&mut events, cx).await;
5927
5928 // Verify we received the tool call event
5929 let tool_call_ids: Vec<_> = all_events
5930 .iter()
5931 .filter_map(|e| match e {
5932 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
5933 _ => None,
5934 })
5935 .collect();
5936 assert_eq!(
5937 tool_call_ids,
5938 vec!["tool_1"],
5939 "Should have received a tool call event for our echo tool"
5940 );
5941
5942 // The turn should have stopped with EndTurn
5943 let stop_reasons = stop_events(all_events);
5944 assert_eq!(
5945 stop_reasons,
5946 vec![acp::StopReason::EndTurn],
5947 "Turn should have ended after tool completion due to queued message"
5948 );
5949
5950 // Verify the queued message flag is still set
5951 thread.update(cx, |thread, _cx| {
5952 assert!(
5953 thread.has_queued_message(),
5954 "Should still have queued message flag set"
5955 );
5956 });
5957
5958 // Thread should be idle now
5959 thread.update(cx, |thread, _cx| {
5960 assert!(
5961 thread.is_turn_complete(),
5962 "Thread should not be running after turn ends"
5963 );
5964 });
5965}