1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
33 fake_provider::FakeLanguageModel,
34};
35use pretty_assertions::assert_eq;
36use project::{
37 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
38};
39use prompt_store::ProjectContext;
40use reqwest_client::ReqwestClient;
41use schemars::JsonSchema;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use settings::{Settings, SettingsStore};
45use std::{
46 path::Path,
47 pin::Pin,
48 rc::Rc,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, Ordering},
52 },
53 time::Duration,
54};
55use util::path;
56
57mod edit_file_thread_test;
58mod test_tools;
59use test_tools::*;
60
61fn init_test(cx: &mut TestAppContext) {
62 cx.update(|cx| {
63 let settings_store = SettingsStore::test(cx);
64 cx.set_global(settings_store);
65 });
66}
67
68struct FakeTerminalHandle {
69 killed: Arc<AtomicBool>,
70 stopped_by_user: Arc<AtomicBool>,
71 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
72 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
73 output: acp::TerminalOutputResponse,
74 id: acp::TerminalId,
75}
76
77impl FakeTerminalHandle {
78 fn new_never_exits(cx: &mut App) -> Self {
79 let killed = Arc::new(AtomicBool::new(false));
80 let stopped_by_user = Arc::new(AtomicBool::new(false));
81
82 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
83
84 let wait_for_exit = cx
85 .spawn(async move |_cx| {
86 // Wait for the exit signal (sent when kill() is called)
87 let _ = exit_receiver.await;
88 acp::TerminalExitStatus::new()
89 })
90 .shared();
91
92 Self {
93 killed,
94 stopped_by_user,
95 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
96 wait_for_exit,
97 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
98 id: acp::TerminalId::new("fake_terminal".to_string()),
99 }
100 }
101
102 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
103 let killed = Arc::new(AtomicBool::new(false));
104 let stopped_by_user = Arc::new(AtomicBool::new(false));
105 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
106
107 let wait_for_exit = cx
108 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
109 .shared();
110
111 Self {
112 killed,
113 stopped_by_user,
114 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
115 wait_for_exit,
116 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
117 id: acp::TerminalId::new("fake_terminal".to_string()),
118 }
119 }
120
121 fn was_killed(&self) -> bool {
122 self.killed.load(Ordering::SeqCst)
123 }
124
125 fn set_stopped_by_user(&self, stopped: bool) {
126 self.stopped_by_user.store(stopped, Ordering::SeqCst);
127 }
128
129 fn signal_exit(&self) {
130 if let Some(sender) = self.exit_sender.borrow_mut().take() {
131 let _ = sender.send(());
132 }
133 }
134}
135
136impl crate::TerminalHandle for FakeTerminalHandle {
137 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
138 Ok(self.id.clone())
139 }
140
141 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
142 Ok(self.output.clone())
143 }
144
145 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
146 Ok(self.wait_for_exit.clone())
147 }
148
149 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
150 self.killed.store(true, Ordering::SeqCst);
151 self.signal_exit();
152 Ok(())
153 }
154
155 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
156 Ok(self.stopped_by_user.load(Ordering::SeqCst))
157 }
158}
159
160struct FakeSubagentHandle {
161 session_id: acp::SessionId,
162 send_task: Shared<Task<String>>,
163}
164
165impl SubagentHandle for FakeSubagentHandle {
166 fn id(&self) -> acp::SessionId {
167 self.session_id.clone()
168 }
169
170 fn num_entries(&self, _cx: &App) -> usize {
171 unimplemented!()
172 }
173
174 fn send(&self, _message: String, cx: &AsyncApp) -> Task<Result<String>> {
175 let task = self.send_task.clone();
176 cx.background_spawn(async move { Ok(task.await) })
177 }
178}
179
180#[derive(Default)]
181struct FakeThreadEnvironment {
182 terminal_handle: Option<Rc<FakeTerminalHandle>>,
183 subagent_handle: Option<Rc<FakeSubagentHandle>>,
184}
185
186impl FakeThreadEnvironment {
187 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
188 Self {
189 terminal_handle: Some(terminal_handle.into()),
190 ..self
191 }
192 }
193}
194
195impl crate::ThreadEnvironment for FakeThreadEnvironment {
196 fn create_terminal(
197 &self,
198 _command: String,
199 _cwd: Option<std::path::PathBuf>,
200 _output_byte_limit: Option<u64>,
201 _cx: &mut AsyncApp,
202 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
203 let handle = self
204 .terminal_handle
205 .clone()
206 .expect("Terminal handle not available on FakeThreadEnvironment");
207 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
208 }
209
210 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
211 Ok(self
212 .subagent_handle
213 .clone()
214 .expect("Subagent handle not available on FakeThreadEnvironment")
215 as Rc<dyn SubagentHandle>)
216 }
217}
218
219/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
220struct MultiTerminalEnvironment {
221 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
222}
223
224impl MultiTerminalEnvironment {
225 fn new() -> Self {
226 Self {
227 handles: std::cell::RefCell::new(Vec::new()),
228 }
229 }
230
231 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
232 self.handles.borrow().clone()
233 }
234}
235
236impl crate::ThreadEnvironment for MultiTerminalEnvironment {
237 fn create_terminal(
238 &self,
239 _command: String,
240 _cwd: Option<std::path::PathBuf>,
241 _output_byte_limit: Option<u64>,
242 cx: &mut AsyncApp,
243 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
244 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
245 self.handles.borrow_mut().push(handle.clone());
246 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
247 }
248
249 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
250 unimplemented!()
251 }
252}
253
254fn always_allow_tools(cx: &mut TestAppContext) {
255 cx.update(|cx| {
256 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
257 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
258 agent_settings::AgentSettings::override_global(settings, cx);
259 });
260}
261
262#[gpui::test]
263async fn test_echo(cx: &mut TestAppContext) {
264 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
265 let fake_model = model.as_fake();
266
267 let events = thread
268 .update(cx, |thread, cx| {
269 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
270 })
271 .unwrap();
272 cx.run_until_parked();
273 fake_model.send_last_completion_stream_text_chunk("Hello");
274 fake_model
275 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
276 fake_model.end_last_completion_stream();
277
278 let events = events.collect().await;
279 thread.update(cx, |thread, _cx| {
280 assert_eq!(
281 thread.last_received_or_pending_message().unwrap().role(),
282 Role::Assistant
283 );
284 assert_eq!(
285 thread
286 .last_received_or_pending_message()
287 .unwrap()
288 .to_markdown(),
289 "Hello\n"
290 )
291 });
292 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
293}
294
295#[gpui::test]
296async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
297 init_test(cx);
298 always_allow_tools(cx);
299
300 let fs = FakeFs::new(cx.executor());
301 let project = Project::test(fs, [], cx).await;
302
303 let environment = Rc::new(cx.update(|cx| {
304 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
305 }));
306 let handle = environment.terminal_handle.clone().unwrap();
307
308 #[allow(clippy::arc_with_non_send_sync)]
309 let tool = Arc::new(crate::TerminalTool::new(project, environment));
310 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
311
312 let task = cx.update(|cx| {
313 tool.run(
314 ToolInput::resolved(crate::TerminalToolInput {
315 command: "sleep 1000".to_string(),
316 cd: ".".to_string(),
317 timeout_ms: Some(5),
318 }),
319 event_stream,
320 cx,
321 )
322 });
323
324 let update = rx.expect_update_fields().await;
325 assert!(
326 update.content.iter().any(|blocks| {
327 blocks
328 .iter()
329 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
330 }),
331 "expected tool call update to include terminal content"
332 );
333
334 let mut task_future: Pin<Box<Fuse<Task<Result<String, String>>>>> = Box::pin(task.fuse());
335
336 let deadline = std::time::Instant::now() + Duration::from_millis(500);
337 loop {
338 if let Some(result) = task_future.as_mut().now_or_never() {
339 let result = result.expect("terminal tool task should complete");
340
341 assert!(
342 handle.was_killed(),
343 "expected terminal handle to be killed on timeout"
344 );
345 assert!(
346 result.contains("partial output"),
347 "expected result to include terminal output, got: {result}"
348 );
349 return;
350 }
351
352 if std::time::Instant::now() >= deadline {
353 panic!("timed out waiting for terminal tool task to complete");
354 }
355
356 cx.run_until_parked();
357 cx.background_executor.timer(Duration::from_millis(1)).await;
358 }
359}
360
361#[gpui::test]
362#[ignore]
363async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
364 init_test(cx);
365 always_allow_tools(cx);
366
367 let fs = FakeFs::new(cx.executor());
368 let project = Project::test(fs, [], cx).await;
369
370 let environment = Rc::new(cx.update(|cx| {
371 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
372 }));
373 let handle = environment.terminal_handle.clone().unwrap();
374
375 #[allow(clippy::arc_with_non_send_sync)]
376 let tool = Arc::new(crate::TerminalTool::new(project, environment));
377 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
378
379 let _task = cx.update(|cx| {
380 tool.run(
381 ToolInput::resolved(crate::TerminalToolInput {
382 command: "sleep 1000".to_string(),
383 cd: ".".to_string(),
384 timeout_ms: None,
385 }),
386 event_stream,
387 cx,
388 )
389 });
390
391 let update = rx.expect_update_fields().await;
392 assert!(
393 update.content.iter().any(|blocks| {
394 blocks
395 .iter()
396 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
397 }),
398 "expected tool call update to include terminal content"
399 );
400
401 cx.background_executor
402 .timer(Duration::from_millis(25))
403 .await;
404
405 assert!(
406 !handle.was_killed(),
407 "did not expect terminal handle to be killed without a timeout"
408 );
409}
410
411#[gpui::test]
412async fn test_thinking(cx: &mut TestAppContext) {
413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
414 let fake_model = model.as_fake();
415
416 let events = thread
417 .update(cx, |thread, cx| {
418 thread.send(
419 UserMessageId::new(),
420 [indoc! {"
421 Testing:
422
423 Generate a thinking step where you just think the word 'Think',
424 and have your final answer be 'Hello'
425 "}],
426 cx,
427 )
428 })
429 .unwrap();
430 cx.run_until_parked();
431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
432 text: "Think".to_string(),
433 signature: None,
434 });
435 fake_model.send_last_completion_stream_text_chunk("Hello");
436 fake_model
437 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
438 fake_model.end_last_completion_stream();
439
440 let events = events.collect().await;
441 thread.update(cx, |thread, _cx| {
442 assert_eq!(
443 thread.last_received_or_pending_message().unwrap().role(),
444 Role::Assistant
445 );
446 assert_eq!(
447 thread
448 .last_received_or_pending_message()
449 .unwrap()
450 .to_markdown(),
451 indoc! {"
452 <think>Think</think>
453 Hello
454 "}
455 )
456 });
457 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
458}
459
460#[gpui::test]
461async fn test_system_prompt(cx: &mut TestAppContext) {
462 let ThreadTest {
463 model,
464 thread,
465 project_context,
466 ..
467 } = setup(cx, TestModel::Fake).await;
468 let fake_model = model.as_fake();
469
470 project_context.update(cx, |project_context, _cx| {
471 project_context.shell = "test-shell".into()
472 });
473 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
474 thread
475 .update(cx, |thread, cx| {
476 thread.send(UserMessageId::new(), ["abc"], cx)
477 })
478 .unwrap();
479 cx.run_until_parked();
480 let mut pending_completions = fake_model.pending_completions();
481 assert_eq!(
482 pending_completions.len(),
483 1,
484 "unexpected pending completions: {:?}",
485 pending_completions
486 );
487
488 let pending_completion = pending_completions.pop().unwrap();
489 assert_eq!(pending_completion.messages[0].role, Role::System);
490
491 let system_message = &pending_completion.messages[0];
492 let system_prompt = system_message.content[0].to_str().unwrap();
493 assert!(
494 system_prompt.contains("test-shell"),
495 "unexpected system message: {:?}",
496 system_message
497 );
498 assert!(
499 system_prompt.contains("## Fixing Diagnostics"),
500 "unexpected system message: {:?}",
501 system_message
502 );
503}
504
505#[gpui::test]
506async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
507 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
508 let fake_model = model.as_fake();
509
510 thread
511 .update(cx, |thread, cx| {
512 thread.send(UserMessageId::new(), ["abc"], cx)
513 })
514 .unwrap();
515 cx.run_until_parked();
516 let mut pending_completions = fake_model.pending_completions();
517 assert_eq!(
518 pending_completions.len(),
519 1,
520 "unexpected pending completions: {:?}",
521 pending_completions
522 );
523
524 let pending_completion = pending_completions.pop().unwrap();
525 assert_eq!(pending_completion.messages[0].role, Role::System);
526
527 let system_message = &pending_completion.messages[0];
528 let system_prompt = system_message.content[0].to_str().unwrap();
529 assert!(
530 !system_prompt.contains("## Tool Use"),
531 "unexpected system message: {:?}",
532 system_message
533 );
534 assert!(
535 !system_prompt.contains("## Fixing Diagnostics"),
536 "unexpected system message: {:?}",
537 system_message
538 );
539}
540
541#[gpui::test]
542async fn test_prompt_caching(cx: &mut TestAppContext) {
543 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
544 let fake_model = model.as_fake();
545
546 // Send initial user message and verify it's cached
547 thread
548 .update(cx, |thread, cx| {
549 thread.send(UserMessageId::new(), ["Message 1"], cx)
550 })
551 .unwrap();
552 cx.run_until_parked();
553
554 let completion = fake_model.pending_completions().pop().unwrap();
555 assert_eq!(
556 completion.messages[1..],
557 vec![LanguageModelRequestMessage {
558 role: Role::User,
559 content: vec!["Message 1".into()],
560 cache: true,
561 reasoning_details: None,
562 }]
563 );
564 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
565 "Response to Message 1".into(),
566 ));
567 fake_model.end_last_completion_stream();
568 cx.run_until_parked();
569
570 // Send another user message and verify only the latest is cached
571 thread
572 .update(cx, |thread, cx| {
573 thread.send(UserMessageId::new(), ["Message 2"], cx)
574 })
575 .unwrap();
576 cx.run_until_parked();
577
578 let completion = fake_model.pending_completions().pop().unwrap();
579 assert_eq!(
580 completion.messages[1..],
581 vec![
582 LanguageModelRequestMessage {
583 role: Role::User,
584 content: vec!["Message 1".into()],
585 cache: false,
586 reasoning_details: None,
587 },
588 LanguageModelRequestMessage {
589 role: Role::Assistant,
590 content: vec!["Response to Message 1".into()],
591 cache: false,
592 reasoning_details: None,
593 },
594 LanguageModelRequestMessage {
595 role: Role::User,
596 content: vec!["Message 2".into()],
597 cache: true,
598 reasoning_details: None,
599 }
600 ]
601 );
602 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
603 "Response to Message 2".into(),
604 ));
605 fake_model.end_last_completion_stream();
606 cx.run_until_parked();
607
608 // Simulate a tool call and verify that the latest tool result is cached
609 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
610 thread
611 .update(cx, |thread, cx| {
612 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
613 })
614 .unwrap();
615 cx.run_until_parked();
616
617 let tool_use = LanguageModelToolUse {
618 id: "tool_1".into(),
619 name: EchoTool::NAME.into(),
620 raw_input: json!({"text": "test"}).to_string(),
621 input: json!({"text": "test"}),
622 is_input_complete: true,
623 thought_signature: None,
624 };
625 fake_model
626 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
627 fake_model.end_last_completion_stream();
628 cx.run_until_parked();
629
630 let completion = fake_model.pending_completions().pop().unwrap();
631 let tool_result = LanguageModelToolResult {
632 tool_use_id: "tool_1".into(),
633 tool_name: EchoTool::NAME.into(),
634 is_error: false,
635 content: "test".into(),
636 output: Some("test".into()),
637 };
638 assert_eq!(
639 completion.messages[1..],
640 vec![
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec!["Message 1".into()],
644 cache: false,
645 reasoning_details: None,
646 },
647 LanguageModelRequestMessage {
648 role: Role::Assistant,
649 content: vec!["Response to Message 1".into()],
650 cache: false,
651 reasoning_details: None,
652 },
653 LanguageModelRequestMessage {
654 role: Role::User,
655 content: vec!["Message 2".into()],
656 cache: false,
657 reasoning_details: None,
658 },
659 LanguageModelRequestMessage {
660 role: Role::Assistant,
661 content: vec!["Response to Message 2".into()],
662 cache: false,
663 reasoning_details: None,
664 },
665 LanguageModelRequestMessage {
666 role: Role::User,
667 content: vec!["Use the echo tool".into()],
668 cache: false,
669 reasoning_details: None,
670 },
671 LanguageModelRequestMessage {
672 role: Role::Assistant,
673 content: vec![MessageContent::ToolUse(tool_use)],
674 cache: false,
675 reasoning_details: None,
676 },
677 LanguageModelRequestMessage {
678 role: Role::User,
679 content: vec![MessageContent::ToolResult(tool_result)],
680 cache: true,
681 reasoning_details: None,
682 }
683 ]
684 );
685}
686
687#[gpui::test]
688#[cfg_attr(not(feature = "e2e"), ignore)]
689async fn test_basic_tool_calls(cx: &mut TestAppContext) {
690 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
691
692 // Test a tool call that's likely to complete *before* streaming stops.
693 let events = thread
694 .update(cx, |thread, cx| {
695 thread.add_tool(EchoTool);
696 thread.send(
697 UserMessageId::new(),
698 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
699 cx,
700 )
701 })
702 .unwrap()
703 .collect()
704 .await;
705 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
706
707 // Test a tool calls that's likely to complete *after* streaming stops.
708 let events = thread
709 .update(cx, |thread, cx| {
710 thread.remove_tool(&EchoTool::NAME);
711 thread.add_tool(DelayTool);
712 thread.send(
713 UserMessageId::new(),
714 [
715 "Now call the delay tool with 200ms.",
716 "When the timer goes off, then you echo the output of the tool.",
717 ],
718 cx,
719 )
720 })
721 .unwrap()
722 .collect()
723 .await;
724 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
725 thread.update(cx, |thread, _cx| {
726 assert!(
727 thread
728 .last_received_or_pending_message()
729 .unwrap()
730 .as_agent_message()
731 .unwrap()
732 .content
733 .iter()
734 .any(|content| {
735 if let AgentMessageContent::Text(text) = content {
736 text.contains("Ding")
737 } else {
738 false
739 }
740 }),
741 "{}",
742 thread.to_markdown()
743 );
744 });
745}
746
747#[gpui::test]
748#[cfg_attr(not(feature = "e2e"), ignore)]
749async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
750 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
751
752 // Test a tool call that's likely to complete *before* streaming stops.
753 let mut events = thread
754 .update(cx, |thread, cx| {
755 thread.add_tool(WordListTool);
756 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
757 })
758 .unwrap();
759
760 let mut saw_partial_tool_use = false;
761 while let Some(event) = events.next().await {
762 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
763 thread.update(cx, |thread, _cx| {
764 // Look for a tool use in the thread's last message
765 let message = thread.last_received_or_pending_message().unwrap();
766 let agent_message = message.as_agent_message().unwrap();
767 let last_content = agent_message.content.last().unwrap();
768 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
769 assert_eq!(last_tool_use.name.as_ref(), "word_list");
770 if tool_call.status == acp::ToolCallStatus::Pending {
771 if !last_tool_use.is_input_complete
772 && last_tool_use.input.get("g").is_none()
773 {
774 saw_partial_tool_use = true;
775 }
776 } else {
777 last_tool_use
778 .input
779 .get("a")
780 .expect("'a' has streamed because input is now complete");
781 last_tool_use
782 .input
783 .get("g")
784 .expect("'g' has streamed because input is now complete");
785 }
786 } else {
787 panic!("last content should be a tool use");
788 }
789 });
790 }
791 }
792
793 assert!(
794 saw_partial_tool_use,
795 "should see at least one partially streamed tool use in the history"
796 );
797}
798
799#[gpui::test]
800async fn test_tool_authorization(cx: &mut TestAppContext) {
801 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
802 let fake_model = model.as_fake();
803
804 let mut events = thread
805 .update(cx, |thread, cx| {
806 thread.add_tool(ToolRequiringPermission);
807 thread.send(UserMessageId::new(), ["abc"], cx)
808 })
809 .unwrap();
810 cx.run_until_parked();
811 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
812 LanguageModelToolUse {
813 id: "tool_id_1".into(),
814 name: ToolRequiringPermission::NAME.into(),
815 raw_input: "{}".into(),
816 input: json!({}),
817 is_input_complete: true,
818 thought_signature: None,
819 },
820 ));
821 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
822 LanguageModelToolUse {
823 id: "tool_id_2".into(),
824 name: ToolRequiringPermission::NAME.into(),
825 raw_input: "{}".into(),
826 input: json!({}),
827 is_input_complete: true,
828 thought_signature: None,
829 },
830 ));
831 fake_model.end_last_completion_stream();
832 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
833 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
834
835 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
836 tool_call_auth_1
837 .response
838 .send(acp::PermissionOptionId::new("allow"))
839 .unwrap();
840 cx.run_until_parked();
841
842 // Reject the second - send "deny" option_id directly since Deny is now a button
843 tool_call_auth_2
844 .response
845 .send(acp::PermissionOptionId::new("deny"))
846 .unwrap();
847 cx.run_until_parked();
848
849 let completion = fake_model.pending_completions().pop().unwrap();
850 let message = completion.messages.last().unwrap();
851 assert_eq!(
852 message.content,
853 vec![
854 language_model::MessageContent::ToolResult(LanguageModelToolResult {
855 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
856 tool_name: ToolRequiringPermission::NAME.into(),
857 is_error: false,
858 content: "Allowed".into(),
859 output: Some("Allowed".into())
860 }),
861 language_model::MessageContent::ToolResult(LanguageModelToolResult {
862 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
863 tool_name: ToolRequiringPermission::NAME.into(),
864 is_error: true,
865 content: "Permission to run tool denied by user".into(),
866 output: Some("Permission to run tool denied by user".into())
867 })
868 ]
869 );
870
871 // Simulate yet another tool call.
872 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
873 LanguageModelToolUse {
874 id: "tool_id_3".into(),
875 name: ToolRequiringPermission::NAME.into(),
876 raw_input: "{}".into(),
877 input: json!({}),
878 is_input_complete: true,
879 thought_signature: None,
880 },
881 ));
882 fake_model.end_last_completion_stream();
883
884 // Respond by always allowing tools - send transformed option_id
885 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
886 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
887 tool_call_auth_3
888 .response
889 .send(acp::PermissionOptionId::new(
890 "always_allow:tool_requiring_permission",
891 ))
892 .unwrap();
893 cx.run_until_parked();
894 let completion = fake_model.pending_completions().pop().unwrap();
895 let message = completion.messages.last().unwrap();
896 assert_eq!(
897 message.content,
898 vec![language_model::MessageContent::ToolResult(
899 LanguageModelToolResult {
900 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
901 tool_name: ToolRequiringPermission::NAME.into(),
902 is_error: false,
903 content: "Allowed".into(),
904 output: Some("Allowed".into())
905 }
906 )]
907 );
908
909 // Simulate a final tool call, ensuring we don't trigger authorization.
910 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
911 LanguageModelToolUse {
912 id: "tool_id_4".into(),
913 name: ToolRequiringPermission::NAME.into(),
914 raw_input: "{}".into(),
915 input: json!({}),
916 is_input_complete: true,
917 thought_signature: None,
918 },
919 ));
920 fake_model.end_last_completion_stream();
921 cx.run_until_parked();
922 let completion = fake_model.pending_completions().pop().unwrap();
923 let message = completion.messages.last().unwrap();
924 assert_eq!(
925 message.content,
926 vec![language_model::MessageContent::ToolResult(
927 LanguageModelToolResult {
928 tool_use_id: "tool_id_4".into(),
929 tool_name: ToolRequiringPermission::NAME.into(),
930 is_error: false,
931 content: "Allowed".into(),
932 output: Some("Allowed".into())
933 }
934 )]
935 );
936}
937
938#[gpui::test]
939async fn test_tool_hallucination(cx: &mut TestAppContext) {
940 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
941 let fake_model = model.as_fake();
942
943 let mut events = thread
944 .update(cx, |thread, cx| {
945 thread.send(UserMessageId::new(), ["abc"], cx)
946 })
947 .unwrap();
948 cx.run_until_parked();
949 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
950 LanguageModelToolUse {
951 id: "tool_id_1".into(),
952 name: "nonexistent_tool".into(),
953 raw_input: "{}".into(),
954 input: json!({}),
955 is_input_complete: true,
956 thought_signature: None,
957 },
958 ));
959 fake_model.end_last_completion_stream();
960
961 let tool_call = expect_tool_call(&mut events).await;
962 assert_eq!(tool_call.title, "nonexistent_tool");
963 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
964 let update = expect_tool_call_update_fields(&mut events).await;
965 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
966}
967
968async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
969 let event = events
970 .next()
971 .await
972 .expect("no tool call authorization event received")
973 .unwrap();
974 match event {
975 ThreadEvent::ToolCall(tool_call) => tool_call,
976 event => {
977 panic!("Unexpected event {event:?}");
978 }
979 }
980}
981
982async fn expect_tool_call_update_fields(
983 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
984) -> acp::ToolCallUpdate {
985 let event = events
986 .next()
987 .await
988 .expect("no tool call authorization event received")
989 .unwrap();
990 match event {
991 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
992 event => {
993 panic!("Unexpected event {event:?}");
994 }
995 }
996}
997
998async fn next_tool_call_authorization(
999 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1000) -> ToolCallAuthorization {
1001 loop {
1002 let event = events
1003 .next()
1004 .await
1005 .expect("no tool call authorization event received")
1006 .unwrap();
1007 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1008 let permission_kinds = tool_call_authorization
1009 .options
1010 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1011 .map(|option| option.kind);
1012 let allow_once = tool_call_authorization
1013 .options
1014 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1015 .map(|option| option.kind);
1016
1017 assert_eq!(
1018 permission_kinds,
1019 Some(acp::PermissionOptionKind::AllowAlways)
1020 );
1021 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1022 return tool_call_authorization;
1023 }
1024 }
1025}
1026
1027#[test]
1028fn test_permission_options_terminal_with_pattern() {
1029 let permission_options = ToolPermissionContext::new(
1030 TerminalTool::NAME,
1031 vec!["cargo build --release".to_string()],
1032 )
1033 .build_permission_options();
1034
1035 let PermissionOptions::Dropdown(choices) = permission_options else {
1036 panic!("Expected dropdown permission options");
1037 };
1038
1039 assert_eq!(choices.len(), 3);
1040 let labels: Vec<&str> = choices
1041 .iter()
1042 .map(|choice| choice.allow.name.as_ref())
1043 .collect();
1044 assert!(labels.contains(&"Always for terminal"));
1045 assert!(labels.contains(&"Always for `cargo build` commands"));
1046 assert!(labels.contains(&"Only this time"));
1047}
1048
1049#[test]
1050fn test_permission_options_terminal_command_with_flag_second_token() {
1051 let permission_options =
1052 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1053 .build_permission_options();
1054
1055 let PermissionOptions::Dropdown(choices) = permission_options else {
1056 panic!("Expected dropdown permission options");
1057 };
1058
1059 assert_eq!(choices.len(), 3);
1060 let labels: Vec<&str> = choices
1061 .iter()
1062 .map(|choice| choice.allow.name.as_ref())
1063 .collect();
1064 assert!(labels.contains(&"Always for terminal"));
1065 assert!(labels.contains(&"Always for `ls` commands"));
1066 assert!(labels.contains(&"Only this time"));
1067}
1068
1069#[test]
1070fn test_permission_options_terminal_single_word_command() {
1071 let permission_options =
1072 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1073 .build_permission_options();
1074
1075 let PermissionOptions::Dropdown(choices) = permission_options else {
1076 panic!("Expected dropdown permission options");
1077 };
1078
1079 assert_eq!(choices.len(), 3);
1080 let labels: Vec<&str> = choices
1081 .iter()
1082 .map(|choice| choice.allow.name.as_ref())
1083 .collect();
1084 assert!(labels.contains(&"Always for terminal"));
1085 assert!(labels.contains(&"Always for `whoami` commands"));
1086 assert!(labels.contains(&"Only this time"));
1087}
1088
1089#[test]
1090fn test_permission_options_edit_file_with_path_pattern() {
1091 let permission_options =
1092 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1093 .build_permission_options();
1094
1095 let PermissionOptions::Dropdown(choices) = permission_options else {
1096 panic!("Expected dropdown permission options");
1097 };
1098
1099 let labels: Vec<&str> = choices
1100 .iter()
1101 .map(|choice| choice.allow.name.as_ref())
1102 .collect();
1103 assert!(labels.contains(&"Always for edit file"));
1104 assert!(labels.contains(&"Always for `src/`"));
1105}
1106
1107#[test]
1108fn test_permission_options_fetch_with_domain_pattern() {
1109 let permission_options =
1110 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1111 .build_permission_options();
1112
1113 let PermissionOptions::Dropdown(choices) = permission_options else {
1114 panic!("Expected dropdown permission options");
1115 };
1116
1117 let labels: Vec<&str> = choices
1118 .iter()
1119 .map(|choice| choice.allow.name.as_ref())
1120 .collect();
1121 assert!(labels.contains(&"Always for fetch"));
1122 assert!(labels.contains(&"Always for `docs.rs`"));
1123}
1124
1125#[test]
1126fn test_permission_options_without_pattern() {
1127 let permission_options = ToolPermissionContext::new(
1128 TerminalTool::NAME,
1129 vec!["./deploy.sh --production".to_string()],
1130 )
1131 .build_permission_options();
1132
1133 let PermissionOptions::Dropdown(choices) = permission_options else {
1134 panic!("Expected dropdown permission options");
1135 };
1136
1137 assert_eq!(choices.len(), 2);
1138 let labels: Vec<&str> = choices
1139 .iter()
1140 .map(|choice| choice.allow.name.as_ref())
1141 .collect();
1142 assert!(labels.contains(&"Always for terminal"));
1143 assert!(labels.contains(&"Only this time"));
1144 assert!(!labels.iter().any(|label| label.contains("commands")));
1145}
1146
1147#[test]
1148fn test_permission_options_symlink_target_are_flat_once_only() {
1149 let permission_options =
1150 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1151 .build_permission_options();
1152
1153 let PermissionOptions::Flat(options) = permission_options else {
1154 panic!("Expected flat permission options for symlink target authorization");
1155 };
1156
1157 assert_eq!(options.len(), 2);
1158 assert!(options.iter().any(|option| {
1159 option.option_id.0.as_ref() == "allow"
1160 && option.kind == acp::PermissionOptionKind::AllowOnce
1161 }));
1162 assert!(options.iter().any(|option| {
1163 option.option_id.0.as_ref() == "deny"
1164 && option.kind == acp::PermissionOptionKind::RejectOnce
1165 }));
1166}
1167
1168#[test]
1169fn test_permission_option_ids_for_terminal() {
1170 let permission_options = ToolPermissionContext::new(
1171 TerminalTool::NAME,
1172 vec!["cargo build --release".to_string()],
1173 )
1174 .build_permission_options();
1175
1176 let PermissionOptions::Dropdown(choices) = permission_options else {
1177 panic!("Expected dropdown permission options");
1178 };
1179
1180 let allow_ids: Vec<String> = choices
1181 .iter()
1182 .map(|choice| choice.allow.option_id.0.to_string())
1183 .collect();
1184 let deny_ids: Vec<String> = choices
1185 .iter()
1186 .map(|choice| choice.deny.option_id.0.to_string())
1187 .collect();
1188
1189 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1190 assert!(allow_ids.contains(&"allow".to_string()));
1191 assert!(
1192 allow_ids
1193 .iter()
1194 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1195 "Missing allow pattern option"
1196 );
1197
1198 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1199 assert!(deny_ids.contains(&"deny".to_string()));
1200 assert!(
1201 deny_ids
1202 .iter()
1203 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1204 "Missing deny pattern option"
1205 );
1206}
1207
1208#[gpui::test]
1209#[cfg_attr(not(feature = "e2e"), ignore)]
1210async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1211 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1212
1213 // Test concurrent tool calls with different delay times
1214 let events = thread
1215 .update(cx, |thread, cx| {
1216 thread.add_tool(DelayTool);
1217 thread.send(
1218 UserMessageId::new(),
1219 [
1220 "Call the delay tool twice in the same message.",
1221 "Once with 100ms. Once with 300ms.",
1222 "When both timers are complete, describe the outputs.",
1223 ],
1224 cx,
1225 )
1226 })
1227 .unwrap()
1228 .collect()
1229 .await;
1230
1231 let stop_reasons = stop_events(events);
1232 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1233
1234 thread.update(cx, |thread, _cx| {
1235 let last_message = thread.last_received_or_pending_message().unwrap();
1236 let agent_message = last_message.as_agent_message().unwrap();
1237 let text = agent_message
1238 .content
1239 .iter()
1240 .filter_map(|content| {
1241 if let AgentMessageContent::Text(text) = content {
1242 Some(text.as_str())
1243 } else {
1244 None
1245 }
1246 })
1247 .collect::<String>();
1248
1249 assert!(text.contains("Ding"));
1250 });
1251}
1252
1253#[gpui::test]
1254async fn test_profiles(cx: &mut TestAppContext) {
1255 let ThreadTest {
1256 model, thread, fs, ..
1257 } = setup(cx, TestModel::Fake).await;
1258 let fake_model = model.as_fake();
1259
1260 thread.update(cx, |thread, _cx| {
1261 thread.add_tool(DelayTool);
1262 thread.add_tool(EchoTool);
1263 thread.add_tool(InfiniteTool);
1264 });
1265
1266 // Override profiles and wait for settings to be loaded.
1267 fs.insert_file(
1268 paths::settings_file(),
1269 json!({
1270 "agent": {
1271 "profiles": {
1272 "test-1": {
1273 "name": "Test Profile 1",
1274 "tools": {
1275 EchoTool::NAME: true,
1276 DelayTool::NAME: true,
1277 }
1278 },
1279 "test-2": {
1280 "name": "Test Profile 2",
1281 "tools": {
1282 InfiniteTool::NAME: true,
1283 }
1284 }
1285 }
1286 }
1287 })
1288 .to_string()
1289 .into_bytes(),
1290 )
1291 .await;
1292 cx.run_until_parked();
1293
1294 // Test that test-1 profile (default) has echo and delay tools
1295 thread
1296 .update(cx, |thread, cx| {
1297 thread.set_profile(AgentProfileId("test-1".into()), cx);
1298 thread.send(UserMessageId::new(), ["test"], cx)
1299 })
1300 .unwrap();
1301 cx.run_until_parked();
1302
1303 let mut pending_completions = fake_model.pending_completions();
1304 assert_eq!(pending_completions.len(), 1);
1305 let completion = pending_completions.pop().unwrap();
1306 let tool_names: Vec<String> = completion
1307 .tools
1308 .iter()
1309 .map(|tool| tool.name.clone())
1310 .collect();
1311 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1312 fake_model.end_last_completion_stream();
1313
1314 // Switch to test-2 profile, and verify that it has only the infinite tool.
1315 thread
1316 .update(cx, |thread, cx| {
1317 thread.set_profile(AgentProfileId("test-2".into()), cx);
1318 thread.send(UserMessageId::new(), ["test2"], cx)
1319 })
1320 .unwrap();
1321 cx.run_until_parked();
1322 let mut pending_completions = fake_model.pending_completions();
1323 assert_eq!(pending_completions.len(), 1);
1324 let completion = pending_completions.pop().unwrap();
1325 let tool_names: Vec<String> = completion
1326 .tools
1327 .iter()
1328 .map(|tool| tool.name.clone())
1329 .collect();
1330 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1331}
1332
1333#[gpui::test]
1334async fn test_mcp_tools(cx: &mut TestAppContext) {
1335 let ThreadTest {
1336 model,
1337 thread,
1338 context_server_store,
1339 fs,
1340 ..
1341 } = setup(cx, TestModel::Fake).await;
1342 let fake_model = model.as_fake();
1343
1344 // Override profiles and wait for settings to be loaded.
1345 fs.insert_file(
1346 paths::settings_file(),
1347 json!({
1348 "agent": {
1349 "tool_permissions": { "default": "allow" },
1350 "profiles": {
1351 "test": {
1352 "name": "Test Profile",
1353 "enable_all_context_servers": true,
1354 "tools": {
1355 EchoTool::NAME: true,
1356 }
1357 },
1358 }
1359 }
1360 })
1361 .to_string()
1362 .into_bytes(),
1363 )
1364 .await;
1365 cx.run_until_parked();
1366 thread.update(cx, |thread, cx| {
1367 thread.set_profile(AgentProfileId("test".into()), cx)
1368 });
1369
1370 let mut mcp_tool_calls = setup_context_server(
1371 "test_server",
1372 vec![context_server::types::Tool {
1373 name: "echo".into(),
1374 description: None,
1375 input_schema: serde_json::to_value(EchoTool::input_schema(
1376 LanguageModelToolSchemaFormat::JsonSchema,
1377 ))
1378 .unwrap(),
1379 output_schema: None,
1380 annotations: None,
1381 }],
1382 &context_server_store,
1383 cx,
1384 );
1385
1386 let events = thread.update(cx, |thread, cx| {
1387 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1388 });
1389 cx.run_until_parked();
1390
1391 // Simulate the model calling the MCP tool.
1392 let completion = fake_model.pending_completions().pop().unwrap();
1393 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1394 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1395 LanguageModelToolUse {
1396 id: "tool_1".into(),
1397 name: "echo".into(),
1398 raw_input: json!({"text": "test"}).to_string(),
1399 input: json!({"text": "test"}),
1400 is_input_complete: true,
1401 thought_signature: None,
1402 },
1403 ));
1404 fake_model.end_last_completion_stream();
1405 cx.run_until_parked();
1406
1407 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1408 assert_eq!(tool_call_params.name, "echo");
1409 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1410 tool_call_response
1411 .send(context_server::types::CallToolResponse {
1412 content: vec![context_server::types::ToolResponseContent::Text {
1413 text: "test".into(),
1414 }],
1415 is_error: None,
1416 meta: None,
1417 structured_content: None,
1418 })
1419 .unwrap();
1420 cx.run_until_parked();
1421
1422 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1423 fake_model.send_last_completion_stream_text_chunk("Done!");
1424 fake_model.end_last_completion_stream();
1425 events.collect::<Vec<_>>().await;
1426
1427 // Send again after adding the echo tool, ensuring the name collision is resolved.
1428 let events = thread.update(cx, |thread, cx| {
1429 thread.add_tool(EchoTool);
1430 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1431 });
1432 cx.run_until_parked();
1433 let completion = fake_model.pending_completions().pop().unwrap();
1434 assert_eq!(
1435 tool_names_for_completion(&completion),
1436 vec!["echo", "test_server_echo"]
1437 );
1438 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1439 LanguageModelToolUse {
1440 id: "tool_2".into(),
1441 name: "test_server_echo".into(),
1442 raw_input: json!({"text": "mcp"}).to_string(),
1443 input: json!({"text": "mcp"}),
1444 is_input_complete: true,
1445 thought_signature: None,
1446 },
1447 ));
1448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1449 LanguageModelToolUse {
1450 id: "tool_3".into(),
1451 name: "echo".into(),
1452 raw_input: json!({"text": "native"}).to_string(),
1453 input: json!({"text": "native"}),
1454 is_input_complete: true,
1455 thought_signature: None,
1456 },
1457 ));
1458 fake_model.end_last_completion_stream();
1459 cx.run_until_parked();
1460
1461 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1462 assert_eq!(tool_call_params.name, "echo");
1463 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1464 tool_call_response
1465 .send(context_server::types::CallToolResponse {
1466 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1467 is_error: None,
1468 meta: None,
1469 structured_content: None,
1470 })
1471 .unwrap();
1472 cx.run_until_parked();
1473
1474 // Ensure the tool results were inserted with the correct names.
1475 let completion = fake_model.pending_completions().pop().unwrap();
1476 assert_eq!(
1477 completion.messages.last().unwrap().content,
1478 vec![
1479 MessageContent::ToolResult(LanguageModelToolResult {
1480 tool_use_id: "tool_3".into(),
1481 tool_name: "echo".into(),
1482 is_error: false,
1483 content: "native".into(),
1484 output: Some("native".into()),
1485 },),
1486 MessageContent::ToolResult(LanguageModelToolResult {
1487 tool_use_id: "tool_2".into(),
1488 tool_name: "test_server_echo".into(),
1489 is_error: false,
1490 content: "mcp".into(),
1491 output: Some("mcp".into()),
1492 },),
1493 ]
1494 );
1495 fake_model.end_last_completion_stream();
1496 events.collect::<Vec<_>>().await;
1497}
1498
1499#[gpui::test]
1500async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1501 let ThreadTest {
1502 model,
1503 thread,
1504 context_server_store,
1505 fs,
1506 ..
1507 } = setup(cx, TestModel::Fake).await;
1508 let fake_model = model.as_fake();
1509
1510 // Setup settings to allow MCP tools
1511 fs.insert_file(
1512 paths::settings_file(),
1513 json!({
1514 "agent": {
1515 "always_allow_tool_actions": true,
1516 "profiles": {
1517 "test": {
1518 "name": "Test Profile",
1519 "enable_all_context_servers": true,
1520 "tools": {}
1521 },
1522 }
1523 }
1524 })
1525 .to_string()
1526 .into_bytes(),
1527 )
1528 .await;
1529 cx.run_until_parked();
1530 thread.update(cx, |thread, cx| {
1531 thread.set_profile(AgentProfileId("test".into()), cx)
1532 });
1533
1534 // Setup a context server with a tool
1535 let mut mcp_tool_calls = setup_context_server(
1536 "github_server",
1537 vec![context_server::types::Tool {
1538 name: "issue_read".into(),
1539 description: Some("Read a GitHub issue".into()),
1540 input_schema: json!({
1541 "type": "object",
1542 "properties": {
1543 "issue_url": { "type": "string" }
1544 }
1545 }),
1546 output_schema: None,
1547 annotations: None,
1548 }],
1549 &context_server_store,
1550 cx,
1551 );
1552
1553 // Send a message and have the model call the MCP tool
1554 let events = thread.update(cx, |thread, cx| {
1555 thread
1556 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1557 .unwrap()
1558 });
1559 cx.run_until_parked();
1560
1561 // Verify the MCP tool is available to the model
1562 let completion = fake_model.pending_completions().pop().unwrap();
1563 assert_eq!(
1564 tool_names_for_completion(&completion),
1565 vec!["issue_read"],
1566 "MCP tool should be available"
1567 );
1568
1569 // Simulate the model calling the MCP tool
1570 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1571 LanguageModelToolUse {
1572 id: "tool_1".into(),
1573 name: "issue_read".into(),
1574 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1575 .to_string(),
1576 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1577 is_input_complete: true,
1578 thought_signature: None,
1579 },
1580 ));
1581 fake_model.end_last_completion_stream();
1582 cx.run_until_parked();
1583
1584 // The MCP server receives the tool call and responds with content
1585 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1586 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1587 assert_eq!(tool_call_params.name, "issue_read");
1588 tool_call_response
1589 .send(context_server::types::CallToolResponse {
1590 content: vec![context_server::types::ToolResponseContent::Text {
1591 text: expected_tool_output.into(),
1592 }],
1593 is_error: None,
1594 meta: None,
1595 structured_content: None,
1596 })
1597 .unwrap();
1598 cx.run_until_parked();
1599
1600 // After tool completes, the model continues with a new completion request
1601 // that includes the tool results. We need to respond to this.
1602 let _completion = fake_model.pending_completions().pop().unwrap();
1603 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1604 fake_model
1605 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1606 fake_model.end_last_completion_stream();
1607 events.collect::<Vec<_>>().await;
1608
1609 // Verify the tool result is stored in the thread by checking the markdown output.
1610 // The tool result is in the first assistant message (not the last one, which is
1611 // the model's response after the tool completed).
1612 thread.update(cx, |thread, _cx| {
1613 let markdown = thread.to_markdown();
1614 assert!(
1615 markdown.contains("**Tool Result**: issue_read"),
1616 "Thread should contain tool result header"
1617 );
1618 assert!(
1619 markdown.contains(expected_tool_output),
1620 "Thread should contain tool output: {}",
1621 expected_tool_output
1622 );
1623 });
1624
1625 // Simulate app restart: disconnect the MCP server.
1626 // After restart, the MCP server won't be connected yet when the thread is replayed.
1627 context_server_store.update(cx, |store, cx| {
1628 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1629 });
1630 cx.run_until_parked();
1631
1632 // Replay the thread (this is what happens when loading a saved thread)
1633 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1634
1635 let mut found_tool_call = None;
1636 let mut found_tool_call_update_with_output = None;
1637
1638 while let Some(event) = replay_events.next().await {
1639 let event = event.unwrap();
1640 match &event {
1641 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1642 found_tool_call = Some(tc.clone());
1643 }
1644 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1645 if update.tool_call_id.to_string() == "tool_1" =>
1646 {
1647 if update.fields.raw_output.is_some() {
1648 found_tool_call_update_with_output = Some(update.clone());
1649 }
1650 }
1651 _ => {}
1652 }
1653 }
1654
1655 // The tool call should be found
1656 assert!(
1657 found_tool_call.is_some(),
1658 "Tool call should be emitted during replay"
1659 );
1660
1661 assert!(
1662 found_tool_call_update_with_output.is_some(),
1663 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1664 );
1665
1666 let update = found_tool_call_update_with_output.unwrap();
1667 assert_eq!(
1668 update.fields.raw_output,
1669 Some(expected_tool_output.into()),
1670 "raw_output should contain the saved tool result"
1671 );
1672
1673 // Also verify the status is correct (completed, not failed)
1674 assert_eq!(
1675 update.fields.status,
1676 Some(acp::ToolCallStatus::Completed),
1677 "Tool call status should reflect the original completion status"
1678 );
1679}
1680
1681#[gpui::test]
1682async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1683 let ThreadTest {
1684 model,
1685 thread,
1686 context_server_store,
1687 fs,
1688 ..
1689 } = setup(cx, TestModel::Fake).await;
1690 let fake_model = model.as_fake();
1691
1692 // Set up a profile with all tools enabled
1693 fs.insert_file(
1694 paths::settings_file(),
1695 json!({
1696 "agent": {
1697 "profiles": {
1698 "test": {
1699 "name": "Test Profile",
1700 "enable_all_context_servers": true,
1701 "tools": {
1702 EchoTool::NAME: true,
1703 DelayTool::NAME: true,
1704 WordListTool::NAME: true,
1705 ToolRequiringPermission::NAME: true,
1706 InfiniteTool::NAME: true,
1707 }
1708 },
1709 }
1710 }
1711 })
1712 .to_string()
1713 .into_bytes(),
1714 )
1715 .await;
1716 cx.run_until_parked();
1717
1718 thread.update(cx, |thread, cx| {
1719 thread.set_profile(AgentProfileId("test".into()), cx);
1720 thread.add_tool(EchoTool);
1721 thread.add_tool(DelayTool);
1722 thread.add_tool(WordListTool);
1723 thread.add_tool(ToolRequiringPermission);
1724 thread.add_tool(InfiniteTool);
1725 });
1726
1727 // Set up multiple context servers with some overlapping tool names
1728 let _server1_calls = setup_context_server(
1729 "xxx",
1730 vec![
1731 context_server::types::Tool {
1732 name: "echo".into(), // Conflicts with native EchoTool
1733 description: None,
1734 input_schema: serde_json::to_value(EchoTool::input_schema(
1735 LanguageModelToolSchemaFormat::JsonSchema,
1736 ))
1737 .unwrap(),
1738 output_schema: None,
1739 annotations: None,
1740 },
1741 context_server::types::Tool {
1742 name: "unique_tool_1".into(),
1743 description: None,
1744 input_schema: json!({"type": "object", "properties": {}}),
1745 output_schema: None,
1746 annotations: None,
1747 },
1748 ],
1749 &context_server_store,
1750 cx,
1751 );
1752
1753 let _server2_calls = setup_context_server(
1754 "yyy",
1755 vec![
1756 context_server::types::Tool {
1757 name: "echo".into(), // Also conflicts with native EchoTool
1758 description: None,
1759 input_schema: serde_json::to_value(EchoTool::input_schema(
1760 LanguageModelToolSchemaFormat::JsonSchema,
1761 ))
1762 .unwrap(),
1763 output_schema: None,
1764 annotations: None,
1765 },
1766 context_server::types::Tool {
1767 name: "unique_tool_2".into(),
1768 description: None,
1769 input_schema: json!({"type": "object", "properties": {}}),
1770 output_schema: None,
1771 annotations: None,
1772 },
1773 context_server::types::Tool {
1774 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1775 description: None,
1776 input_schema: json!({"type": "object", "properties": {}}),
1777 output_schema: None,
1778 annotations: None,
1779 },
1780 context_server::types::Tool {
1781 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1782 description: None,
1783 input_schema: json!({"type": "object", "properties": {}}),
1784 output_schema: None,
1785 annotations: None,
1786 },
1787 ],
1788 &context_server_store,
1789 cx,
1790 );
1791 let _server3_calls = setup_context_server(
1792 "zzz",
1793 vec![
1794 context_server::types::Tool {
1795 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1796 description: None,
1797 input_schema: json!({"type": "object", "properties": {}}),
1798 output_schema: None,
1799 annotations: None,
1800 },
1801 context_server::types::Tool {
1802 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1803 description: None,
1804 input_schema: json!({"type": "object", "properties": {}}),
1805 output_schema: None,
1806 annotations: None,
1807 },
1808 context_server::types::Tool {
1809 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1810 description: None,
1811 input_schema: json!({"type": "object", "properties": {}}),
1812 output_schema: None,
1813 annotations: None,
1814 },
1815 ],
1816 &context_server_store,
1817 cx,
1818 );
1819
1820 // Server with spaces in name - tests snake_case conversion for API compatibility
1821 let _server4_calls = setup_context_server(
1822 "Azure DevOps",
1823 vec![context_server::types::Tool {
1824 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1825 description: None,
1826 input_schema: serde_json::to_value(EchoTool::input_schema(
1827 LanguageModelToolSchemaFormat::JsonSchema,
1828 ))
1829 .unwrap(),
1830 output_schema: None,
1831 annotations: None,
1832 }],
1833 &context_server_store,
1834 cx,
1835 );
1836
1837 thread
1838 .update(cx, |thread, cx| {
1839 thread.send(UserMessageId::new(), ["Go"], cx)
1840 })
1841 .unwrap();
1842 cx.run_until_parked();
1843 let completion = fake_model.pending_completions().pop().unwrap();
1844 assert_eq!(
1845 tool_names_for_completion(&completion),
1846 vec![
1847 "azure_dev_ops_echo",
1848 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1849 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1850 "delay",
1851 "echo",
1852 "infinite",
1853 "tool_requiring_permission",
1854 "unique_tool_1",
1855 "unique_tool_2",
1856 "word_list",
1857 "xxx_echo",
1858 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1859 "yyy_echo",
1860 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1861 ]
1862 );
1863}
1864
1865#[gpui::test]
1866#[cfg_attr(not(feature = "e2e"), ignore)]
1867async fn test_cancellation(cx: &mut TestAppContext) {
1868 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1869
1870 let mut events = thread
1871 .update(cx, |thread, cx| {
1872 thread.add_tool(InfiniteTool);
1873 thread.add_tool(EchoTool);
1874 thread.send(
1875 UserMessageId::new(),
1876 ["Call the echo tool, then call the infinite tool, then explain their output"],
1877 cx,
1878 )
1879 })
1880 .unwrap();
1881
1882 // Wait until both tools are called.
1883 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1884 let mut echo_id = None;
1885 let mut echo_completed = false;
1886 while let Some(event) = events.next().await {
1887 match event.unwrap() {
1888 ThreadEvent::ToolCall(tool_call) => {
1889 assert_eq!(tool_call.title, expected_tools.remove(0));
1890 if tool_call.title == "Echo" {
1891 echo_id = Some(tool_call.tool_call_id);
1892 }
1893 }
1894 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1895 acp::ToolCallUpdate {
1896 tool_call_id,
1897 fields:
1898 acp::ToolCallUpdateFields {
1899 status: Some(acp::ToolCallStatus::Completed),
1900 ..
1901 },
1902 ..
1903 },
1904 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1905 echo_completed = true;
1906 }
1907 _ => {}
1908 }
1909
1910 if expected_tools.is_empty() && echo_completed {
1911 break;
1912 }
1913 }
1914
1915 // Cancel the current send and ensure that the event stream is closed, even
1916 // if one of the tools is still running.
1917 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1918 let events = events.collect::<Vec<_>>().await;
1919 let last_event = events.last();
1920 assert!(
1921 matches!(
1922 last_event,
1923 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1924 ),
1925 "unexpected event {last_event:?}"
1926 );
1927
1928 // Ensure we can still send a new message after cancellation.
1929 let events = thread
1930 .update(cx, |thread, cx| {
1931 thread.send(
1932 UserMessageId::new(),
1933 ["Testing: reply with 'Hello' then stop."],
1934 cx,
1935 )
1936 })
1937 .unwrap()
1938 .collect::<Vec<_>>()
1939 .await;
1940 thread.update(cx, |thread, _cx| {
1941 let message = thread.last_received_or_pending_message().unwrap();
1942 let agent_message = message.as_agent_message().unwrap();
1943 assert_eq!(
1944 agent_message.content,
1945 vec![AgentMessageContent::Text("Hello".to_string())]
1946 );
1947 });
1948 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1949}
1950
1951#[gpui::test]
1952async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1953 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1954 always_allow_tools(cx);
1955 let fake_model = model.as_fake();
1956
1957 let environment = Rc::new(cx.update(|cx| {
1958 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1959 }));
1960 let handle = environment.terminal_handle.clone().unwrap();
1961
1962 let mut events = thread
1963 .update(cx, |thread, cx| {
1964 thread.add_tool(crate::TerminalTool::new(
1965 thread.project().clone(),
1966 environment,
1967 ));
1968 thread.send(UserMessageId::new(), ["run a command"], cx)
1969 })
1970 .unwrap();
1971
1972 cx.run_until_parked();
1973
1974 // Simulate the model calling the terminal tool
1975 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1976 LanguageModelToolUse {
1977 id: "terminal_tool_1".into(),
1978 name: TerminalTool::NAME.into(),
1979 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1980 input: json!({"command": "sleep 1000", "cd": "."}),
1981 is_input_complete: true,
1982 thought_signature: None,
1983 },
1984 ));
1985 fake_model.end_last_completion_stream();
1986
1987 // Wait for the terminal tool to start running
1988 wait_for_terminal_tool_started(&mut events, cx).await;
1989
1990 // Cancel the thread while the terminal is running
1991 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1992
1993 // Collect remaining events, driving the executor to let cancellation complete
1994 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1995
1996 // Verify the terminal was killed
1997 assert!(
1998 handle.was_killed(),
1999 "expected terminal handle to be killed on cancellation"
2000 );
2001
2002 // Verify we got a cancellation stop event
2003 assert_eq!(
2004 stop_events(remaining_events),
2005 vec![acp::StopReason::Cancelled],
2006 );
2007
2008 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2009 thread.update(cx, |thread, _cx| {
2010 let message = thread.last_received_or_pending_message().unwrap();
2011 let agent_message = message.as_agent_message().unwrap();
2012
2013 let tool_use = agent_message
2014 .content
2015 .iter()
2016 .find_map(|content| match content {
2017 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2018 _ => None,
2019 })
2020 .expect("expected tool use in agent message");
2021
2022 let tool_result = agent_message
2023 .tool_results
2024 .get(&tool_use.id)
2025 .expect("expected tool result");
2026
2027 let result_text = match &tool_result.content {
2028 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2029 _ => panic!("expected text content in tool result"),
2030 };
2031
2032 // "partial output" comes from FakeTerminalHandle's output field
2033 assert!(
2034 result_text.contains("partial output"),
2035 "expected tool result to contain terminal output, got: {result_text}"
2036 );
2037 // Match the actual format from process_content in terminal_tool.rs
2038 assert!(
2039 result_text.contains("The user stopped this command"),
2040 "expected tool result to indicate user stopped, got: {result_text}"
2041 );
2042 });
2043
2044 // Verify we can send a new message after cancellation
2045 verify_thread_recovery(&thread, &fake_model, cx).await;
2046}
2047
2048#[gpui::test]
2049async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2050 // This test verifies that tools which properly handle cancellation via
2051 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2052 // to cancellation and report that they were cancelled.
2053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2054 always_allow_tools(cx);
2055 let fake_model = model.as_fake();
2056
2057 let (tool, was_cancelled) = CancellationAwareTool::new();
2058
2059 let mut events = thread
2060 .update(cx, |thread, cx| {
2061 thread.add_tool(tool);
2062 thread.send(
2063 UserMessageId::new(),
2064 ["call the cancellation aware tool"],
2065 cx,
2066 )
2067 })
2068 .unwrap();
2069
2070 cx.run_until_parked();
2071
2072 // Simulate the model calling the cancellation-aware tool
2073 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2074 LanguageModelToolUse {
2075 id: "cancellation_aware_1".into(),
2076 name: "cancellation_aware".into(),
2077 raw_input: r#"{}"#.into(),
2078 input: json!({}),
2079 is_input_complete: true,
2080 thought_signature: None,
2081 },
2082 ));
2083 fake_model.end_last_completion_stream();
2084
2085 cx.run_until_parked();
2086
2087 // Wait for the tool call to be reported
2088 let mut tool_started = false;
2089 let deadline = cx.executor().num_cpus() * 100;
2090 for _ in 0..deadline {
2091 cx.run_until_parked();
2092
2093 while let Some(Some(event)) = events.next().now_or_never() {
2094 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2095 if tool_call.title == "Cancellation Aware Tool" {
2096 tool_started = true;
2097 break;
2098 }
2099 }
2100 }
2101
2102 if tool_started {
2103 break;
2104 }
2105
2106 cx.background_executor
2107 .timer(Duration::from_millis(10))
2108 .await;
2109 }
2110 assert!(tool_started, "expected cancellation aware tool to start");
2111
2112 // Cancel the thread and wait for it to complete
2113 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2114
2115 // The cancel task should complete promptly because the tool handles cancellation
2116 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2117 futures::select! {
2118 _ = cancel_task.fuse() => {}
2119 _ = timeout.fuse() => {
2120 panic!("cancel task timed out - tool did not respond to cancellation");
2121 }
2122 }
2123
2124 // Verify the tool detected cancellation via its flag
2125 assert!(
2126 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2127 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2128 );
2129
2130 // Collect remaining events
2131 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2132
2133 // Verify we got a cancellation stop event
2134 assert_eq!(
2135 stop_events(remaining_events),
2136 vec![acp::StopReason::Cancelled],
2137 );
2138
2139 // Verify we can send a new message after cancellation
2140 verify_thread_recovery(&thread, &fake_model, cx).await;
2141}
2142
2143/// Helper to verify thread can recover after cancellation by sending a simple message.
2144async fn verify_thread_recovery(
2145 thread: &Entity<Thread>,
2146 fake_model: &FakeLanguageModel,
2147 cx: &mut TestAppContext,
2148) {
2149 let events = thread
2150 .update(cx, |thread, cx| {
2151 thread.send(
2152 UserMessageId::new(),
2153 ["Testing: reply with 'Hello' then stop."],
2154 cx,
2155 )
2156 })
2157 .unwrap();
2158 cx.run_until_parked();
2159 fake_model.send_last_completion_stream_text_chunk("Hello");
2160 fake_model
2161 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2162 fake_model.end_last_completion_stream();
2163
2164 let events = events.collect::<Vec<_>>().await;
2165 thread.update(cx, |thread, _cx| {
2166 let message = thread.last_received_or_pending_message().unwrap();
2167 let agent_message = message.as_agent_message().unwrap();
2168 assert_eq!(
2169 agent_message.content,
2170 vec![AgentMessageContent::Text("Hello".to_string())]
2171 );
2172 });
2173 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2174}
2175
2176/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2177async fn wait_for_terminal_tool_started(
2178 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2179 cx: &mut TestAppContext,
2180) {
2181 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2182 for _ in 0..deadline {
2183 cx.run_until_parked();
2184
2185 while let Some(Some(event)) = events.next().now_or_never() {
2186 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2187 update,
2188 ))) = &event
2189 {
2190 if update.fields.content.as_ref().is_some_and(|content| {
2191 content
2192 .iter()
2193 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2194 }) {
2195 return;
2196 }
2197 }
2198 }
2199
2200 cx.background_executor
2201 .timer(Duration::from_millis(10))
2202 .await;
2203 }
2204 panic!("terminal tool did not start within the expected time");
2205}
2206
2207/// Collects events until a Stop event is received, driving the executor to completion.
2208async fn collect_events_until_stop(
2209 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2210 cx: &mut TestAppContext,
2211) -> Vec<Result<ThreadEvent>> {
2212 let mut collected = Vec::new();
2213 let deadline = cx.executor().num_cpus() * 200;
2214
2215 for _ in 0..deadline {
2216 cx.executor().advance_clock(Duration::from_millis(10));
2217 cx.run_until_parked();
2218
2219 while let Some(Some(event)) = events.next().now_or_never() {
2220 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2221 collected.push(event);
2222 if is_stop {
2223 return collected;
2224 }
2225 }
2226 }
2227 panic!(
2228 "did not receive Stop event within the expected time; collected {} events",
2229 collected.len()
2230 );
2231}
2232
2233#[gpui::test]
2234async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2235 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2236 always_allow_tools(cx);
2237 let fake_model = model.as_fake();
2238
2239 let environment = Rc::new(cx.update(|cx| {
2240 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2241 }));
2242 let handle = environment.terminal_handle.clone().unwrap();
2243
2244 let message_id = UserMessageId::new();
2245 let mut events = thread
2246 .update(cx, |thread, cx| {
2247 thread.add_tool(crate::TerminalTool::new(
2248 thread.project().clone(),
2249 environment,
2250 ));
2251 thread.send(message_id.clone(), ["run a command"], cx)
2252 })
2253 .unwrap();
2254
2255 cx.run_until_parked();
2256
2257 // Simulate the model calling the terminal tool
2258 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2259 LanguageModelToolUse {
2260 id: "terminal_tool_1".into(),
2261 name: TerminalTool::NAME.into(),
2262 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2263 input: json!({"command": "sleep 1000", "cd": "."}),
2264 is_input_complete: true,
2265 thought_signature: None,
2266 },
2267 ));
2268 fake_model.end_last_completion_stream();
2269
2270 // Wait for the terminal tool to start running
2271 wait_for_terminal_tool_started(&mut events, cx).await;
2272
2273 // Truncate the thread while the terminal is running
2274 thread
2275 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2276 .unwrap();
2277
2278 // Drive the executor to let cancellation complete
2279 let _ = collect_events_until_stop(&mut events, cx).await;
2280
2281 // Verify the terminal was killed
2282 assert!(
2283 handle.was_killed(),
2284 "expected terminal handle to be killed on truncate"
2285 );
2286
2287 // Verify the thread is empty after truncation
2288 thread.update(cx, |thread, _cx| {
2289 assert_eq!(
2290 thread.to_markdown(),
2291 "",
2292 "expected thread to be empty after truncating the only message"
2293 );
2294 });
2295
2296 // Verify we can send a new message after truncation
2297 verify_thread_recovery(&thread, &fake_model, cx).await;
2298}
2299
2300#[gpui::test]
2301async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2302 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2303 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2304 always_allow_tools(cx);
2305 let fake_model = model.as_fake();
2306
2307 let environment = Rc::new(MultiTerminalEnvironment::new());
2308
2309 let mut events = thread
2310 .update(cx, |thread, cx| {
2311 thread.add_tool(crate::TerminalTool::new(
2312 thread.project().clone(),
2313 environment.clone(),
2314 ));
2315 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2316 })
2317 .unwrap();
2318
2319 cx.run_until_parked();
2320
2321 // Simulate the model calling two terminal tools
2322 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2323 LanguageModelToolUse {
2324 id: "terminal_tool_1".into(),
2325 name: TerminalTool::NAME.into(),
2326 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2327 input: json!({"command": "sleep 1000", "cd": "."}),
2328 is_input_complete: true,
2329 thought_signature: None,
2330 },
2331 ));
2332 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2333 LanguageModelToolUse {
2334 id: "terminal_tool_2".into(),
2335 name: TerminalTool::NAME.into(),
2336 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2337 input: json!({"command": "sleep 2000", "cd": "."}),
2338 is_input_complete: true,
2339 thought_signature: None,
2340 },
2341 ));
2342 fake_model.end_last_completion_stream();
2343
2344 // Wait for both terminal tools to start by counting terminal content updates
2345 let mut terminals_started = 0;
2346 let deadline = cx.executor().num_cpus() * 100;
2347 for _ in 0..deadline {
2348 cx.run_until_parked();
2349
2350 while let Some(Some(event)) = events.next().now_or_never() {
2351 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2352 update,
2353 ))) = &event
2354 {
2355 if update.fields.content.as_ref().is_some_and(|content| {
2356 content
2357 .iter()
2358 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2359 }) {
2360 terminals_started += 1;
2361 if terminals_started >= 2 {
2362 break;
2363 }
2364 }
2365 }
2366 }
2367 if terminals_started >= 2 {
2368 break;
2369 }
2370
2371 cx.background_executor
2372 .timer(Duration::from_millis(10))
2373 .await;
2374 }
2375 assert!(
2376 terminals_started >= 2,
2377 "expected 2 terminal tools to start, got {terminals_started}"
2378 );
2379
2380 // Cancel the thread while both terminals are running
2381 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2382
2383 // Collect remaining events
2384 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2385
2386 // Verify both terminal handles were killed
2387 let handles = environment.handles();
2388 assert_eq!(
2389 handles.len(),
2390 2,
2391 "expected 2 terminal handles to be created"
2392 );
2393 assert!(
2394 handles[0].was_killed(),
2395 "expected first terminal handle to be killed on cancellation"
2396 );
2397 assert!(
2398 handles[1].was_killed(),
2399 "expected second terminal handle to be killed on cancellation"
2400 );
2401
2402 // Verify we got a cancellation stop event
2403 assert_eq!(
2404 stop_events(remaining_events),
2405 vec![acp::StopReason::Cancelled],
2406 );
2407}
2408
2409#[gpui::test]
2410async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2411 // Tests that clicking the stop button on the terminal card (as opposed to the main
2412 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2414 always_allow_tools(cx);
2415 let fake_model = model.as_fake();
2416
2417 let environment = Rc::new(cx.update(|cx| {
2418 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2419 }));
2420 let handle = environment.terminal_handle.clone().unwrap();
2421
2422 let mut events = thread
2423 .update(cx, |thread, cx| {
2424 thread.add_tool(crate::TerminalTool::new(
2425 thread.project().clone(),
2426 environment,
2427 ));
2428 thread.send(UserMessageId::new(), ["run a command"], cx)
2429 })
2430 .unwrap();
2431
2432 cx.run_until_parked();
2433
2434 // Simulate the model calling the terminal tool
2435 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2436 LanguageModelToolUse {
2437 id: "terminal_tool_1".into(),
2438 name: TerminalTool::NAME.into(),
2439 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2440 input: json!({"command": "sleep 1000", "cd": "."}),
2441 is_input_complete: true,
2442 thought_signature: None,
2443 },
2444 ));
2445 fake_model.end_last_completion_stream();
2446
2447 // Wait for the terminal tool to start running
2448 wait_for_terminal_tool_started(&mut events, cx).await;
2449
2450 // Simulate user clicking stop on the terminal card itself.
2451 // This sets the flag and signals exit (simulating what the real UI would do).
2452 handle.set_stopped_by_user(true);
2453 handle.killed.store(true, Ordering::SeqCst);
2454 handle.signal_exit();
2455
2456 // Wait for the tool to complete
2457 cx.run_until_parked();
2458
2459 // The thread continues after tool completion - simulate the model ending its turn
2460 fake_model
2461 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2462 fake_model.end_last_completion_stream();
2463
2464 // Collect remaining events
2465 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2466
2467 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2468 assert_eq!(
2469 stop_events(remaining_events),
2470 vec![acp::StopReason::EndTurn],
2471 );
2472
2473 // Verify the tool result indicates user stopped
2474 thread.update(cx, |thread, _cx| {
2475 let message = thread.last_received_or_pending_message().unwrap();
2476 let agent_message = message.as_agent_message().unwrap();
2477
2478 let tool_use = agent_message
2479 .content
2480 .iter()
2481 .find_map(|content| match content {
2482 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2483 _ => None,
2484 })
2485 .expect("expected tool use in agent message");
2486
2487 let tool_result = agent_message
2488 .tool_results
2489 .get(&tool_use.id)
2490 .expect("expected tool result");
2491
2492 let result_text = match &tool_result.content {
2493 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2494 _ => panic!("expected text content in tool result"),
2495 };
2496
2497 assert!(
2498 result_text.contains("The user stopped this command"),
2499 "expected tool result to indicate user stopped, got: {result_text}"
2500 );
2501 });
2502}
2503
2504#[gpui::test]
2505async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2506 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2507 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2508 always_allow_tools(cx);
2509 let fake_model = model.as_fake();
2510
2511 let environment = Rc::new(cx.update(|cx| {
2512 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2513 }));
2514 let handle = environment.terminal_handle.clone().unwrap();
2515
2516 let mut events = thread
2517 .update(cx, |thread, cx| {
2518 thread.add_tool(crate::TerminalTool::new(
2519 thread.project().clone(),
2520 environment,
2521 ));
2522 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2523 })
2524 .unwrap();
2525
2526 cx.run_until_parked();
2527
2528 // Simulate the model calling the terminal tool with a short timeout
2529 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2530 LanguageModelToolUse {
2531 id: "terminal_tool_1".into(),
2532 name: TerminalTool::NAME.into(),
2533 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2534 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2535 is_input_complete: true,
2536 thought_signature: None,
2537 },
2538 ));
2539 fake_model.end_last_completion_stream();
2540
2541 // Wait for the terminal tool to start running
2542 wait_for_terminal_tool_started(&mut events, cx).await;
2543
2544 // Advance clock past the timeout
2545 cx.executor().advance_clock(Duration::from_millis(200));
2546 cx.run_until_parked();
2547
2548 // The thread continues after tool completion - simulate the model ending its turn
2549 fake_model
2550 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2551 fake_model.end_last_completion_stream();
2552
2553 // Collect remaining events
2554 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2555
2556 // Verify the terminal was killed due to timeout
2557 assert!(
2558 handle.was_killed(),
2559 "expected terminal handle to be killed on timeout"
2560 );
2561
2562 // Verify we got an EndTurn (the tool completed, just with timeout)
2563 assert_eq!(
2564 stop_events(remaining_events),
2565 vec![acp::StopReason::EndTurn],
2566 );
2567
2568 // Verify the tool result indicates timeout, not user stopped
2569 thread.update(cx, |thread, _cx| {
2570 let message = thread.last_received_or_pending_message().unwrap();
2571 let agent_message = message.as_agent_message().unwrap();
2572
2573 let tool_use = agent_message
2574 .content
2575 .iter()
2576 .find_map(|content| match content {
2577 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2578 _ => None,
2579 })
2580 .expect("expected tool use in agent message");
2581
2582 let tool_result = agent_message
2583 .tool_results
2584 .get(&tool_use.id)
2585 .expect("expected tool result");
2586
2587 let result_text = match &tool_result.content {
2588 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2589 _ => panic!("expected text content in tool result"),
2590 };
2591
2592 assert!(
2593 result_text.contains("timed out"),
2594 "expected tool result to indicate timeout, got: {result_text}"
2595 );
2596 assert!(
2597 !result_text.contains("The user stopped"),
2598 "tool result should not mention user stopped when it timed out, got: {result_text}"
2599 );
2600 });
2601}
2602
2603#[gpui::test]
2604async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2605 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2606 let fake_model = model.as_fake();
2607
2608 let events_1 = thread
2609 .update(cx, |thread, cx| {
2610 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2611 })
2612 .unwrap();
2613 cx.run_until_parked();
2614 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2615 cx.run_until_parked();
2616
2617 let events_2 = thread
2618 .update(cx, |thread, cx| {
2619 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2620 })
2621 .unwrap();
2622 cx.run_until_parked();
2623 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2624 fake_model
2625 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2626 fake_model.end_last_completion_stream();
2627
2628 let events_1 = events_1.collect::<Vec<_>>().await;
2629 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2630 let events_2 = events_2.collect::<Vec<_>>().await;
2631 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2632}
2633
2634#[gpui::test]
2635async fn test_retry_cancelled_promptly_on_new_send(cx: &mut TestAppContext) {
2636 // Regression test: when a completion fails with a retryable error (e.g. upstream 500),
2637 // the retry loop waits on a timer. If the user switches models and sends a new message
2638 // during that delay, the old turn should exit immediately instead of retrying with the
2639 // stale model.
2640 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2641 let model_a = model.as_fake();
2642
2643 // Start a turn with model_a.
2644 let events_1 = thread
2645 .update(cx, |thread, cx| {
2646 thread.send(UserMessageId::new(), ["Hello"], cx)
2647 })
2648 .unwrap();
2649 cx.run_until_parked();
2650 assert_eq!(model_a.completion_count(), 1);
2651
2652 // Model returns a retryable upstream 500. The turn enters the retry delay.
2653 model_a.send_last_completion_stream_error(
2654 LanguageModelCompletionError::UpstreamProviderError {
2655 message: "Internal server error".to_string(),
2656 status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
2657 retry_after: None,
2658 },
2659 );
2660 model_a.end_last_completion_stream();
2661 cx.run_until_parked();
2662
2663 // The old completion was consumed; model_a has no pending requests yet because the
2664 // retry timer hasn't fired.
2665 assert_eq!(model_a.completion_count(), 0);
2666
2667 // Switch to model_b and send a new message. This cancels the old turn.
2668 let model_b = Arc::new(FakeLanguageModel::with_id_and_thinking(
2669 "fake", "model-b", "Model B", false,
2670 ));
2671 thread.update(cx, |thread, cx| {
2672 thread.set_model(model_b.clone(), cx);
2673 });
2674 let events_2 = thread
2675 .update(cx, |thread, cx| {
2676 thread.send(UserMessageId::new(), ["Continue"], cx)
2677 })
2678 .unwrap();
2679 cx.run_until_parked();
2680
2681 // model_b should have received its completion request.
2682 assert_eq!(model_b.as_fake().completion_count(), 1);
2683
2684 // Advance the clock well past the retry delay (BASE_RETRY_DELAY = 5s).
2685 cx.executor().advance_clock(Duration::from_secs(10));
2686 cx.run_until_parked();
2687
2688 // model_a must NOT have received another completion request — the cancelled turn
2689 // should have exited during the retry delay rather than retrying with the old model.
2690 assert_eq!(
2691 model_a.completion_count(),
2692 0,
2693 "old model should not receive a retry request after cancellation"
2694 );
2695
2696 // Complete model_b's turn.
2697 model_b
2698 .as_fake()
2699 .send_last_completion_stream_text_chunk("Done!");
2700 model_b
2701 .as_fake()
2702 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2703 model_b.as_fake().end_last_completion_stream();
2704
2705 let events_1 = events_1.collect::<Vec<_>>().await;
2706 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2707
2708 let events_2 = events_2.collect::<Vec<_>>().await;
2709 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2710}
2711
2712#[gpui::test]
2713async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2714 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2715 let fake_model = model.as_fake();
2716
2717 let events_1 = thread
2718 .update(cx, |thread, cx| {
2719 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2720 })
2721 .unwrap();
2722 cx.run_until_parked();
2723 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2724 fake_model
2725 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2726 fake_model.end_last_completion_stream();
2727 let events_1 = events_1.collect::<Vec<_>>().await;
2728
2729 let events_2 = thread
2730 .update(cx, |thread, cx| {
2731 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2732 })
2733 .unwrap();
2734 cx.run_until_parked();
2735 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2736 fake_model
2737 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2738 fake_model.end_last_completion_stream();
2739 let events_2 = events_2.collect::<Vec<_>>().await;
2740
2741 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2742 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2743}
2744
2745#[gpui::test]
2746async fn test_refusal(cx: &mut TestAppContext) {
2747 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2748 let fake_model = model.as_fake();
2749
2750 let events = thread
2751 .update(cx, |thread, cx| {
2752 thread.send(UserMessageId::new(), ["Hello"], cx)
2753 })
2754 .unwrap();
2755 cx.run_until_parked();
2756 thread.read_with(cx, |thread, _| {
2757 assert_eq!(
2758 thread.to_markdown(),
2759 indoc! {"
2760 ## User
2761
2762 Hello
2763 "}
2764 );
2765 });
2766
2767 fake_model.send_last_completion_stream_text_chunk("Hey!");
2768 cx.run_until_parked();
2769 thread.read_with(cx, |thread, _| {
2770 assert_eq!(
2771 thread.to_markdown(),
2772 indoc! {"
2773 ## User
2774
2775 Hello
2776
2777 ## Assistant
2778
2779 Hey!
2780 "}
2781 );
2782 });
2783
2784 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2785 fake_model
2786 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2787 let events = events.collect::<Vec<_>>().await;
2788 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2789 thread.read_with(cx, |thread, _| {
2790 assert_eq!(thread.to_markdown(), "");
2791 });
2792}
2793
2794#[gpui::test]
2795async fn test_truncate_first_message(cx: &mut TestAppContext) {
2796 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2797 let fake_model = model.as_fake();
2798
2799 let message_id = UserMessageId::new();
2800 thread
2801 .update(cx, |thread, cx| {
2802 thread.send(message_id.clone(), ["Hello"], cx)
2803 })
2804 .unwrap();
2805 cx.run_until_parked();
2806 thread.read_with(cx, |thread, _| {
2807 assert_eq!(
2808 thread.to_markdown(),
2809 indoc! {"
2810 ## User
2811
2812 Hello
2813 "}
2814 );
2815 assert_eq!(thread.latest_token_usage(), None);
2816 });
2817
2818 fake_model.send_last_completion_stream_text_chunk("Hey!");
2819 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2820 language_model::TokenUsage {
2821 input_tokens: 32_000,
2822 output_tokens: 16_000,
2823 cache_creation_input_tokens: 0,
2824 cache_read_input_tokens: 0,
2825 },
2826 ));
2827 cx.run_until_parked();
2828 thread.read_with(cx, |thread, _| {
2829 assert_eq!(
2830 thread.to_markdown(),
2831 indoc! {"
2832 ## User
2833
2834 Hello
2835
2836 ## Assistant
2837
2838 Hey!
2839 "}
2840 );
2841 assert_eq!(
2842 thread.latest_token_usage(),
2843 Some(acp_thread::TokenUsage {
2844 used_tokens: 32_000 + 16_000,
2845 max_tokens: 1_000_000,
2846 max_output_tokens: None,
2847 input_tokens: 32_000,
2848 output_tokens: 16_000,
2849 })
2850 );
2851 });
2852
2853 thread
2854 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2855 .unwrap();
2856 cx.run_until_parked();
2857 thread.read_with(cx, |thread, _| {
2858 assert_eq!(thread.to_markdown(), "");
2859 assert_eq!(thread.latest_token_usage(), None);
2860 });
2861
2862 // Ensure we can still send a new message after truncation.
2863 thread
2864 .update(cx, |thread, cx| {
2865 thread.send(UserMessageId::new(), ["Hi"], cx)
2866 })
2867 .unwrap();
2868 thread.update(cx, |thread, _cx| {
2869 assert_eq!(
2870 thread.to_markdown(),
2871 indoc! {"
2872 ## User
2873
2874 Hi
2875 "}
2876 );
2877 });
2878 cx.run_until_parked();
2879 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2880 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2881 language_model::TokenUsage {
2882 input_tokens: 40_000,
2883 output_tokens: 20_000,
2884 cache_creation_input_tokens: 0,
2885 cache_read_input_tokens: 0,
2886 },
2887 ));
2888 cx.run_until_parked();
2889 thread.read_with(cx, |thread, _| {
2890 assert_eq!(
2891 thread.to_markdown(),
2892 indoc! {"
2893 ## User
2894
2895 Hi
2896
2897 ## Assistant
2898
2899 Ahoy!
2900 "}
2901 );
2902
2903 assert_eq!(
2904 thread.latest_token_usage(),
2905 Some(acp_thread::TokenUsage {
2906 used_tokens: 40_000 + 20_000,
2907 max_tokens: 1_000_000,
2908 max_output_tokens: None,
2909 input_tokens: 40_000,
2910 output_tokens: 20_000,
2911 })
2912 );
2913 });
2914}
2915
2916#[gpui::test]
2917async fn test_truncate_second_message(cx: &mut TestAppContext) {
2918 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2919 let fake_model = model.as_fake();
2920
2921 thread
2922 .update(cx, |thread, cx| {
2923 thread.send(UserMessageId::new(), ["Message 1"], cx)
2924 })
2925 .unwrap();
2926 cx.run_until_parked();
2927 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2928 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2929 language_model::TokenUsage {
2930 input_tokens: 32_000,
2931 output_tokens: 16_000,
2932 cache_creation_input_tokens: 0,
2933 cache_read_input_tokens: 0,
2934 },
2935 ));
2936 fake_model.end_last_completion_stream();
2937 cx.run_until_parked();
2938
2939 let assert_first_message_state = |cx: &mut TestAppContext| {
2940 thread.clone().read_with(cx, |thread, _| {
2941 assert_eq!(
2942 thread.to_markdown(),
2943 indoc! {"
2944 ## User
2945
2946 Message 1
2947
2948 ## Assistant
2949
2950 Message 1 response
2951 "}
2952 );
2953
2954 assert_eq!(
2955 thread.latest_token_usage(),
2956 Some(acp_thread::TokenUsage {
2957 used_tokens: 32_000 + 16_000,
2958 max_tokens: 1_000_000,
2959 max_output_tokens: None,
2960 input_tokens: 32_000,
2961 output_tokens: 16_000,
2962 })
2963 );
2964 });
2965 };
2966
2967 assert_first_message_state(cx);
2968
2969 let second_message_id = UserMessageId::new();
2970 thread
2971 .update(cx, |thread, cx| {
2972 thread.send(second_message_id.clone(), ["Message 2"], cx)
2973 })
2974 .unwrap();
2975 cx.run_until_parked();
2976
2977 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2978 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2979 language_model::TokenUsage {
2980 input_tokens: 40_000,
2981 output_tokens: 20_000,
2982 cache_creation_input_tokens: 0,
2983 cache_read_input_tokens: 0,
2984 },
2985 ));
2986 fake_model.end_last_completion_stream();
2987 cx.run_until_parked();
2988
2989 thread.read_with(cx, |thread, _| {
2990 assert_eq!(
2991 thread.to_markdown(),
2992 indoc! {"
2993 ## User
2994
2995 Message 1
2996
2997 ## Assistant
2998
2999 Message 1 response
3000
3001 ## User
3002
3003 Message 2
3004
3005 ## Assistant
3006
3007 Message 2 response
3008 "}
3009 );
3010
3011 assert_eq!(
3012 thread.latest_token_usage(),
3013 Some(acp_thread::TokenUsage {
3014 used_tokens: 40_000 + 20_000,
3015 max_tokens: 1_000_000,
3016 max_output_tokens: None,
3017 input_tokens: 40_000,
3018 output_tokens: 20_000,
3019 })
3020 );
3021 });
3022
3023 thread
3024 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
3025 .unwrap();
3026 cx.run_until_parked();
3027
3028 assert_first_message_state(cx);
3029}
3030
3031#[gpui::test]
3032async fn test_title_generation(cx: &mut TestAppContext) {
3033 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3034 let fake_model = model.as_fake();
3035
3036 let summary_model = Arc::new(FakeLanguageModel::default());
3037 thread.update(cx, |thread, cx| {
3038 thread.set_summarization_model(Some(summary_model.clone()), cx)
3039 });
3040
3041 let send = thread
3042 .update(cx, |thread, cx| {
3043 thread.send(UserMessageId::new(), ["Hello"], cx)
3044 })
3045 .unwrap();
3046 cx.run_until_parked();
3047
3048 fake_model.send_last_completion_stream_text_chunk("Hey!");
3049 fake_model.end_last_completion_stream();
3050 cx.run_until_parked();
3051 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
3052
3053 // Ensure the summary model has been invoked to generate a title.
3054 summary_model.send_last_completion_stream_text_chunk("Hello ");
3055 summary_model.send_last_completion_stream_text_chunk("world\nG");
3056 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
3057 summary_model.end_last_completion_stream();
3058 send.collect::<Vec<_>>().await;
3059 cx.run_until_parked();
3060 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3061
3062 // Send another message, ensuring no title is generated this time.
3063 let send = thread
3064 .update(cx, |thread, cx| {
3065 thread.send(UserMessageId::new(), ["Hello again"], cx)
3066 })
3067 .unwrap();
3068 cx.run_until_parked();
3069 fake_model.send_last_completion_stream_text_chunk("Hey again!");
3070 fake_model.end_last_completion_stream();
3071 cx.run_until_parked();
3072 assert_eq!(summary_model.pending_completions(), Vec::new());
3073 send.collect::<Vec<_>>().await;
3074 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
3075}
3076
3077#[gpui::test]
3078async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
3079 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3080 let fake_model = model.as_fake();
3081
3082 let _events = thread
3083 .update(cx, |thread, cx| {
3084 thread.add_tool(ToolRequiringPermission);
3085 thread.add_tool(EchoTool);
3086 thread.send(UserMessageId::new(), ["Hey!"], cx)
3087 })
3088 .unwrap();
3089 cx.run_until_parked();
3090
3091 let permission_tool_use = LanguageModelToolUse {
3092 id: "tool_id_1".into(),
3093 name: ToolRequiringPermission::NAME.into(),
3094 raw_input: "{}".into(),
3095 input: json!({}),
3096 is_input_complete: true,
3097 thought_signature: None,
3098 };
3099 let echo_tool_use = LanguageModelToolUse {
3100 id: "tool_id_2".into(),
3101 name: EchoTool::NAME.into(),
3102 raw_input: json!({"text": "test"}).to_string(),
3103 input: json!({"text": "test"}),
3104 is_input_complete: true,
3105 thought_signature: None,
3106 };
3107 fake_model.send_last_completion_stream_text_chunk("Hi!");
3108 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3109 permission_tool_use,
3110 ));
3111 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3112 echo_tool_use.clone(),
3113 ));
3114 fake_model.end_last_completion_stream();
3115 cx.run_until_parked();
3116
3117 // Ensure pending tools are skipped when building a request.
3118 let request = thread
3119 .read_with(cx, |thread, cx| {
3120 thread.build_completion_request(CompletionIntent::EditFile, cx)
3121 })
3122 .unwrap();
3123 assert_eq!(
3124 request.messages[1..],
3125 vec![
3126 LanguageModelRequestMessage {
3127 role: Role::User,
3128 content: vec!["Hey!".into()],
3129 cache: true,
3130 reasoning_details: None,
3131 },
3132 LanguageModelRequestMessage {
3133 role: Role::Assistant,
3134 content: vec![
3135 MessageContent::Text("Hi!".into()),
3136 MessageContent::ToolUse(echo_tool_use.clone())
3137 ],
3138 cache: false,
3139 reasoning_details: None,
3140 },
3141 LanguageModelRequestMessage {
3142 role: Role::User,
3143 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3144 tool_use_id: echo_tool_use.id.clone(),
3145 tool_name: echo_tool_use.name,
3146 is_error: false,
3147 content: "test".into(),
3148 output: Some("test".into())
3149 })],
3150 cache: false,
3151 reasoning_details: None,
3152 },
3153 ],
3154 );
3155}
3156
3157#[gpui::test]
3158async fn test_agent_connection(cx: &mut TestAppContext) {
3159 cx.update(settings::init);
3160 let templates = Templates::new();
3161
3162 // Initialize language model system with test provider
3163 cx.update(|cx| {
3164 gpui_tokio::init(cx);
3165
3166 let http_client = FakeHttpClient::with_404_response();
3167 let clock = Arc::new(clock::FakeSystemClock::new());
3168 let client = Client::new(clock, http_client, cx);
3169 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3170 language_model::init(client.clone(), cx);
3171 language_models::init(user_store, client.clone(), cx);
3172 LanguageModelRegistry::test(cx);
3173 });
3174 cx.executor().forbid_parking();
3175
3176 // Create a project for new_thread
3177 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3178 fake_fs.insert_tree(path!("/test"), json!({})).await;
3179 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3180 let cwd = Path::new("/test");
3181 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3182
3183 // Create agent and connection
3184 let agent = NativeAgent::new(
3185 project.clone(),
3186 thread_store,
3187 templates.clone(),
3188 None,
3189 fake_fs.clone(),
3190 &mut cx.to_async(),
3191 )
3192 .await
3193 .unwrap();
3194 let connection = NativeAgentConnection(agent.clone());
3195
3196 // Create a thread using new_thread
3197 let connection_rc = Rc::new(connection.clone());
3198 let acp_thread = cx
3199 .update(|cx| connection_rc.new_session(project, cwd, cx))
3200 .await
3201 .expect("new_thread should succeed");
3202
3203 // Get the session_id from the AcpThread
3204 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3205
3206 // Test model_selector returns Some
3207 let selector_opt = connection.model_selector(&session_id);
3208 assert!(
3209 selector_opt.is_some(),
3210 "agent should always support ModelSelector"
3211 );
3212 let selector = selector_opt.unwrap();
3213
3214 // Test list_models
3215 let listed_models = cx
3216 .update(|cx| selector.list_models(cx))
3217 .await
3218 .expect("list_models should succeed");
3219 let AgentModelList::Grouped(listed_models) = listed_models else {
3220 panic!("Unexpected model list type");
3221 };
3222 assert!(!listed_models.is_empty(), "should have at least one model");
3223 assert_eq!(
3224 listed_models[&AgentModelGroupName("Fake".into())][0]
3225 .id
3226 .0
3227 .as_ref(),
3228 "fake/fake"
3229 );
3230
3231 // Test selected_model returns the default
3232 let model = cx
3233 .update(|cx| selector.selected_model(cx))
3234 .await
3235 .expect("selected_model should succeed");
3236 let model = cx
3237 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3238 .unwrap();
3239 let model = model.as_fake();
3240 assert_eq!(model.id().0, "fake", "should return default model");
3241
3242 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3243 cx.run_until_parked();
3244 model.send_last_completion_stream_text_chunk("def");
3245 cx.run_until_parked();
3246 acp_thread.read_with(cx, |thread, cx| {
3247 assert_eq!(
3248 thread.to_markdown(cx),
3249 indoc! {"
3250 ## User
3251
3252 abc
3253
3254 ## Assistant
3255
3256 def
3257
3258 "}
3259 )
3260 });
3261
3262 // Test cancel
3263 cx.update(|cx| connection.cancel(&session_id, cx));
3264 request.await.expect("prompt should fail gracefully");
3265
3266 // Explicitly close the session and drop the ACP thread.
3267 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3268 .await
3269 .unwrap();
3270 drop(acp_thread);
3271 let result = cx
3272 .update(|cx| {
3273 connection.prompt(
3274 Some(acp_thread::UserMessageId::new()),
3275 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3276 cx,
3277 )
3278 })
3279 .await;
3280 assert_eq!(
3281 result.as_ref().unwrap_err().to_string(),
3282 "Session not found",
3283 "unexpected result: {:?}",
3284 result
3285 );
3286}
3287
3288#[gpui::test]
3289async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3290 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3291 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
3292 let fake_model = model.as_fake();
3293
3294 let mut events = thread
3295 .update(cx, |thread, cx| {
3296 thread.send(UserMessageId::new(), ["Echo something"], cx)
3297 })
3298 .unwrap();
3299 cx.run_until_parked();
3300
3301 // Simulate streaming partial input.
3302 let input = json!({});
3303 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3304 LanguageModelToolUse {
3305 id: "1".into(),
3306 name: EchoTool::NAME.into(),
3307 raw_input: input.to_string(),
3308 input,
3309 is_input_complete: false,
3310 thought_signature: None,
3311 },
3312 ));
3313
3314 // Input streaming completed
3315 let input = json!({ "text": "Hello!" });
3316 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3317 LanguageModelToolUse {
3318 id: "1".into(),
3319 name: "echo".into(),
3320 raw_input: input.to_string(),
3321 input,
3322 is_input_complete: true,
3323 thought_signature: None,
3324 },
3325 ));
3326 fake_model.end_last_completion_stream();
3327 cx.run_until_parked();
3328
3329 let tool_call = expect_tool_call(&mut events).await;
3330 assert_eq!(
3331 tool_call,
3332 acp::ToolCall::new("1", "Echo")
3333 .raw_input(json!({}))
3334 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3335 );
3336 let update = expect_tool_call_update_fields(&mut events).await;
3337 assert_eq!(
3338 update,
3339 acp::ToolCallUpdate::new(
3340 "1",
3341 acp::ToolCallUpdateFields::new()
3342 .title("Echo")
3343 .kind(acp::ToolKind::Other)
3344 .raw_input(json!({ "text": "Hello!"}))
3345 )
3346 );
3347 let update = expect_tool_call_update_fields(&mut events).await;
3348 assert_eq!(
3349 update,
3350 acp::ToolCallUpdate::new(
3351 "1",
3352 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3353 )
3354 );
3355 let update = expect_tool_call_update_fields(&mut events).await;
3356 assert_eq!(
3357 update,
3358 acp::ToolCallUpdate::new(
3359 "1",
3360 acp::ToolCallUpdateFields::new()
3361 .status(acp::ToolCallStatus::Completed)
3362 .raw_output("Hello!")
3363 )
3364 );
3365}
3366
3367#[gpui::test]
3368async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3369 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3370 let fake_model = model.as_fake();
3371
3372 let mut events = thread
3373 .update(cx, |thread, cx| {
3374 thread.send(UserMessageId::new(), ["Hello!"], cx)
3375 })
3376 .unwrap();
3377 cx.run_until_parked();
3378
3379 fake_model.send_last_completion_stream_text_chunk("Hey!");
3380 fake_model.end_last_completion_stream();
3381
3382 let mut retry_events = Vec::new();
3383 while let Some(Ok(event)) = events.next().await {
3384 match event {
3385 ThreadEvent::Retry(retry_status) => {
3386 retry_events.push(retry_status);
3387 }
3388 ThreadEvent::Stop(..) => break,
3389 _ => {}
3390 }
3391 }
3392
3393 assert_eq!(retry_events.len(), 0);
3394 thread.read_with(cx, |thread, _cx| {
3395 assert_eq!(
3396 thread.to_markdown(),
3397 indoc! {"
3398 ## User
3399
3400 Hello!
3401
3402 ## Assistant
3403
3404 Hey!
3405 "}
3406 )
3407 });
3408}
3409
3410#[gpui::test]
3411async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3412 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3413 let fake_model = model.as_fake();
3414
3415 let mut events = thread
3416 .update(cx, |thread, cx| {
3417 thread.send(UserMessageId::new(), ["Hello!"], cx)
3418 })
3419 .unwrap();
3420 cx.run_until_parked();
3421
3422 fake_model.send_last_completion_stream_text_chunk("Hey,");
3423 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3424 provider: LanguageModelProviderName::new("Anthropic"),
3425 retry_after: Some(Duration::from_secs(3)),
3426 });
3427 fake_model.end_last_completion_stream();
3428
3429 cx.executor().advance_clock(Duration::from_secs(3));
3430 cx.run_until_parked();
3431
3432 fake_model.send_last_completion_stream_text_chunk("there!");
3433 fake_model.end_last_completion_stream();
3434 cx.run_until_parked();
3435
3436 let mut retry_events = Vec::new();
3437 while let Some(Ok(event)) = events.next().await {
3438 match event {
3439 ThreadEvent::Retry(retry_status) => {
3440 retry_events.push(retry_status);
3441 }
3442 ThreadEvent::Stop(..) => break,
3443 _ => {}
3444 }
3445 }
3446
3447 assert_eq!(retry_events.len(), 1);
3448 assert!(matches!(
3449 retry_events[0],
3450 acp_thread::RetryStatus { attempt: 1, .. }
3451 ));
3452 thread.read_with(cx, |thread, _cx| {
3453 assert_eq!(
3454 thread.to_markdown(),
3455 indoc! {"
3456 ## User
3457
3458 Hello!
3459
3460 ## Assistant
3461
3462 Hey,
3463
3464 [resume]
3465
3466 ## Assistant
3467
3468 there!
3469 "}
3470 )
3471 });
3472}
3473
3474#[gpui::test]
3475async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3476 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3477 let fake_model = model.as_fake();
3478
3479 let events = thread
3480 .update(cx, |thread, cx| {
3481 thread.add_tool(EchoTool);
3482 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3483 })
3484 .unwrap();
3485 cx.run_until_parked();
3486
3487 let tool_use_1 = LanguageModelToolUse {
3488 id: "tool_1".into(),
3489 name: EchoTool::NAME.into(),
3490 raw_input: json!({"text": "test"}).to_string(),
3491 input: json!({"text": "test"}),
3492 is_input_complete: true,
3493 thought_signature: None,
3494 };
3495 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3496 tool_use_1.clone(),
3497 ));
3498 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3499 provider: LanguageModelProviderName::new("Anthropic"),
3500 retry_after: Some(Duration::from_secs(3)),
3501 });
3502 fake_model.end_last_completion_stream();
3503
3504 cx.executor().advance_clock(Duration::from_secs(3));
3505 let completion = fake_model.pending_completions().pop().unwrap();
3506 assert_eq!(
3507 completion.messages[1..],
3508 vec![
3509 LanguageModelRequestMessage {
3510 role: Role::User,
3511 content: vec!["Call the echo tool!".into()],
3512 cache: false,
3513 reasoning_details: None,
3514 },
3515 LanguageModelRequestMessage {
3516 role: Role::Assistant,
3517 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3518 cache: false,
3519 reasoning_details: None,
3520 },
3521 LanguageModelRequestMessage {
3522 role: Role::User,
3523 content: vec![language_model::MessageContent::ToolResult(
3524 LanguageModelToolResult {
3525 tool_use_id: tool_use_1.id.clone(),
3526 tool_name: tool_use_1.name.clone(),
3527 is_error: false,
3528 content: "test".into(),
3529 output: Some("test".into())
3530 }
3531 )],
3532 cache: true,
3533 reasoning_details: None,
3534 },
3535 ]
3536 );
3537
3538 fake_model.send_last_completion_stream_text_chunk("Done");
3539 fake_model.end_last_completion_stream();
3540 cx.run_until_parked();
3541 events.collect::<Vec<_>>().await;
3542 thread.read_with(cx, |thread, _cx| {
3543 assert_eq!(
3544 thread.last_received_or_pending_message(),
3545 Some(Message::Agent(AgentMessage {
3546 content: vec![AgentMessageContent::Text("Done".into())],
3547 tool_results: IndexMap::default(),
3548 reasoning_details: None,
3549 }))
3550 );
3551 })
3552}
3553
3554#[gpui::test]
3555async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3556 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3557 let fake_model = model.as_fake();
3558
3559 let mut events = thread
3560 .update(cx, |thread, cx| {
3561 thread.send(UserMessageId::new(), ["Hello!"], cx)
3562 })
3563 .unwrap();
3564 cx.run_until_parked();
3565
3566 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3567 fake_model.send_last_completion_stream_error(
3568 LanguageModelCompletionError::ServerOverloaded {
3569 provider: LanguageModelProviderName::new("Anthropic"),
3570 retry_after: Some(Duration::from_secs(3)),
3571 },
3572 );
3573 fake_model.end_last_completion_stream();
3574 cx.executor().advance_clock(Duration::from_secs(3));
3575 cx.run_until_parked();
3576 }
3577
3578 let mut errors = Vec::new();
3579 let mut retry_events = Vec::new();
3580 while let Some(event) = events.next().await {
3581 match event {
3582 Ok(ThreadEvent::Retry(retry_status)) => {
3583 retry_events.push(retry_status);
3584 }
3585 Ok(ThreadEvent::Stop(..)) => break,
3586 Err(error) => errors.push(error),
3587 _ => {}
3588 }
3589 }
3590
3591 assert_eq!(
3592 retry_events.len(),
3593 crate::thread::MAX_RETRY_ATTEMPTS as usize
3594 );
3595 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3596 assert_eq!(retry_events[i].attempt, i + 1);
3597 }
3598 assert_eq!(errors.len(), 1);
3599 let error = errors[0]
3600 .downcast_ref::<LanguageModelCompletionError>()
3601 .unwrap();
3602 assert!(matches!(
3603 error,
3604 LanguageModelCompletionError::ServerOverloaded { .. }
3605 ));
3606}
3607
3608/// Filters out the stop events for asserting against in tests
3609fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3610 result_events
3611 .into_iter()
3612 .filter_map(|event| match event.unwrap() {
3613 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3614 _ => None,
3615 })
3616 .collect()
3617}
3618
3619struct ThreadTest {
3620 model: Arc<dyn LanguageModel>,
3621 thread: Entity<Thread>,
3622 project_context: Entity<ProjectContext>,
3623 context_server_store: Entity<ContextServerStore>,
3624 fs: Arc<FakeFs>,
3625}
3626
3627enum TestModel {
3628 Sonnet4,
3629 Fake,
3630}
3631
3632impl TestModel {
3633 fn id(&self) -> LanguageModelId {
3634 match self {
3635 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3636 TestModel::Fake => unreachable!(),
3637 }
3638 }
3639}
3640
3641async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3642 cx.executor().allow_parking();
3643
3644 let fs = FakeFs::new(cx.background_executor.clone());
3645 fs.create_dir(paths::settings_file().parent().unwrap())
3646 .await
3647 .unwrap();
3648 fs.insert_file(
3649 paths::settings_file(),
3650 json!({
3651 "agent": {
3652 "default_profile": "test-profile",
3653 "profiles": {
3654 "test-profile": {
3655 "name": "Test Profile",
3656 "tools": {
3657 EchoTool::NAME: true,
3658 DelayTool::NAME: true,
3659 WordListTool::NAME: true,
3660 ToolRequiringPermission::NAME: true,
3661 InfiniteTool::NAME: true,
3662 CancellationAwareTool::NAME: true,
3663 (TerminalTool::NAME): true,
3664 }
3665 }
3666 }
3667 }
3668 })
3669 .to_string()
3670 .into_bytes(),
3671 )
3672 .await;
3673
3674 cx.update(|cx| {
3675 settings::init(cx);
3676
3677 match model {
3678 TestModel::Fake => {}
3679 TestModel::Sonnet4 => {
3680 gpui_tokio::init(cx);
3681 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3682 cx.set_http_client(Arc::new(http_client));
3683 let client = Client::production(cx);
3684 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3685 language_model::init(client.clone(), cx);
3686 language_models::init(user_store, client.clone(), cx);
3687 }
3688 };
3689
3690 watch_settings(fs.clone(), cx);
3691 });
3692
3693 let templates = Templates::new();
3694
3695 fs.insert_tree(path!("/test"), json!({})).await;
3696 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3697
3698 let model = cx
3699 .update(|cx| {
3700 if let TestModel::Fake = model {
3701 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3702 } else {
3703 let model_id = model.id();
3704 let models = LanguageModelRegistry::read_global(cx);
3705 let model = models
3706 .available_models(cx)
3707 .find(|model| model.id() == model_id)
3708 .unwrap();
3709
3710 let provider = models.provider(&model.provider_id()).unwrap();
3711 let authenticated = provider.authenticate(cx);
3712
3713 cx.spawn(async move |_cx| {
3714 authenticated.await.unwrap();
3715 model
3716 })
3717 }
3718 })
3719 .await;
3720
3721 let project_context = cx.new(|_cx| ProjectContext::default());
3722 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3723 let context_server_registry =
3724 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3725 let thread = cx.new(|cx| {
3726 Thread::new(
3727 project,
3728 project_context.clone(),
3729 context_server_registry,
3730 templates,
3731 Some(model.clone()),
3732 cx,
3733 )
3734 });
3735 ThreadTest {
3736 model,
3737 thread,
3738 project_context,
3739 context_server_store,
3740 fs,
3741 }
3742}
3743
3744#[cfg(test)]
3745#[ctor::ctor]
3746fn init_logger() {
3747 if std::env::var("RUST_LOG").is_ok() {
3748 env_logger::init();
3749 }
3750}
3751
3752fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3753 let fs = fs.clone();
3754 cx.spawn({
3755 async move |cx| {
3756 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3757 cx.background_executor(),
3758 fs,
3759 paths::settings_file().clone(),
3760 );
3761 let _watcher_task = watcher_task;
3762
3763 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3764 cx.update(|cx| {
3765 SettingsStore::update_global(cx, |settings, cx| {
3766 settings.set_user_settings(&new_settings_content, cx)
3767 })
3768 })
3769 .ok();
3770 }
3771 }
3772 })
3773 .detach();
3774}
3775
3776fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3777 completion
3778 .tools
3779 .iter()
3780 .map(|tool| tool.name.clone())
3781 .collect()
3782}
3783
3784fn setup_context_server(
3785 name: &'static str,
3786 tools: Vec<context_server::types::Tool>,
3787 context_server_store: &Entity<ContextServerStore>,
3788 cx: &mut TestAppContext,
3789) -> mpsc::UnboundedReceiver<(
3790 context_server::types::CallToolParams,
3791 oneshot::Sender<context_server::types::CallToolResponse>,
3792)> {
3793 cx.update(|cx| {
3794 let mut settings = ProjectSettings::get_global(cx).clone();
3795 settings.context_servers.insert(
3796 name.into(),
3797 project::project_settings::ContextServerSettings::Stdio {
3798 enabled: true,
3799 remote: false,
3800 command: ContextServerCommand {
3801 path: "somebinary".into(),
3802 args: Vec::new(),
3803 env: None,
3804 timeout: None,
3805 },
3806 },
3807 );
3808 ProjectSettings::override_global(settings, cx);
3809 });
3810
3811 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3812 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3813 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3814 context_server::types::InitializeResponse {
3815 protocol_version: context_server::types::ProtocolVersion(
3816 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3817 ),
3818 server_info: context_server::types::Implementation {
3819 name: name.into(),
3820 version: "1.0.0".to_string(),
3821 },
3822 capabilities: context_server::types::ServerCapabilities {
3823 tools: Some(context_server::types::ToolsCapabilities {
3824 list_changed: Some(true),
3825 }),
3826 ..Default::default()
3827 },
3828 meta: None,
3829 }
3830 })
3831 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3832 let tools = tools.clone();
3833 async move {
3834 context_server::types::ListToolsResponse {
3835 tools,
3836 next_cursor: None,
3837 meta: None,
3838 }
3839 }
3840 })
3841 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3842 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3843 async move {
3844 let (response_tx, response_rx) = oneshot::channel();
3845 mcp_tool_calls_tx
3846 .unbounded_send((params, response_tx))
3847 .unwrap();
3848 response_rx.await.unwrap()
3849 }
3850 });
3851 context_server_store.update(cx, |store, cx| {
3852 store.start_server(
3853 Arc::new(ContextServer::new(
3854 ContextServerId(name.into()),
3855 Arc::new(fake_transport),
3856 )),
3857 cx,
3858 );
3859 });
3860 cx.run_until_parked();
3861 mcp_tool_calls_rx
3862}
3863
3864#[gpui::test]
3865async fn test_tokens_before_message(cx: &mut TestAppContext) {
3866 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3867 let fake_model = model.as_fake();
3868
3869 // First message
3870 let message_1_id = UserMessageId::new();
3871 thread
3872 .update(cx, |thread, cx| {
3873 thread.send(message_1_id.clone(), ["First message"], cx)
3874 })
3875 .unwrap();
3876 cx.run_until_parked();
3877
3878 // Before any response, tokens_before_message should return None for first message
3879 thread.read_with(cx, |thread, _| {
3880 assert_eq!(
3881 thread.tokens_before_message(&message_1_id),
3882 None,
3883 "First message should have no tokens before it"
3884 );
3885 });
3886
3887 // Complete first message with usage
3888 fake_model.send_last_completion_stream_text_chunk("Response 1");
3889 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3890 language_model::TokenUsage {
3891 input_tokens: 100,
3892 output_tokens: 50,
3893 cache_creation_input_tokens: 0,
3894 cache_read_input_tokens: 0,
3895 },
3896 ));
3897 fake_model.end_last_completion_stream();
3898 cx.run_until_parked();
3899
3900 // First message still has no tokens before it
3901 thread.read_with(cx, |thread, _| {
3902 assert_eq!(
3903 thread.tokens_before_message(&message_1_id),
3904 None,
3905 "First message should still have no tokens before it after response"
3906 );
3907 });
3908
3909 // Second message
3910 let message_2_id = UserMessageId::new();
3911 thread
3912 .update(cx, |thread, cx| {
3913 thread.send(message_2_id.clone(), ["Second message"], cx)
3914 })
3915 .unwrap();
3916 cx.run_until_parked();
3917
3918 // Second message should have first message's input tokens before it
3919 thread.read_with(cx, |thread, _| {
3920 assert_eq!(
3921 thread.tokens_before_message(&message_2_id),
3922 Some(100),
3923 "Second message should have 100 tokens before it (from first request)"
3924 );
3925 });
3926
3927 // Complete second message
3928 fake_model.send_last_completion_stream_text_chunk("Response 2");
3929 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3930 language_model::TokenUsage {
3931 input_tokens: 250, // Total for this request (includes previous context)
3932 output_tokens: 75,
3933 cache_creation_input_tokens: 0,
3934 cache_read_input_tokens: 0,
3935 },
3936 ));
3937 fake_model.end_last_completion_stream();
3938 cx.run_until_parked();
3939
3940 // Third message
3941 let message_3_id = UserMessageId::new();
3942 thread
3943 .update(cx, |thread, cx| {
3944 thread.send(message_3_id.clone(), ["Third message"], cx)
3945 })
3946 .unwrap();
3947 cx.run_until_parked();
3948
3949 // Third message should have second message's input tokens (250) before it
3950 thread.read_with(cx, |thread, _| {
3951 assert_eq!(
3952 thread.tokens_before_message(&message_3_id),
3953 Some(250),
3954 "Third message should have 250 tokens before it (from second request)"
3955 );
3956 // Second message should still have 100
3957 assert_eq!(
3958 thread.tokens_before_message(&message_2_id),
3959 Some(100),
3960 "Second message should still have 100 tokens before it"
3961 );
3962 // First message still has none
3963 assert_eq!(
3964 thread.tokens_before_message(&message_1_id),
3965 None,
3966 "First message should still have no tokens before it"
3967 );
3968 });
3969}
3970
3971#[gpui::test]
3972async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3973 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3974 let fake_model = model.as_fake();
3975
3976 // Set up three messages with responses
3977 let message_1_id = UserMessageId::new();
3978 thread
3979 .update(cx, |thread, cx| {
3980 thread.send(message_1_id.clone(), ["Message 1"], cx)
3981 })
3982 .unwrap();
3983 cx.run_until_parked();
3984 fake_model.send_last_completion_stream_text_chunk("Response 1");
3985 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3986 language_model::TokenUsage {
3987 input_tokens: 100,
3988 output_tokens: 50,
3989 cache_creation_input_tokens: 0,
3990 cache_read_input_tokens: 0,
3991 },
3992 ));
3993 fake_model.end_last_completion_stream();
3994 cx.run_until_parked();
3995
3996 let message_2_id = UserMessageId::new();
3997 thread
3998 .update(cx, |thread, cx| {
3999 thread.send(message_2_id.clone(), ["Message 2"], cx)
4000 })
4001 .unwrap();
4002 cx.run_until_parked();
4003 fake_model.send_last_completion_stream_text_chunk("Response 2");
4004 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
4005 language_model::TokenUsage {
4006 input_tokens: 250,
4007 output_tokens: 75,
4008 cache_creation_input_tokens: 0,
4009 cache_read_input_tokens: 0,
4010 },
4011 ));
4012 fake_model.end_last_completion_stream();
4013 cx.run_until_parked();
4014
4015 // Verify initial state
4016 thread.read_with(cx, |thread, _| {
4017 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
4018 });
4019
4020 // Truncate at message 2 (removes message 2 and everything after)
4021 thread
4022 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
4023 .unwrap();
4024 cx.run_until_parked();
4025
4026 // After truncation, message_2_id no longer exists, so lookup should return None
4027 thread.read_with(cx, |thread, _| {
4028 assert_eq!(
4029 thread.tokens_before_message(&message_2_id),
4030 None,
4031 "After truncation, message 2 no longer exists"
4032 );
4033 // Message 1 still exists but has no tokens before it
4034 assert_eq!(
4035 thread.tokens_before_message(&message_1_id),
4036 None,
4037 "First message still has no tokens before it"
4038 );
4039 });
4040}
4041
4042#[gpui::test]
4043async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
4044 init_test(cx);
4045
4046 let fs = FakeFs::new(cx.executor());
4047 fs.insert_tree("/root", json!({})).await;
4048 let project = Project::test(fs, ["/root".as_ref()], cx).await;
4049
4050 // Test 1: Deny rule blocks command
4051 {
4052 let environment = Rc::new(cx.update(|cx| {
4053 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4054 }));
4055
4056 cx.update(|cx| {
4057 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4058 settings.tool_permissions.tools.insert(
4059 TerminalTool::NAME.into(),
4060 agent_settings::ToolRules {
4061 default: Some(settings::ToolPermissionMode::Confirm),
4062 always_allow: vec![],
4063 always_deny: vec![
4064 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
4065 ],
4066 always_confirm: vec![],
4067 invalid_patterns: vec![],
4068 },
4069 );
4070 agent_settings::AgentSettings::override_global(settings, cx);
4071 });
4072
4073 #[allow(clippy::arc_with_non_send_sync)]
4074 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4075 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4076
4077 let task = cx.update(|cx| {
4078 tool.run(
4079 ToolInput::resolved(crate::TerminalToolInput {
4080 command: "rm -rf /".to_string(),
4081 cd: ".".to_string(),
4082 timeout_ms: None,
4083 }),
4084 event_stream,
4085 cx,
4086 )
4087 });
4088
4089 let result = task.await;
4090 assert!(
4091 result.is_err(),
4092 "expected command to be blocked by deny rule"
4093 );
4094 let err_msg = result.unwrap_err().to_lowercase();
4095 assert!(
4096 err_msg.contains("blocked"),
4097 "error should mention the command was blocked"
4098 );
4099 }
4100
4101 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4102 {
4103 let environment = Rc::new(cx.update(|cx| {
4104 FakeThreadEnvironment::default()
4105 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4106 }));
4107
4108 cx.update(|cx| {
4109 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4110 settings.tool_permissions.tools.insert(
4111 TerminalTool::NAME.into(),
4112 agent_settings::ToolRules {
4113 default: Some(settings::ToolPermissionMode::Deny),
4114 always_allow: vec![
4115 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4116 ],
4117 always_deny: vec![],
4118 always_confirm: vec![],
4119 invalid_patterns: vec![],
4120 },
4121 );
4122 agent_settings::AgentSettings::override_global(settings, cx);
4123 });
4124
4125 #[allow(clippy::arc_with_non_send_sync)]
4126 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4127 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4128
4129 let task = cx.update(|cx| {
4130 tool.run(
4131 ToolInput::resolved(crate::TerminalToolInput {
4132 command: "echo hello".to_string(),
4133 cd: ".".to_string(),
4134 timeout_ms: None,
4135 }),
4136 event_stream,
4137 cx,
4138 )
4139 });
4140
4141 let update = rx.expect_update_fields().await;
4142 assert!(
4143 update.content.iter().any(|blocks| {
4144 blocks
4145 .iter()
4146 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4147 }),
4148 "expected terminal content (allow rule should skip confirmation and override default deny)"
4149 );
4150
4151 let result = task.await;
4152 assert!(
4153 result.is_ok(),
4154 "expected command to succeed without confirmation"
4155 );
4156 }
4157
4158 // Test 3: global default: allow does NOT override always_confirm patterns
4159 {
4160 let environment = Rc::new(cx.update(|cx| {
4161 FakeThreadEnvironment::default()
4162 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4163 }));
4164
4165 cx.update(|cx| {
4166 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4167 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4168 settings.tool_permissions.tools.insert(
4169 TerminalTool::NAME.into(),
4170 agent_settings::ToolRules {
4171 default: Some(settings::ToolPermissionMode::Allow),
4172 always_allow: vec![],
4173 always_deny: vec![],
4174 always_confirm: vec![
4175 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4176 ],
4177 invalid_patterns: vec![],
4178 },
4179 );
4180 agent_settings::AgentSettings::override_global(settings, cx);
4181 });
4182
4183 #[allow(clippy::arc_with_non_send_sync)]
4184 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4185 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4186
4187 let _task = cx.update(|cx| {
4188 tool.run(
4189 ToolInput::resolved(crate::TerminalToolInput {
4190 command: "sudo rm file".to_string(),
4191 cd: ".".to_string(),
4192 timeout_ms: None,
4193 }),
4194 event_stream,
4195 cx,
4196 )
4197 });
4198
4199 // With global default: allow, confirm patterns are still respected
4200 // The expect_authorization() call will panic if no authorization is requested,
4201 // which validates that the confirm pattern still triggers confirmation
4202 let _auth = rx.expect_authorization().await;
4203
4204 drop(_task);
4205 }
4206
4207 // Test 4: tool-specific default: deny is respected even with global default: allow
4208 {
4209 let environment = Rc::new(cx.update(|cx| {
4210 FakeThreadEnvironment::default()
4211 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4212 }));
4213
4214 cx.update(|cx| {
4215 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4216 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4217 settings.tool_permissions.tools.insert(
4218 TerminalTool::NAME.into(),
4219 agent_settings::ToolRules {
4220 default: Some(settings::ToolPermissionMode::Deny),
4221 always_allow: vec![],
4222 always_deny: vec![],
4223 always_confirm: vec![],
4224 invalid_patterns: vec![],
4225 },
4226 );
4227 agent_settings::AgentSettings::override_global(settings, cx);
4228 });
4229
4230 #[allow(clippy::arc_with_non_send_sync)]
4231 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4232 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4233
4234 let task = cx.update(|cx| {
4235 tool.run(
4236 ToolInput::resolved(crate::TerminalToolInput {
4237 command: "echo hello".to_string(),
4238 cd: ".".to_string(),
4239 timeout_ms: None,
4240 }),
4241 event_stream,
4242 cx,
4243 )
4244 });
4245
4246 // tool-specific default: deny is respected even with global default: allow
4247 let result = task.await;
4248 assert!(
4249 result.is_err(),
4250 "expected command to be blocked by tool-specific deny default"
4251 );
4252 let err_msg = result.unwrap_err().to_lowercase();
4253 assert!(
4254 err_msg.contains("disabled"),
4255 "error should mention the tool is disabled, got: {err_msg}"
4256 );
4257 }
4258}
4259
4260#[gpui::test]
4261async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4262 init_test(cx);
4263 cx.update(|cx| {
4264 LanguageModelRegistry::test(cx);
4265 });
4266 cx.update(|cx| {
4267 cx.update_flags(true, vec!["subagents".to_string()]);
4268 });
4269
4270 let fs = FakeFs::new(cx.executor());
4271 fs.insert_tree(
4272 "/",
4273 json!({
4274 "a": {
4275 "b.md": "Lorem"
4276 }
4277 }),
4278 )
4279 .await;
4280 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4281 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4282 let agent = NativeAgent::new(
4283 project.clone(),
4284 thread_store.clone(),
4285 Templates::new(),
4286 None,
4287 fs.clone(),
4288 &mut cx.to_async(),
4289 )
4290 .await
4291 .unwrap();
4292 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4293
4294 let acp_thread = cx
4295 .update(|cx| {
4296 connection
4297 .clone()
4298 .new_session(project.clone(), Path::new(""), cx)
4299 })
4300 .await
4301 .unwrap();
4302 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4303 let thread = agent.read_with(cx, |agent, _| {
4304 agent.sessions.get(&session_id).unwrap().thread.clone()
4305 });
4306 let model = Arc::new(FakeLanguageModel::default());
4307
4308 // Ensure empty threads are not saved, even if they get mutated.
4309 thread.update(cx, |thread, cx| {
4310 thread.set_model(model.clone(), cx);
4311 });
4312 cx.run_until_parked();
4313
4314 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4315 cx.run_until_parked();
4316 model.send_last_completion_stream_text_chunk("spawning subagent");
4317 let subagent_tool_input = SpawnAgentToolInput {
4318 label: "label".to_string(),
4319 message: "subagent task prompt".to_string(),
4320 session_id: None,
4321 };
4322 let subagent_tool_use = LanguageModelToolUse {
4323 id: "subagent_1".into(),
4324 name: SpawnAgentTool::NAME.into(),
4325 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4326 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4327 is_input_complete: true,
4328 thought_signature: None,
4329 };
4330 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4331 subagent_tool_use,
4332 ));
4333 model.end_last_completion_stream();
4334
4335 cx.run_until_parked();
4336
4337 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4338 thread
4339 .running_subagent_ids(cx)
4340 .get(0)
4341 .expect("subagent thread should be running")
4342 .clone()
4343 });
4344
4345 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4346 agent
4347 .sessions
4348 .get(&subagent_session_id)
4349 .expect("subagent session should exist")
4350 .acp_thread
4351 .clone()
4352 });
4353
4354 model.send_last_completion_stream_text_chunk("subagent task response");
4355 model.end_last_completion_stream();
4356
4357 cx.run_until_parked();
4358
4359 assert_eq!(
4360 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4361 indoc! {"
4362 ## User
4363
4364 subagent task prompt
4365
4366 ## Assistant
4367
4368 subagent task response
4369
4370 "}
4371 );
4372
4373 model.send_last_completion_stream_text_chunk("Response");
4374 model.end_last_completion_stream();
4375
4376 send.await.unwrap();
4377
4378 assert_eq!(
4379 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4380 indoc! {r#"
4381 ## User
4382
4383 Prompt
4384
4385 ## Assistant
4386
4387 spawning subagent
4388
4389 **Tool Call: label**
4390 Status: Completed
4391
4392 subagent task response
4393
4394 ## Assistant
4395
4396 Response
4397
4398 "#},
4399 );
4400}
4401
4402#[gpui::test]
4403async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppContext) {
4404 init_test(cx);
4405 cx.update(|cx| {
4406 LanguageModelRegistry::test(cx);
4407 });
4408 cx.update(|cx| {
4409 cx.update_flags(true, vec!["subagents".to_string()]);
4410 });
4411
4412 let fs = FakeFs::new(cx.executor());
4413 fs.insert_tree(
4414 "/",
4415 json!({
4416 "a": {
4417 "b.md": "Lorem"
4418 }
4419 }),
4420 )
4421 .await;
4422 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4423 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4424 let agent = NativeAgent::new(
4425 project.clone(),
4426 thread_store.clone(),
4427 Templates::new(),
4428 None,
4429 fs.clone(),
4430 &mut cx.to_async(),
4431 )
4432 .await
4433 .unwrap();
4434 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4435
4436 let acp_thread = cx
4437 .update(|cx| {
4438 connection
4439 .clone()
4440 .new_session(project.clone(), Path::new(""), cx)
4441 })
4442 .await
4443 .unwrap();
4444 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4445 let thread = agent.read_with(cx, |agent, _| {
4446 agent.sessions.get(&session_id).unwrap().thread.clone()
4447 });
4448 let model = Arc::new(FakeLanguageModel::default());
4449
4450 // Ensure empty threads are not saved, even if they get mutated.
4451 thread.update(cx, |thread, cx| {
4452 thread.set_model(model.clone(), cx);
4453 });
4454 cx.run_until_parked();
4455
4456 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4457 cx.run_until_parked();
4458 model.send_last_completion_stream_text_chunk("spawning subagent");
4459 let subagent_tool_input = SpawnAgentToolInput {
4460 label: "label".to_string(),
4461 message: "subagent task prompt".to_string(),
4462 session_id: None,
4463 };
4464 let subagent_tool_use = LanguageModelToolUse {
4465 id: "subagent_1".into(),
4466 name: SpawnAgentTool::NAME.into(),
4467 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4468 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4469 is_input_complete: true,
4470 thought_signature: None,
4471 };
4472 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4473 subagent_tool_use,
4474 ));
4475 model.end_last_completion_stream();
4476
4477 cx.run_until_parked();
4478
4479 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4480 thread
4481 .running_subagent_ids(cx)
4482 .get(0)
4483 .expect("subagent thread should be running")
4484 .clone()
4485 });
4486
4487 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4488 agent
4489 .sessions
4490 .get(&subagent_session_id)
4491 .expect("subagent session should exist")
4492 .acp_thread
4493 .clone()
4494 });
4495
4496 model.send_last_completion_stream_text_chunk("subagent task response 1");
4497 model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
4498 text: "thinking more about the subagent task".into(),
4499 signature: None,
4500 });
4501 model.send_last_completion_stream_text_chunk("subagent task response 2");
4502 model.end_last_completion_stream();
4503
4504 cx.run_until_parked();
4505
4506 assert_eq!(
4507 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4508 indoc! {"
4509 ## User
4510
4511 subagent task prompt
4512
4513 ## Assistant
4514
4515 subagent task response 1
4516
4517 <thinking>
4518 thinking more about the subagent task
4519 </thinking>
4520
4521 subagent task response 2
4522
4523 "}
4524 );
4525
4526 model.send_last_completion_stream_text_chunk("Response");
4527 model.end_last_completion_stream();
4528
4529 send.await.unwrap();
4530
4531 assert_eq!(
4532 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4533 indoc! {r#"
4534 ## User
4535
4536 Prompt
4537
4538 ## Assistant
4539
4540 spawning subagent
4541
4542 **Tool Call: label**
4543 Status: Completed
4544
4545 subagent task response 1
4546
4547 subagent task response 2
4548
4549 ## Assistant
4550
4551 Response
4552
4553 "#},
4554 );
4555}
4556
4557#[gpui::test]
4558async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4559 init_test(cx);
4560 cx.update(|cx| {
4561 LanguageModelRegistry::test(cx);
4562 });
4563 cx.update(|cx| {
4564 cx.update_flags(true, vec!["subagents".to_string()]);
4565 });
4566
4567 let fs = FakeFs::new(cx.executor());
4568 fs.insert_tree(
4569 "/",
4570 json!({
4571 "a": {
4572 "b.md": "Lorem"
4573 }
4574 }),
4575 )
4576 .await;
4577 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4578 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4579 let agent = NativeAgent::new(
4580 project.clone(),
4581 thread_store.clone(),
4582 Templates::new(),
4583 None,
4584 fs.clone(),
4585 &mut cx.to_async(),
4586 )
4587 .await
4588 .unwrap();
4589 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4590
4591 let acp_thread = cx
4592 .update(|cx| {
4593 connection
4594 .clone()
4595 .new_session(project.clone(), Path::new(""), cx)
4596 })
4597 .await
4598 .unwrap();
4599 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4600 let thread = agent.read_with(cx, |agent, _| {
4601 agent.sessions.get(&session_id).unwrap().thread.clone()
4602 });
4603 let model = Arc::new(FakeLanguageModel::default());
4604
4605 // Ensure empty threads are not saved, even if they get mutated.
4606 thread.update(cx, |thread, cx| {
4607 thread.set_model(model.clone(), cx);
4608 });
4609 cx.run_until_parked();
4610
4611 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4612 cx.run_until_parked();
4613 model.send_last_completion_stream_text_chunk("spawning subagent");
4614 let subagent_tool_input = SpawnAgentToolInput {
4615 label: "label".to_string(),
4616 message: "subagent task prompt".to_string(),
4617 session_id: None,
4618 };
4619 let subagent_tool_use = LanguageModelToolUse {
4620 id: "subagent_1".into(),
4621 name: SpawnAgentTool::NAME.into(),
4622 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4623 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4624 is_input_complete: true,
4625 thought_signature: None,
4626 };
4627 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4628 subagent_tool_use,
4629 ));
4630 model.end_last_completion_stream();
4631
4632 cx.run_until_parked();
4633
4634 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4635 thread
4636 .running_subagent_ids(cx)
4637 .get(0)
4638 .expect("subagent thread should be running")
4639 .clone()
4640 });
4641 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4642 agent
4643 .sessions
4644 .get(&subagent_session_id)
4645 .expect("subagent session should exist")
4646 .acp_thread
4647 .clone()
4648 });
4649
4650 // model.send_last_completion_stream_text_chunk("subagent task response");
4651 // model.end_last_completion_stream();
4652
4653 // cx.run_until_parked();
4654
4655 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4656
4657 cx.run_until_parked();
4658
4659 send.await.unwrap();
4660
4661 acp_thread.read_with(cx, |thread, cx| {
4662 assert_eq!(thread.status(), ThreadStatus::Idle);
4663 assert_eq!(
4664 thread.to_markdown(cx),
4665 indoc! {"
4666 ## User
4667
4668 Prompt
4669
4670 ## Assistant
4671
4672 spawning subagent
4673
4674 **Tool Call: label**
4675 Status: Canceled
4676
4677 "}
4678 );
4679 });
4680 subagent_acp_thread.read_with(cx, |thread, cx| {
4681 assert_eq!(thread.status(), ThreadStatus::Idle);
4682 assert_eq!(
4683 thread.to_markdown(cx),
4684 indoc! {"
4685 ## User
4686
4687 subagent task prompt
4688
4689 "}
4690 );
4691 });
4692}
4693
4694#[gpui::test]
4695async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
4696 init_test(cx);
4697 cx.update(|cx| {
4698 LanguageModelRegistry::test(cx);
4699 });
4700 cx.update(|cx| {
4701 cx.update_flags(true, vec!["subagents".to_string()]);
4702 });
4703
4704 let fs = FakeFs::new(cx.executor());
4705 fs.insert_tree(
4706 "/",
4707 json!({
4708 "a": {
4709 "b.md": "Lorem"
4710 }
4711 }),
4712 )
4713 .await;
4714 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4715 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4716 let agent = NativeAgent::new(
4717 project.clone(),
4718 thread_store.clone(),
4719 Templates::new(),
4720 None,
4721 fs.clone(),
4722 &mut cx.to_async(),
4723 )
4724 .await
4725 .unwrap();
4726 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4727
4728 let acp_thread = cx
4729 .update(|cx| {
4730 connection
4731 .clone()
4732 .new_session(project.clone(), Path::new(""), cx)
4733 })
4734 .await
4735 .unwrap();
4736 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4737 let thread = agent.read_with(cx, |agent, _| {
4738 agent.sessions.get(&session_id).unwrap().thread.clone()
4739 });
4740 let model = Arc::new(FakeLanguageModel::default());
4741
4742 thread.update(cx, |thread, cx| {
4743 thread.set_model(model.clone(), cx);
4744 });
4745 cx.run_until_parked();
4746
4747 // === First turn: create subagent ===
4748 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
4749 cx.run_until_parked();
4750 model.send_last_completion_stream_text_chunk("spawning subagent");
4751 let subagent_tool_input = SpawnAgentToolInput {
4752 label: "initial task".to_string(),
4753 message: "do the first task".to_string(),
4754 session_id: None,
4755 };
4756 let subagent_tool_use = LanguageModelToolUse {
4757 id: "subagent_1".into(),
4758 name: SpawnAgentTool::NAME.into(),
4759 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4760 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4761 is_input_complete: true,
4762 thought_signature: None,
4763 };
4764 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4765 subagent_tool_use,
4766 ));
4767 model.end_last_completion_stream();
4768
4769 cx.run_until_parked();
4770
4771 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4772 thread
4773 .running_subagent_ids(cx)
4774 .get(0)
4775 .expect("subagent thread should be running")
4776 .clone()
4777 });
4778
4779 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4780 agent
4781 .sessions
4782 .get(&subagent_session_id)
4783 .expect("subagent session should exist")
4784 .acp_thread
4785 .clone()
4786 });
4787
4788 // Subagent responds
4789 model.send_last_completion_stream_text_chunk("first task response");
4790 model.end_last_completion_stream();
4791
4792 cx.run_until_parked();
4793
4794 // Parent model responds to complete first turn
4795 model.send_last_completion_stream_text_chunk("First response");
4796 model.end_last_completion_stream();
4797
4798 send.await.unwrap();
4799
4800 // Verify subagent is no longer running
4801 thread.read_with(cx, |thread, cx| {
4802 assert!(
4803 thread.running_subagent_ids(cx).is_empty(),
4804 "subagent should not be running after completion"
4805 );
4806 });
4807
4808 // === Second turn: resume subagent with session_id ===
4809 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
4810 cx.run_until_parked();
4811 model.send_last_completion_stream_text_chunk("resuming subagent");
4812 let resume_tool_input = SpawnAgentToolInput {
4813 label: "follow-up task".to_string(),
4814 message: "do the follow-up task".to_string(),
4815 session_id: Some(subagent_session_id.clone()),
4816 };
4817 let resume_tool_use = LanguageModelToolUse {
4818 id: "subagent_2".into(),
4819 name: SpawnAgentTool::NAME.into(),
4820 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
4821 input: serde_json::to_value(&resume_tool_input).unwrap(),
4822 is_input_complete: true,
4823 thought_signature: None,
4824 };
4825 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
4826 model.end_last_completion_stream();
4827
4828 cx.run_until_parked();
4829
4830 // Subagent should be running again with the same session
4831 thread.read_with(cx, |thread, cx| {
4832 let running = thread.running_subagent_ids(cx);
4833 assert_eq!(running.len(), 1, "subagent should be running");
4834 assert_eq!(running[0], subagent_session_id, "should be same session");
4835 });
4836
4837 // Subagent responds to follow-up
4838 model.send_last_completion_stream_text_chunk("follow-up task response");
4839 model.end_last_completion_stream();
4840
4841 cx.run_until_parked();
4842
4843 // Parent model responds to complete second turn
4844 model.send_last_completion_stream_text_chunk("Second response");
4845 model.end_last_completion_stream();
4846
4847 send2.await.unwrap();
4848
4849 // Verify subagent is no longer running
4850 thread.read_with(cx, |thread, cx| {
4851 assert!(
4852 thread.running_subagent_ids(cx).is_empty(),
4853 "subagent should not be running after resume completion"
4854 );
4855 });
4856
4857 // Verify the subagent's acp thread has both conversation turns
4858 assert_eq!(
4859 subagent_acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4860 indoc! {"
4861 ## User
4862
4863 do the first task
4864
4865 ## Assistant
4866
4867 first task response
4868
4869 ## User
4870
4871 do the follow-up task
4872
4873 ## Assistant
4874
4875 follow-up task response
4876
4877 "}
4878 );
4879}
4880
4881#[gpui::test]
4882async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4883 init_test(cx);
4884
4885 cx.update(|cx| {
4886 cx.update_flags(true, vec!["subagents".to_string()]);
4887 });
4888
4889 let fs = FakeFs::new(cx.executor());
4890 fs.insert_tree(path!("/test"), json!({})).await;
4891 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4892 let project_context = cx.new(|_cx| ProjectContext::default());
4893 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4894 let context_server_registry =
4895 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4896 let model = Arc::new(FakeLanguageModel::default());
4897
4898 let environment = Rc::new(cx.update(|cx| {
4899 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4900 }));
4901
4902 let thread = cx.new(|cx| {
4903 let mut thread = Thread::new(
4904 project.clone(),
4905 project_context,
4906 context_server_registry,
4907 Templates::new(),
4908 Some(model),
4909 cx,
4910 );
4911 thread.add_default_tools(environment, cx);
4912 thread
4913 });
4914
4915 thread.read_with(cx, |thread, _| {
4916 assert!(
4917 thread.has_registered_tool(SpawnAgentTool::NAME),
4918 "subagent tool should be present when feature flag is enabled"
4919 );
4920 });
4921}
4922
4923#[gpui::test]
4924async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4925 init_test(cx);
4926
4927 cx.update(|cx| {
4928 cx.update_flags(true, vec!["subagents".to_string()]);
4929 });
4930
4931 let fs = FakeFs::new(cx.executor());
4932 fs.insert_tree(path!("/test"), json!({})).await;
4933 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4934 let project_context = cx.new(|_cx| ProjectContext::default());
4935 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4936 let context_server_registry =
4937 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4938 let model = Arc::new(FakeLanguageModel::default());
4939
4940 let parent_thread = cx.new(|cx| {
4941 Thread::new(
4942 project.clone(),
4943 project_context,
4944 context_server_registry,
4945 Templates::new(),
4946 Some(model.clone()),
4947 cx,
4948 )
4949 });
4950
4951 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4952 subagent_thread.read_with(cx, |subagent_thread, cx| {
4953 assert!(subagent_thread.is_subagent());
4954 assert_eq!(subagent_thread.depth(), 1);
4955 assert_eq!(
4956 subagent_thread.model().map(|model| model.id()),
4957 Some(model.id())
4958 );
4959 assert_eq!(
4960 subagent_thread.parent_thread_id(),
4961 Some(parent_thread.read(cx).id().clone())
4962 );
4963 });
4964}
4965
4966#[gpui::test]
4967async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4968 init_test(cx);
4969
4970 cx.update(|cx| {
4971 cx.update_flags(true, vec!["subagents".to_string()]);
4972 });
4973
4974 let fs = FakeFs::new(cx.executor());
4975 fs.insert_tree(path!("/test"), json!({})).await;
4976 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4977 let project_context = cx.new(|_cx| ProjectContext::default());
4978 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4979 let context_server_registry =
4980 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4981 let model = Arc::new(FakeLanguageModel::default());
4982 let environment = Rc::new(cx.update(|cx| {
4983 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4984 }));
4985
4986 let deep_parent_thread = cx.new(|cx| {
4987 let mut thread = Thread::new(
4988 project.clone(),
4989 project_context,
4990 context_server_registry,
4991 Templates::new(),
4992 Some(model.clone()),
4993 cx,
4994 );
4995 thread.set_subagent_context(SubagentContext {
4996 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4997 depth: MAX_SUBAGENT_DEPTH - 1,
4998 });
4999 thread
5000 });
5001 let deep_subagent_thread = cx.new(|cx| {
5002 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
5003 thread.add_default_tools(environment, cx);
5004 thread
5005 });
5006
5007 deep_subagent_thread.read_with(cx, |thread, _| {
5008 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
5009 assert!(
5010 !thread.has_registered_tool(SpawnAgentTool::NAME),
5011 "subagent tool should not be present at max depth"
5012 );
5013 });
5014}
5015
5016#[gpui::test]
5017async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
5018 init_test(cx);
5019
5020 cx.update(|cx| {
5021 cx.update_flags(true, vec!["subagents".to_string()]);
5022 });
5023
5024 let fs = FakeFs::new(cx.executor());
5025 fs.insert_tree(path!("/test"), json!({})).await;
5026 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5027 let project_context = cx.new(|_cx| ProjectContext::default());
5028 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5029 let context_server_registry =
5030 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5031 let model = Arc::new(FakeLanguageModel::default());
5032
5033 let parent = cx.new(|cx| {
5034 Thread::new(
5035 project.clone(),
5036 project_context.clone(),
5037 context_server_registry.clone(),
5038 Templates::new(),
5039 Some(model.clone()),
5040 cx,
5041 )
5042 });
5043
5044 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
5045
5046 parent.update(cx, |thread, _cx| {
5047 thread.register_running_subagent(subagent.downgrade());
5048 });
5049
5050 subagent
5051 .update(cx, |thread, cx| {
5052 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
5053 })
5054 .unwrap();
5055 cx.run_until_parked();
5056
5057 subagent.read_with(cx, |thread, _| {
5058 assert!(!thread.is_turn_complete(), "subagent should be running");
5059 });
5060
5061 parent.update(cx, |thread, cx| {
5062 thread.cancel(cx).detach();
5063 });
5064
5065 subagent.read_with(cx, |thread, _| {
5066 assert!(
5067 thread.is_turn_complete(),
5068 "subagent should be cancelled when parent cancels"
5069 );
5070 });
5071}
5072
5073#[gpui::test]
5074async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
5075 init_test(cx);
5076 cx.update(|cx| {
5077 LanguageModelRegistry::test(cx);
5078 });
5079 cx.update(|cx| {
5080 cx.update_flags(true, vec!["subagents".to_string()]);
5081 });
5082
5083 let fs = FakeFs::new(cx.executor());
5084 fs.insert_tree(
5085 "/",
5086 json!({
5087 "a": {
5088 "b.md": "Lorem"
5089 }
5090 }),
5091 )
5092 .await;
5093 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5094 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5095 let agent = NativeAgent::new(
5096 project.clone(),
5097 thread_store.clone(),
5098 Templates::new(),
5099 None,
5100 fs.clone(),
5101 &mut cx.to_async(),
5102 )
5103 .await
5104 .unwrap();
5105 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5106
5107 let acp_thread = cx
5108 .update(|cx| {
5109 connection
5110 .clone()
5111 .new_session(project.clone(), Path::new(""), cx)
5112 })
5113 .await
5114 .unwrap();
5115 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5116 let thread = agent.read_with(cx, |agent, _| {
5117 agent.sessions.get(&session_id).unwrap().thread.clone()
5118 });
5119 let model = Arc::new(FakeLanguageModel::default());
5120
5121 thread.update(cx, |thread, cx| {
5122 thread.set_model(model.clone(), cx);
5123 });
5124 cx.run_until_parked();
5125
5126 // Start the parent turn
5127 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5128 cx.run_until_parked();
5129 model.send_last_completion_stream_text_chunk("spawning subagent");
5130 let subagent_tool_input = SpawnAgentToolInput {
5131 label: "label".to_string(),
5132 message: "subagent task prompt".to_string(),
5133 session_id: None,
5134 };
5135 let subagent_tool_use = LanguageModelToolUse {
5136 id: "subagent_1".into(),
5137 name: SpawnAgentTool::NAME.into(),
5138 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5139 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5140 is_input_complete: true,
5141 thought_signature: None,
5142 };
5143 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5144 subagent_tool_use,
5145 ));
5146 model.end_last_completion_stream();
5147
5148 cx.run_until_parked();
5149
5150 // Verify subagent is running
5151 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5152 thread
5153 .running_subagent_ids(cx)
5154 .get(0)
5155 .expect("subagent thread should be running")
5156 .clone()
5157 });
5158
5159 // Send a usage update that crosses the warning threshold (80% of 1,000,000)
5160 model.send_last_completion_stream_text_chunk("partial work");
5161 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5162 TokenUsage {
5163 input_tokens: 850_000,
5164 output_tokens: 0,
5165 cache_creation_input_tokens: 0,
5166 cache_read_input_tokens: 0,
5167 },
5168 ));
5169
5170 cx.run_until_parked();
5171
5172 // The subagent should no longer be running
5173 thread.read_with(cx, |thread, cx| {
5174 assert!(
5175 thread.running_subagent_ids(cx).is_empty(),
5176 "subagent should be stopped after context window warning"
5177 );
5178 });
5179
5180 // The parent model should get a new completion request to respond to the tool error
5181 model.send_last_completion_stream_text_chunk("Response after warning");
5182 model.end_last_completion_stream();
5183
5184 send.await.unwrap();
5185
5186 // Verify the parent thread shows the warning error in the tool call
5187 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5188 assert!(
5189 markdown.contains("nearing the end of its context window"),
5190 "tool output should contain context window warning message, got:\n{markdown}"
5191 );
5192 assert!(
5193 markdown.contains("Status: Failed"),
5194 "tool call should have Failed status, got:\n{markdown}"
5195 );
5196
5197 // Verify the subagent session still exists (can be resumed)
5198 agent.read_with(cx, |agent, _cx| {
5199 assert!(
5200 agent.sessions.contains_key(&subagent_session_id),
5201 "subagent session should still exist for potential resume"
5202 );
5203 });
5204}
5205
5206#[gpui::test]
5207async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
5208 init_test(cx);
5209 cx.update(|cx| {
5210 LanguageModelRegistry::test(cx);
5211 });
5212 cx.update(|cx| {
5213 cx.update_flags(true, vec!["subagents".to_string()]);
5214 });
5215
5216 let fs = FakeFs::new(cx.executor());
5217 fs.insert_tree(
5218 "/",
5219 json!({
5220 "a": {
5221 "b.md": "Lorem"
5222 }
5223 }),
5224 )
5225 .await;
5226 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5227 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5228 let agent = NativeAgent::new(
5229 project.clone(),
5230 thread_store.clone(),
5231 Templates::new(),
5232 None,
5233 fs.clone(),
5234 &mut cx.to_async(),
5235 )
5236 .await
5237 .unwrap();
5238 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5239
5240 let acp_thread = cx
5241 .update(|cx| {
5242 connection
5243 .clone()
5244 .new_session(project.clone(), Path::new(""), cx)
5245 })
5246 .await
5247 .unwrap();
5248 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5249 let thread = agent.read_with(cx, |agent, _| {
5250 agent.sessions.get(&session_id).unwrap().thread.clone()
5251 });
5252 let model = Arc::new(FakeLanguageModel::default());
5253
5254 thread.update(cx, |thread, cx| {
5255 thread.set_model(model.clone(), cx);
5256 });
5257 cx.run_until_parked();
5258
5259 // === First turn: create subagent, trigger context window warning ===
5260 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
5261 cx.run_until_parked();
5262 model.send_last_completion_stream_text_chunk("spawning subagent");
5263 let subagent_tool_input = SpawnAgentToolInput {
5264 label: "initial task".to_string(),
5265 message: "do the first task".to_string(),
5266 session_id: None,
5267 };
5268 let subagent_tool_use = LanguageModelToolUse {
5269 id: "subagent_1".into(),
5270 name: SpawnAgentTool::NAME.into(),
5271 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5272 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5273 is_input_complete: true,
5274 thought_signature: None,
5275 };
5276 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5277 subagent_tool_use,
5278 ));
5279 model.end_last_completion_stream();
5280
5281 cx.run_until_parked();
5282
5283 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5284 thread
5285 .running_subagent_ids(cx)
5286 .get(0)
5287 .expect("subagent thread should be running")
5288 .clone()
5289 });
5290
5291 // Subagent sends a usage update that crosses the warning threshold.
5292 // This triggers Normal→Warning, stopping the subagent.
5293 model.send_last_completion_stream_text_chunk("partial work");
5294 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5295 TokenUsage {
5296 input_tokens: 850_000,
5297 output_tokens: 0,
5298 cache_creation_input_tokens: 0,
5299 cache_read_input_tokens: 0,
5300 },
5301 ));
5302
5303 cx.run_until_parked();
5304
5305 // Verify the first turn was stopped with a context window warning
5306 thread.read_with(cx, |thread, cx| {
5307 assert!(
5308 thread.running_subagent_ids(cx).is_empty(),
5309 "subagent should be stopped after context window warning"
5310 );
5311 });
5312
5313 // Parent model responds to complete first turn
5314 model.send_last_completion_stream_text_chunk("First response");
5315 model.end_last_completion_stream();
5316
5317 send.await.unwrap();
5318
5319 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5320 assert!(
5321 markdown.contains("nearing the end of its context window"),
5322 "first turn should have context window warning, got:\n{markdown}"
5323 );
5324
5325 // === Second turn: resume the same subagent (now at Warning level) ===
5326 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
5327 cx.run_until_parked();
5328 model.send_last_completion_stream_text_chunk("resuming subagent");
5329 let resume_tool_input = SpawnAgentToolInput {
5330 label: "follow-up task".to_string(),
5331 message: "do the follow-up task".to_string(),
5332 session_id: Some(subagent_session_id.clone()),
5333 };
5334 let resume_tool_use = LanguageModelToolUse {
5335 id: "subagent_2".into(),
5336 name: SpawnAgentTool::NAME.into(),
5337 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
5338 input: serde_json::to_value(&resume_tool_input).unwrap(),
5339 is_input_complete: true,
5340 thought_signature: None,
5341 };
5342 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
5343 model.end_last_completion_stream();
5344
5345 cx.run_until_parked();
5346
5347 // Subagent responds with tokens still at warning level (no worse).
5348 // Since ratio_before_prompt was already Warning, this should NOT
5349 // trigger the context window warning again.
5350 model.send_last_completion_stream_text_chunk("follow-up task response");
5351 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5352 TokenUsage {
5353 input_tokens: 870_000,
5354 output_tokens: 0,
5355 cache_creation_input_tokens: 0,
5356 cache_read_input_tokens: 0,
5357 },
5358 ));
5359 model.end_last_completion_stream();
5360
5361 cx.run_until_parked();
5362
5363 // Parent model responds to complete second turn
5364 model.send_last_completion_stream_text_chunk("Second response");
5365 model.end_last_completion_stream();
5366
5367 send2.await.unwrap();
5368
5369 // The resumed subagent should have completed normally since the ratio
5370 // didn't transition (it was Warning before and stayed at Warning)
5371 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5372 assert!(
5373 markdown.contains("follow-up task response"),
5374 "resumed subagent should complete normally when already at warning, got:\n{markdown}"
5375 );
5376 // The second tool call should NOT have a context window warning
5377 let second_tool_pos = markdown
5378 .find("follow-up task")
5379 .expect("should find follow-up tool call");
5380 let after_second_tool = &markdown[second_tool_pos..];
5381 assert!(
5382 !after_second_tool.contains("nearing the end of its context window"),
5383 "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
5384 );
5385}
5386
5387#[gpui::test]
5388async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
5389 init_test(cx);
5390 cx.update(|cx| {
5391 LanguageModelRegistry::test(cx);
5392 });
5393 cx.update(|cx| {
5394 cx.update_flags(true, vec!["subagents".to_string()]);
5395 });
5396
5397 let fs = FakeFs::new(cx.executor());
5398 fs.insert_tree(
5399 "/",
5400 json!({
5401 "a": {
5402 "b.md": "Lorem"
5403 }
5404 }),
5405 )
5406 .await;
5407 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5408 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5409 let agent = NativeAgent::new(
5410 project.clone(),
5411 thread_store.clone(),
5412 Templates::new(),
5413 None,
5414 fs.clone(),
5415 &mut cx.to_async(),
5416 )
5417 .await
5418 .unwrap();
5419 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5420
5421 let acp_thread = cx
5422 .update(|cx| {
5423 connection
5424 .clone()
5425 .new_session(project.clone(), Path::new(""), cx)
5426 })
5427 .await
5428 .unwrap();
5429 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5430 let thread = agent.read_with(cx, |agent, _| {
5431 agent.sessions.get(&session_id).unwrap().thread.clone()
5432 });
5433 let model = Arc::new(FakeLanguageModel::default());
5434
5435 thread.update(cx, |thread, cx| {
5436 thread.set_model(model.clone(), cx);
5437 });
5438 cx.run_until_parked();
5439
5440 // Start the parent turn
5441 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5442 cx.run_until_parked();
5443 model.send_last_completion_stream_text_chunk("spawning subagent");
5444 let subagent_tool_input = SpawnAgentToolInput {
5445 label: "label".to_string(),
5446 message: "subagent task prompt".to_string(),
5447 session_id: None,
5448 };
5449 let subagent_tool_use = LanguageModelToolUse {
5450 id: "subagent_1".into(),
5451 name: SpawnAgentTool::NAME.into(),
5452 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5453 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5454 is_input_complete: true,
5455 thought_signature: None,
5456 };
5457 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5458 subagent_tool_use,
5459 ));
5460 model.end_last_completion_stream();
5461
5462 cx.run_until_parked();
5463
5464 // Verify subagent is running
5465 thread.read_with(cx, |thread, cx| {
5466 assert!(
5467 !thread.running_subagent_ids(cx).is_empty(),
5468 "subagent should be running"
5469 );
5470 });
5471
5472 // The subagent's model returns a non-retryable error
5473 model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
5474 tokens: None,
5475 });
5476
5477 cx.run_until_parked();
5478
5479 // The subagent should no longer be running
5480 thread.read_with(cx, |thread, cx| {
5481 assert!(
5482 thread.running_subagent_ids(cx).is_empty(),
5483 "subagent should not be running after error"
5484 );
5485 });
5486
5487 // The parent model should get a new completion request to respond to the tool error
5488 model.send_last_completion_stream_text_chunk("Response after error");
5489 model.end_last_completion_stream();
5490
5491 send.await.unwrap();
5492
5493 // Verify the parent thread shows the error in the tool call
5494 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5495 assert!(
5496 markdown.contains("Status: Failed"),
5497 "tool call should have Failed status after model error, got:\n{markdown}"
5498 );
5499}
5500
5501#[gpui::test]
5502async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5503 init_test(cx);
5504
5505 let fs = FakeFs::new(cx.executor());
5506 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5507 .await;
5508 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5509
5510 cx.update(|cx| {
5511 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5512 settings.tool_permissions.tools.insert(
5513 EditFileTool::NAME.into(),
5514 agent_settings::ToolRules {
5515 default: Some(settings::ToolPermissionMode::Allow),
5516 always_allow: vec![],
5517 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5518 always_confirm: vec![],
5519 invalid_patterns: vec![],
5520 },
5521 );
5522 agent_settings::AgentSettings::override_global(settings, cx);
5523 });
5524
5525 let context_server_registry =
5526 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5527 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5528 let templates = crate::Templates::new();
5529 let thread = cx.new(|cx| {
5530 crate::Thread::new(
5531 project.clone(),
5532 cx.new(|_cx| prompt_store::ProjectContext::default()),
5533 context_server_registry,
5534 templates.clone(),
5535 None,
5536 cx,
5537 )
5538 });
5539
5540 #[allow(clippy::arc_with_non_send_sync)]
5541 let tool = Arc::new(crate::EditFileTool::new(
5542 project.clone(),
5543 thread.downgrade(),
5544 language_registry,
5545 templates,
5546 ));
5547 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5548
5549 let task = cx.update(|cx| {
5550 tool.run(
5551 ToolInput::resolved(crate::EditFileToolInput {
5552 display_description: "Edit sensitive file".to_string(),
5553 path: "root/sensitive_config.txt".into(),
5554 mode: crate::EditFileMode::Edit,
5555 }),
5556 event_stream,
5557 cx,
5558 )
5559 });
5560
5561 let result = task.await;
5562 assert!(result.is_err(), "expected edit to be blocked");
5563 assert!(
5564 result.unwrap_err().to_string().contains("blocked"),
5565 "error should mention the edit was blocked"
5566 );
5567}
5568
5569#[gpui::test]
5570async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5571 init_test(cx);
5572
5573 let fs = FakeFs::new(cx.executor());
5574 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5575 .await;
5576 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5577
5578 cx.update(|cx| {
5579 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5580 settings.tool_permissions.tools.insert(
5581 DeletePathTool::NAME.into(),
5582 agent_settings::ToolRules {
5583 default: Some(settings::ToolPermissionMode::Allow),
5584 always_allow: vec![],
5585 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5586 always_confirm: vec![],
5587 invalid_patterns: vec![],
5588 },
5589 );
5590 agent_settings::AgentSettings::override_global(settings, cx);
5591 });
5592
5593 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5594
5595 #[allow(clippy::arc_with_non_send_sync)]
5596 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5597 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5598
5599 let task = cx.update(|cx| {
5600 tool.run(
5601 ToolInput::resolved(crate::DeletePathToolInput {
5602 path: "root/important_data.txt".to_string(),
5603 }),
5604 event_stream,
5605 cx,
5606 )
5607 });
5608
5609 let result = task.await;
5610 assert!(result.is_err(), "expected deletion to be blocked");
5611 assert!(
5612 result.unwrap_err().contains("blocked"),
5613 "error should mention the deletion was blocked"
5614 );
5615}
5616
5617#[gpui::test]
5618async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5619 init_test(cx);
5620
5621 let fs = FakeFs::new(cx.executor());
5622 fs.insert_tree(
5623 "/root",
5624 json!({
5625 "safe.txt": "content",
5626 "protected": {}
5627 }),
5628 )
5629 .await;
5630 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5631
5632 cx.update(|cx| {
5633 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5634 settings.tool_permissions.tools.insert(
5635 MovePathTool::NAME.into(),
5636 agent_settings::ToolRules {
5637 default: Some(settings::ToolPermissionMode::Allow),
5638 always_allow: vec![],
5639 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5640 always_confirm: vec![],
5641 invalid_patterns: vec![],
5642 },
5643 );
5644 agent_settings::AgentSettings::override_global(settings, cx);
5645 });
5646
5647 #[allow(clippy::arc_with_non_send_sync)]
5648 let tool = Arc::new(crate::MovePathTool::new(project));
5649 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5650
5651 let task = cx.update(|cx| {
5652 tool.run(
5653 ToolInput::resolved(crate::MovePathToolInput {
5654 source_path: "root/safe.txt".to_string(),
5655 destination_path: "root/protected/safe.txt".to_string(),
5656 }),
5657 event_stream,
5658 cx,
5659 )
5660 });
5661
5662 let result = task.await;
5663 assert!(
5664 result.is_err(),
5665 "expected move to be blocked due to destination path"
5666 );
5667 assert!(
5668 result.unwrap_err().contains("blocked"),
5669 "error should mention the move was blocked"
5670 );
5671}
5672
5673#[gpui::test]
5674async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5675 init_test(cx);
5676
5677 let fs = FakeFs::new(cx.executor());
5678 fs.insert_tree(
5679 "/root",
5680 json!({
5681 "secret.txt": "secret content",
5682 "public": {}
5683 }),
5684 )
5685 .await;
5686 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5687
5688 cx.update(|cx| {
5689 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5690 settings.tool_permissions.tools.insert(
5691 MovePathTool::NAME.into(),
5692 agent_settings::ToolRules {
5693 default: Some(settings::ToolPermissionMode::Allow),
5694 always_allow: vec![],
5695 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5696 always_confirm: vec![],
5697 invalid_patterns: vec![],
5698 },
5699 );
5700 agent_settings::AgentSettings::override_global(settings, cx);
5701 });
5702
5703 #[allow(clippy::arc_with_non_send_sync)]
5704 let tool = Arc::new(crate::MovePathTool::new(project));
5705 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5706
5707 let task = cx.update(|cx| {
5708 tool.run(
5709 ToolInput::resolved(crate::MovePathToolInput {
5710 source_path: "root/secret.txt".to_string(),
5711 destination_path: "root/public/not_secret.txt".to_string(),
5712 }),
5713 event_stream,
5714 cx,
5715 )
5716 });
5717
5718 let result = task.await;
5719 assert!(
5720 result.is_err(),
5721 "expected move to be blocked due to source path"
5722 );
5723 assert!(
5724 result.unwrap_err().contains("blocked"),
5725 "error should mention the move was blocked"
5726 );
5727}
5728
5729#[gpui::test]
5730async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5731 init_test(cx);
5732
5733 let fs = FakeFs::new(cx.executor());
5734 fs.insert_tree(
5735 "/root",
5736 json!({
5737 "confidential.txt": "confidential data",
5738 "dest": {}
5739 }),
5740 )
5741 .await;
5742 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5743
5744 cx.update(|cx| {
5745 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5746 settings.tool_permissions.tools.insert(
5747 CopyPathTool::NAME.into(),
5748 agent_settings::ToolRules {
5749 default: Some(settings::ToolPermissionMode::Allow),
5750 always_allow: vec![],
5751 always_deny: vec![
5752 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5753 ],
5754 always_confirm: vec![],
5755 invalid_patterns: vec![],
5756 },
5757 );
5758 agent_settings::AgentSettings::override_global(settings, cx);
5759 });
5760
5761 #[allow(clippy::arc_with_non_send_sync)]
5762 let tool = Arc::new(crate::CopyPathTool::new(project));
5763 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5764
5765 let task = cx.update(|cx| {
5766 tool.run(
5767 ToolInput::resolved(crate::CopyPathToolInput {
5768 source_path: "root/confidential.txt".to_string(),
5769 destination_path: "root/dest/copy.txt".to_string(),
5770 }),
5771 event_stream,
5772 cx,
5773 )
5774 });
5775
5776 let result = task.await;
5777 assert!(result.is_err(), "expected copy to be blocked");
5778 assert!(
5779 result.unwrap_err().contains("blocked"),
5780 "error should mention the copy was blocked"
5781 );
5782}
5783
5784#[gpui::test]
5785async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5786 init_test(cx);
5787
5788 let fs = FakeFs::new(cx.executor());
5789 fs.insert_tree(
5790 "/root",
5791 json!({
5792 "normal.txt": "normal content",
5793 "readonly": {
5794 "config.txt": "readonly content"
5795 }
5796 }),
5797 )
5798 .await;
5799 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5800
5801 cx.update(|cx| {
5802 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5803 settings.tool_permissions.tools.insert(
5804 SaveFileTool::NAME.into(),
5805 agent_settings::ToolRules {
5806 default: Some(settings::ToolPermissionMode::Allow),
5807 always_allow: vec![],
5808 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5809 always_confirm: vec![],
5810 invalid_patterns: vec![],
5811 },
5812 );
5813 agent_settings::AgentSettings::override_global(settings, cx);
5814 });
5815
5816 #[allow(clippy::arc_with_non_send_sync)]
5817 let tool = Arc::new(crate::SaveFileTool::new(project));
5818 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5819
5820 let task = cx.update(|cx| {
5821 tool.run(
5822 ToolInput::resolved(crate::SaveFileToolInput {
5823 paths: vec![
5824 std::path::PathBuf::from("root/normal.txt"),
5825 std::path::PathBuf::from("root/readonly/config.txt"),
5826 ],
5827 }),
5828 event_stream,
5829 cx,
5830 )
5831 });
5832
5833 let result = task.await;
5834 assert!(
5835 result.is_err(),
5836 "expected save to be blocked due to denied path"
5837 );
5838 assert!(
5839 result.unwrap_err().contains("blocked"),
5840 "error should mention the save was blocked"
5841 );
5842}
5843
5844#[gpui::test]
5845async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5846 init_test(cx);
5847
5848 let fs = FakeFs::new(cx.executor());
5849 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5850 .await;
5851 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5852
5853 cx.update(|cx| {
5854 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5855 settings.tool_permissions.tools.insert(
5856 SaveFileTool::NAME.into(),
5857 agent_settings::ToolRules {
5858 default: Some(settings::ToolPermissionMode::Allow),
5859 always_allow: vec![],
5860 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5861 always_confirm: vec![],
5862 invalid_patterns: vec![],
5863 },
5864 );
5865 agent_settings::AgentSettings::override_global(settings, cx);
5866 });
5867
5868 #[allow(clippy::arc_with_non_send_sync)]
5869 let tool = Arc::new(crate::SaveFileTool::new(project));
5870 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5871
5872 let task = cx.update(|cx| {
5873 tool.run(
5874 ToolInput::resolved(crate::SaveFileToolInput {
5875 paths: vec![std::path::PathBuf::from("root/config.secret")],
5876 }),
5877 event_stream,
5878 cx,
5879 )
5880 });
5881
5882 let result = task.await;
5883 assert!(result.is_err(), "expected save to be blocked");
5884 assert!(
5885 result.unwrap_err().contains("blocked"),
5886 "error should mention the save was blocked"
5887 );
5888}
5889
5890#[gpui::test]
5891async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5892 init_test(cx);
5893
5894 cx.update(|cx| {
5895 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5896 settings.tool_permissions.tools.insert(
5897 WebSearchTool::NAME.into(),
5898 agent_settings::ToolRules {
5899 default: Some(settings::ToolPermissionMode::Allow),
5900 always_allow: vec![],
5901 always_deny: vec![
5902 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5903 ],
5904 always_confirm: vec![],
5905 invalid_patterns: vec![],
5906 },
5907 );
5908 agent_settings::AgentSettings::override_global(settings, cx);
5909 });
5910
5911 #[allow(clippy::arc_with_non_send_sync)]
5912 let tool = Arc::new(crate::WebSearchTool);
5913 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5914
5915 let input: crate::WebSearchToolInput =
5916 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5917
5918 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
5919
5920 let result = task.await;
5921 assert!(result.is_err(), "expected search to be blocked");
5922 match result.unwrap_err() {
5923 crate::WebSearchToolOutput::Error { error } => {
5924 assert!(
5925 error.contains("blocked"),
5926 "error should mention the search was blocked"
5927 );
5928 }
5929 other => panic!("expected Error variant, got: {other:?}"),
5930 }
5931}
5932
5933#[gpui::test]
5934async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5935 init_test(cx);
5936
5937 let fs = FakeFs::new(cx.executor());
5938 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5939 .await;
5940 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5941
5942 cx.update(|cx| {
5943 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5944 settings.tool_permissions.tools.insert(
5945 EditFileTool::NAME.into(),
5946 agent_settings::ToolRules {
5947 default: Some(settings::ToolPermissionMode::Confirm),
5948 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5949 always_deny: vec![],
5950 always_confirm: vec![],
5951 invalid_patterns: vec![],
5952 },
5953 );
5954 agent_settings::AgentSettings::override_global(settings, cx);
5955 });
5956
5957 let context_server_registry =
5958 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5959 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5960 let templates = crate::Templates::new();
5961 let thread = cx.new(|cx| {
5962 crate::Thread::new(
5963 project.clone(),
5964 cx.new(|_cx| prompt_store::ProjectContext::default()),
5965 context_server_registry,
5966 templates.clone(),
5967 None,
5968 cx,
5969 )
5970 });
5971
5972 #[allow(clippy::arc_with_non_send_sync)]
5973 let tool = Arc::new(crate::EditFileTool::new(
5974 project,
5975 thread.downgrade(),
5976 language_registry,
5977 templates,
5978 ));
5979 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5980
5981 let _task = cx.update(|cx| {
5982 tool.run(
5983 ToolInput::resolved(crate::EditFileToolInput {
5984 display_description: "Edit README".to_string(),
5985 path: "root/README.md".into(),
5986 mode: crate::EditFileMode::Edit,
5987 }),
5988 event_stream,
5989 cx,
5990 )
5991 });
5992
5993 cx.run_until_parked();
5994
5995 let event = rx.try_next();
5996 assert!(
5997 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5998 "expected no authorization request for allowed .md file"
5999 );
6000}
6001
6002#[gpui::test]
6003async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
6004 init_test(cx);
6005
6006 let fs = FakeFs::new(cx.executor());
6007 fs.insert_tree(
6008 "/root",
6009 json!({
6010 ".zed": {
6011 "settings.json": "{}"
6012 },
6013 "README.md": "# Hello"
6014 }),
6015 )
6016 .await;
6017 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
6018
6019 cx.update(|cx| {
6020 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6021 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
6022 agent_settings::AgentSettings::override_global(settings, cx);
6023 });
6024
6025 let context_server_registry =
6026 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
6027 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
6028 let templates = crate::Templates::new();
6029 let thread = cx.new(|cx| {
6030 crate::Thread::new(
6031 project.clone(),
6032 cx.new(|_cx| prompt_store::ProjectContext::default()),
6033 context_server_registry,
6034 templates.clone(),
6035 None,
6036 cx,
6037 )
6038 });
6039
6040 #[allow(clippy::arc_with_non_send_sync)]
6041 let tool = Arc::new(crate::EditFileTool::new(
6042 project,
6043 thread.downgrade(),
6044 language_registry,
6045 templates,
6046 ));
6047
6048 // Editing a file inside .zed/ should still prompt even with global default: allow,
6049 // because local settings paths are sensitive and require confirmation regardless.
6050 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6051 let _task = cx.update(|cx| {
6052 tool.run(
6053 ToolInput::resolved(crate::EditFileToolInput {
6054 display_description: "Edit local settings".to_string(),
6055 path: "root/.zed/settings.json".into(),
6056 mode: crate::EditFileMode::Edit,
6057 }),
6058 event_stream,
6059 cx,
6060 )
6061 });
6062
6063 let _update = rx.expect_update_fields().await;
6064 let _auth = rx.expect_authorization().await;
6065}
6066
6067#[gpui::test]
6068async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
6069 init_test(cx);
6070
6071 cx.update(|cx| {
6072 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6073 settings.tool_permissions.tools.insert(
6074 FetchTool::NAME.into(),
6075 agent_settings::ToolRules {
6076 default: Some(settings::ToolPermissionMode::Allow),
6077 always_allow: vec![],
6078 always_deny: vec![
6079 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
6080 ],
6081 always_confirm: vec![],
6082 invalid_patterns: vec![],
6083 },
6084 );
6085 agent_settings::AgentSettings::override_global(settings, cx);
6086 });
6087
6088 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6089
6090 #[allow(clippy::arc_with_non_send_sync)]
6091 let tool = Arc::new(crate::FetchTool::new(http_client));
6092 let (event_stream, _rx) = crate::ToolCallEventStream::test();
6093
6094 let input: crate::FetchToolInput =
6095 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
6096
6097 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6098
6099 let result = task.await;
6100 assert!(result.is_err(), "expected fetch to be blocked");
6101 assert!(
6102 result.unwrap_err().contains("blocked"),
6103 "error should mention the fetch was blocked"
6104 );
6105}
6106
6107#[gpui::test]
6108async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
6109 init_test(cx);
6110
6111 cx.update(|cx| {
6112 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6113 settings.tool_permissions.tools.insert(
6114 FetchTool::NAME.into(),
6115 agent_settings::ToolRules {
6116 default: Some(settings::ToolPermissionMode::Confirm),
6117 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
6118 always_deny: vec![],
6119 always_confirm: vec![],
6120 invalid_patterns: vec![],
6121 },
6122 );
6123 agent_settings::AgentSettings::override_global(settings, cx);
6124 });
6125
6126 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6127
6128 #[allow(clippy::arc_with_non_send_sync)]
6129 let tool = Arc::new(crate::FetchTool::new(http_client));
6130 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6131
6132 let input: crate::FetchToolInput =
6133 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
6134
6135 let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6136
6137 cx.run_until_parked();
6138
6139 let event = rx.try_next();
6140 assert!(
6141 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
6142 "expected no authorization request for allowed docs.rs URL"
6143 );
6144}
6145
6146#[gpui::test]
6147async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
6148 init_test(cx);
6149 always_allow_tools(cx);
6150
6151 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6152 let fake_model = model.as_fake();
6153
6154 // Add a tool so we can simulate tool calls
6155 thread.update(cx, |thread, _cx| {
6156 thread.add_tool(EchoTool);
6157 });
6158
6159 // Start a turn by sending a message
6160 let mut events = thread
6161 .update(cx, |thread, cx| {
6162 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
6163 })
6164 .unwrap();
6165 cx.run_until_parked();
6166
6167 // Simulate the model making a tool call
6168 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6169 LanguageModelToolUse {
6170 id: "tool_1".into(),
6171 name: "echo".into(),
6172 raw_input: r#"{"text": "hello"}"#.into(),
6173 input: json!({"text": "hello"}),
6174 is_input_complete: true,
6175 thought_signature: None,
6176 },
6177 ));
6178 fake_model
6179 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
6180
6181 // Signal that a message is queued before ending the stream
6182 thread.update(cx, |thread, _cx| {
6183 thread.set_has_queued_message(true);
6184 });
6185
6186 // Now end the stream - tool will run, and the boundary check should see the queue
6187 fake_model.end_last_completion_stream();
6188
6189 // Collect all events until the turn stops
6190 let all_events = collect_events_until_stop(&mut events, cx).await;
6191
6192 // Verify we received the tool call event
6193 let tool_call_ids: Vec<_> = all_events
6194 .iter()
6195 .filter_map(|e| match e {
6196 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
6197 _ => None,
6198 })
6199 .collect();
6200 assert_eq!(
6201 tool_call_ids,
6202 vec!["tool_1"],
6203 "Should have received a tool call event for our echo tool"
6204 );
6205
6206 // The turn should have stopped with EndTurn
6207 let stop_reasons = stop_events(all_events);
6208 assert_eq!(
6209 stop_reasons,
6210 vec![acp::StopReason::EndTurn],
6211 "Turn should have ended after tool completion due to queued message"
6212 );
6213
6214 // Verify the queued message flag is still set
6215 thread.update(cx, |thread, _cx| {
6216 assert!(
6217 thread.has_queued_message(),
6218 "Should still have queued message flag set"
6219 );
6220 });
6221
6222 // Thread should be idle now
6223 thread.update(cx, |thread, _cx| {
6224 assert!(
6225 thread.is_turn_complete(),
6226 "Thread should not be running after turn ends"
6227 );
6228 });
6229}