1use super::*;
2use acp_thread::{
3 AgentConnection, AgentModelGroupName, AgentModelList, PermissionOptions, ThreadStatus,
4 UserMessageId,
5};
6use agent_client_protocol::{self as acp};
7use agent_settings::AgentProfileId;
8use anyhow::Result;
9use client::{Client, UserStore};
10use cloud_llm_client::CompletionIntent;
11use collections::IndexMap;
12use context_server::{ContextServer, ContextServerCommand, ContextServerId};
13use feature_flags::FeatureFlagAppExt as _;
14use fs::{FakeFs, Fs};
15use futures::{
16 FutureExt as _, StreamExt,
17 channel::{
18 mpsc::{self, UnboundedReceiver},
19 oneshot,
20 },
21 future::{Fuse, Shared},
22};
23use gpui::{
24 App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
25 http_client::FakeHttpClient,
26};
27use indoc::indoc;
28use language_model::{
29 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
30 LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
31 LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
32 LanguageModelToolUse, MessageContent, Role, StopReason, TokenUsage,
33 fake_provider::FakeLanguageModel,
34};
35use pretty_assertions::assert_eq;
36use project::{
37 Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
38};
39use prompt_store::ProjectContext;
40use reqwest_client::ReqwestClient;
41use schemars::JsonSchema;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use settings::{Settings, SettingsStore};
45use std::{
46 path::Path,
47 pin::Pin,
48 rc::Rc,
49 sync::{
50 Arc,
51 atomic::{AtomicBool, Ordering},
52 },
53 time::Duration,
54};
55use util::path;
56
57mod edit_file_thread_test;
58mod test_tools;
59use test_tools::*;
60
61fn init_test(cx: &mut TestAppContext) {
62 cx.update(|cx| {
63 let settings_store = SettingsStore::test(cx);
64 cx.set_global(settings_store);
65 });
66}
67
68struct FakeTerminalHandle {
69 killed: Arc<AtomicBool>,
70 stopped_by_user: Arc<AtomicBool>,
71 exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
72 wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
73 output: acp::TerminalOutputResponse,
74 id: acp::TerminalId,
75}
76
77impl FakeTerminalHandle {
78 fn new_never_exits(cx: &mut App) -> Self {
79 let killed = Arc::new(AtomicBool::new(false));
80 let stopped_by_user = Arc::new(AtomicBool::new(false));
81
82 let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
83
84 let wait_for_exit = cx
85 .spawn(async move |_cx| {
86 // Wait for the exit signal (sent when kill() is called)
87 let _ = exit_receiver.await;
88 acp::TerminalExitStatus::new()
89 })
90 .shared();
91
92 Self {
93 killed,
94 stopped_by_user,
95 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
96 wait_for_exit,
97 output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
98 id: acp::TerminalId::new("fake_terminal".to_string()),
99 }
100 }
101
102 fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
103 let killed = Arc::new(AtomicBool::new(false));
104 let stopped_by_user = Arc::new(AtomicBool::new(false));
105 let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
106
107 let wait_for_exit = cx
108 .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
109 .shared();
110
111 Self {
112 killed,
113 stopped_by_user,
114 exit_sender: std::cell::RefCell::new(Some(exit_sender)),
115 wait_for_exit,
116 output: acp::TerminalOutputResponse::new("command output".to_string(), false),
117 id: acp::TerminalId::new("fake_terminal".to_string()),
118 }
119 }
120
121 fn was_killed(&self) -> bool {
122 self.killed.load(Ordering::SeqCst)
123 }
124
125 fn set_stopped_by_user(&self, stopped: bool) {
126 self.stopped_by_user.store(stopped, Ordering::SeqCst);
127 }
128
129 fn signal_exit(&self) {
130 if let Some(sender) = self.exit_sender.borrow_mut().take() {
131 let _ = sender.send(());
132 }
133 }
134}
135
136impl crate::TerminalHandle for FakeTerminalHandle {
137 fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
138 Ok(self.id.clone())
139 }
140
141 fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
142 Ok(self.output.clone())
143 }
144
145 fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
146 Ok(self.wait_for_exit.clone())
147 }
148
149 fn kill(&self, _cx: &AsyncApp) -> Result<()> {
150 self.killed.store(true, Ordering::SeqCst);
151 self.signal_exit();
152 Ok(())
153 }
154
155 fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
156 Ok(self.stopped_by_user.load(Ordering::SeqCst))
157 }
158}
159
160struct FakeSubagentHandle {
161 session_id: acp::SessionId,
162 send_task: Shared<Task<String>>,
163}
164
165impl SubagentHandle for FakeSubagentHandle {
166 fn id(&self) -> acp::SessionId {
167 self.session_id.clone()
168 }
169
170 fn num_entries(&self, _cx: &App) -> usize {
171 unimplemented!()
172 }
173
174 fn send(&self, _message: String, cx: &AsyncApp) -> Task<Result<String>> {
175 let task = self.send_task.clone();
176 cx.background_spawn(async move { Ok(task.await) })
177 }
178}
179
180#[derive(Default)]
181struct FakeThreadEnvironment {
182 terminal_handle: Option<Rc<FakeTerminalHandle>>,
183 subagent_handle: Option<Rc<FakeSubagentHandle>>,
184}
185
186impl FakeThreadEnvironment {
187 pub fn with_terminal(self, terminal_handle: FakeTerminalHandle) -> Self {
188 Self {
189 terminal_handle: Some(terminal_handle.into()),
190 ..self
191 }
192 }
193}
194
195impl crate::ThreadEnvironment for FakeThreadEnvironment {
196 fn create_terminal(
197 &self,
198 _command: String,
199 _cwd: Option<std::path::PathBuf>,
200 _output_byte_limit: Option<u64>,
201 _cx: &mut AsyncApp,
202 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
203 let handle = self
204 .terminal_handle
205 .clone()
206 .expect("Terminal handle not available on FakeThreadEnvironment");
207 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
208 }
209
210 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
211 Ok(self
212 .subagent_handle
213 .clone()
214 .expect("Subagent handle not available on FakeThreadEnvironment")
215 as Rc<dyn SubagentHandle>)
216 }
217}
218
219/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
220struct MultiTerminalEnvironment {
221 handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
222}
223
224impl MultiTerminalEnvironment {
225 fn new() -> Self {
226 Self {
227 handles: std::cell::RefCell::new(Vec::new()),
228 }
229 }
230
231 fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
232 self.handles.borrow().clone()
233 }
234}
235
236impl crate::ThreadEnvironment for MultiTerminalEnvironment {
237 fn create_terminal(
238 &self,
239 _command: String,
240 _cwd: Option<std::path::PathBuf>,
241 _output_byte_limit: Option<u64>,
242 cx: &mut AsyncApp,
243 ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
244 let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
245 self.handles.borrow_mut().push(handle.clone());
246 Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
247 }
248
249 fn create_subagent(&self, _label: String, _cx: &mut App) -> Result<Rc<dyn SubagentHandle>> {
250 unimplemented!()
251 }
252}
253
254fn always_allow_tools(cx: &mut TestAppContext) {
255 cx.update(|cx| {
256 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
257 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
258 agent_settings::AgentSettings::override_global(settings, cx);
259 });
260}
261
262#[gpui::test]
263async fn test_echo(cx: &mut TestAppContext) {
264 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
265 let fake_model = model.as_fake();
266
267 let events = thread
268 .update(cx, |thread, cx| {
269 thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
270 })
271 .unwrap();
272 cx.run_until_parked();
273 fake_model.send_last_completion_stream_text_chunk("Hello");
274 fake_model
275 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
276 fake_model.end_last_completion_stream();
277
278 let events = events.collect().await;
279 thread.update(cx, |thread, _cx| {
280 assert_eq!(
281 thread.last_received_or_pending_message().unwrap().role(),
282 Role::Assistant
283 );
284 assert_eq!(
285 thread
286 .last_received_or_pending_message()
287 .unwrap()
288 .to_markdown(),
289 "Hello\n"
290 )
291 });
292 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
293}
294
295#[gpui::test]
296async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
297 init_test(cx);
298 always_allow_tools(cx);
299
300 let fs = FakeFs::new(cx.executor());
301 let project = Project::test(fs, [], cx).await;
302
303 let environment = Rc::new(cx.update(|cx| {
304 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
305 }));
306 let handle = environment.terminal_handle.clone().unwrap();
307
308 #[allow(clippy::arc_with_non_send_sync)]
309 let tool = Arc::new(crate::TerminalTool::new(project, environment));
310 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
311
312 let task = cx.update(|cx| {
313 tool.run(
314 ToolInput::resolved(crate::TerminalToolInput {
315 command: "sleep 1000".to_string(),
316 cd: ".".to_string(),
317 timeout_ms: Some(5),
318 }),
319 event_stream,
320 cx,
321 )
322 });
323
324 let update = rx.expect_update_fields().await;
325 assert!(
326 update.content.iter().any(|blocks| {
327 blocks
328 .iter()
329 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
330 }),
331 "expected tool call update to include terminal content"
332 );
333
334 let mut task_future: Pin<Box<Fuse<Task<Result<String, String>>>>> = Box::pin(task.fuse());
335
336 let deadline = std::time::Instant::now() + Duration::from_millis(500);
337 loop {
338 if let Some(result) = task_future.as_mut().now_or_never() {
339 let result = result.expect("terminal tool task should complete");
340
341 assert!(
342 handle.was_killed(),
343 "expected terminal handle to be killed on timeout"
344 );
345 assert!(
346 result.contains("partial output"),
347 "expected result to include terminal output, got: {result}"
348 );
349 return;
350 }
351
352 if std::time::Instant::now() >= deadline {
353 panic!("timed out waiting for terminal tool task to complete");
354 }
355
356 cx.run_until_parked();
357 cx.background_executor.timer(Duration::from_millis(1)).await;
358 }
359}
360
361#[gpui::test]
362#[ignore]
363async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
364 init_test(cx);
365 always_allow_tools(cx);
366
367 let fs = FakeFs::new(cx.executor());
368 let project = Project::test(fs, [], cx).await;
369
370 let environment = Rc::new(cx.update(|cx| {
371 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
372 }));
373 let handle = environment.terminal_handle.clone().unwrap();
374
375 #[allow(clippy::arc_with_non_send_sync)]
376 let tool = Arc::new(crate::TerminalTool::new(project, environment));
377 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
378
379 let _task = cx.update(|cx| {
380 tool.run(
381 ToolInput::resolved(crate::TerminalToolInput {
382 command: "sleep 1000".to_string(),
383 cd: ".".to_string(),
384 timeout_ms: None,
385 }),
386 event_stream,
387 cx,
388 )
389 });
390
391 let update = rx.expect_update_fields().await;
392 assert!(
393 update.content.iter().any(|blocks| {
394 blocks
395 .iter()
396 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
397 }),
398 "expected tool call update to include terminal content"
399 );
400
401 cx.background_executor
402 .timer(Duration::from_millis(25))
403 .await;
404
405 assert!(
406 !handle.was_killed(),
407 "did not expect terminal handle to be killed without a timeout"
408 );
409}
410
411#[gpui::test]
412async fn test_thinking(cx: &mut TestAppContext) {
413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
414 let fake_model = model.as_fake();
415
416 let events = thread
417 .update(cx, |thread, cx| {
418 thread.send(
419 UserMessageId::new(),
420 [indoc! {"
421 Testing:
422
423 Generate a thinking step where you just think the word 'Think',
424 and have your final answer be 'Hello'
425 "}],
426 cx,
427 )
428 })
429 .unwrap();
430 cx.run_until_parked();
431 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
432 text: "Think".to_string(),
433 signature: None,
434 });
435 fake_model.send_last_completion_stream_text_chunk("Hello");
436 fake_model
437 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
438 fake_model.end_last_completion_stream();
439
440 let events = events.collect().await;
441 thread.update(cx, |thread, _cx| {
442 assert_eq!(
443 thread.last_received_or_pending_message().unwrap().role(),
444 Role::Assistant
445 );
446 assert_eq!(
447 thread
448 .last_received_or_pending_message()
449 .unwrap()
450 .to_markdown(),
451 indoc! {"
452 <think>Think</think>
453 Hello
454 "}
455 )
456 });
457 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
458}
459
460#[gpui::test]
461async fn test_system_prompt(cx: &mut TestAppContext) {
462 let ThreadTest {
463 model,
464 thread,
465 project_context,
466 ..
467 } = setup(cx, TestModel::Fake).await;
468 let fake_model = model.as_fake();
469
470 project_context.update(cx, |project_context, _cx| {
471 project_context.shell = "test-shell".into()
472 });
473 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
474 thread
475 .update(cx, |thread, cx| {
476 thread.send(UserMessageId::new(), ["abc"], cx)
477 })
478 .unwrap();
479 cx.run_until_parked();
480 let mut pending_completions = fake_model.pending_completions();
481 assert_eq!(
482 pending_completions.len(),
483 1,
484 "unexpected pending completions: {:?}",
485 pending_completions
486 );
487
488 let pending_completion = pending_completions.pop().unwrap();
489 assert_eq!(pending_completion.messages[0].role, Role::System);
490
491 let system_message = &pending_completion.messages[0];
492 let system_prompt = system_message.content[0].to_str().unwrap();
493 assert!(
494 system_prompt.contains("test-shell"),
495 "unexpected system message: {:?}",
496 system_message
497 );
498 assert!(
499 system_prompt.contains("## Fixing Diagnostics"),
500 "unexpected system message: {:?}",
501 system_message
502 );
503}
504
505#[gpui::test]
506async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
507 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
508 let fake_model = model.as_fake();
509
510 thread
511 .update(cx, |thread, cx| {
512 thread.send(UserMessageId::new(), ["abc"], cx)
513 })
514 .unwrap();
515 cx.run_until_parked();
516 let mut pending_completions = fake_model.pending_completions();
517 assert_eq!(
518 pending_completions.len(),
519 1,
520 "unexpected pending completions: {:?}",
521 pending_completions
522 );
523
524 let pending_completion = pending_completions.pop().unwrap();
525 assert_eq!(pending_completion.messages[0].role, Role::System);
526
527 let system_message = &pending_completion.messages[0];
528 let system_prompt = system_message.content[0].to_str().unwrap();
529 assert!(
530 !system_prompt.contains("## Tool Use"),
531 "unexpected system message: {:?}",
532 system_message
533 );
534 assert!(
535 !system_prompt.contains("## Fixing Diagnostics"),
536 "unexpected system message: {:?}",
537 system_message
538 );
539}
540
541#[gpui::test]
542async fn test_prompt_caching(cx: &mut TestAppContext) {
543 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
544 let fake_model = model.as_fake();
545
546 // Send initial user message and verify it's cached
547 thread
548 .update(cx, |thread, cx| {
549 thread.send(UserMessageId::new(), ["Message 1"], cx)
550 })
551 .unwrap();
552 cx.run_until_parked();
553
554 let completion = fake_model.pending_completions().pop().unwrap();
555 assert_eq!(
556 completion.messages[1..],
557 vec![LanguageModelRequestMessage {
558 role: Role::User,
559 content: vec!["Message 1".into()],
560 cache: true,
561 reasoning_details: None,
562 }]
563 );
564 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
565 "Response to Message 1".into(),
566 ));
567 fake_model.end_last_completion_stream();
568 cx.run_until_parked();
569
570 // Send another user message and verify only the latest is cached
571 thread
572 .update(cx, |thread, cx| {
573 thread.send(UserMessageId::new(), ["Message 2"], cx)
574 })
575 .unwrap();
576 cx.run_until_parked();
577
578 let completion = fake_model.pending_completions().pop().unwrap();
579 assert_eq!(
580 completion.messages[1..],
581 vec![
582 LanguageModelRequestMessage {
583 role: Role::User,
584 content: vec!["Message 1".into()],
585 cache: false,
586 reasoning_details: None,
587 },
588 LanguageModelRequestMessage {
589 role: Role::Assistant,
590 content: vec!["Response to Message 1".into()],
591 cache: false,
592 reasoning_details: None,
593 },
594 LanguageModelRequestMessage {
595 role: Role::User,
596 content: vec!["Message 2".into()],
597 cache: true,
598 reasoning_details: None,
599 }
600 ]
601 );
602 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
603 "Response to Message 2".into(),
604 ));
605 fake_model.end_last_completion_stream();
606 cx.run_until_parked();
607
608 // Simulate a tool call and verify that the latest tool result is cached
609 thread.update(cx, |thread, _| thread.add_tool(EchoTool));
610 thread
611 .update(cx, |thread, cx| {
612 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
613 })
614 .unwrap();
615 cx.run_until_parked();
616
617 let tool_use = LanguageModelToolUse {
618 id: "tool_1".into(),
619 name: EchoTool::NAME.into(),
620 raw_input: json!({"text": "test"}).to_string(),
621 input: json!({"text": "test"}),
622 is_input_complete: true,
623 thought_signature: None,
624 };
625 fake_model
626 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
627 fake_model.end_last_completion_stream();
628 cx.run_until_parked();
629
630 let completion = fake_model.pending_completions().pop().unwrap();
631 let tool_result = LanguageModelToolResult {
632 tool_use_id: "tool_1".into(),
633 tool_name: EchoTool::NAME.into(),
634 is_error: false,
635 content: "test".into(),
636 output: Some("test".into()),
637 };
638 assert_eq!(
639 completion.messages[1..],
640 vec![
641 LanguageModelRequestMessage {
642 role: Role::User,
643 content: vec!["Message 1".into()],
644 cache: false,
645 reasoning_details: None,
646 },
647 LanguageModelRequestMessage {
648 role: Role::Assistant,
649 content: vec!["Response to Message 1".into()],
650 cache: false,
651 reasoning_details: None,
652 },
653 LanguageModelRequestMessage {
654 role: Role::User,
655 content: vec!["Message 2".into()],
656 cache: false,
657 reasoning_details: None,
658 },
659 LanguageModelRequestMessage {
660 role: Role::Assistant,
661 content: vec!["Response to Message 2".into()],
662 cache: false,
663 reasoning_details: None,
664 },
665 LanguageModelRequestMessage {
666 role: Role::User,
667 content: vec!["Use the echo tool".into()],
668 cache: false,
669 reasoning_details: None,
670 },
671 LanguageModelRequestMessage {
672 role: Role::Assistant,
673 content: vec![MessageContent::ToolUse(tool_use)],
674 cache: false,
675 reasoning_details: None,
676 },
677 LanguageModelRequestMessage {
678 role: Role::User,
679 content: vec![MessageContent::ToolResult(tool_result)],
680 cache: true,
681 reasoning_details: None,
682 }
683 ]
684 );
685}
686
687#[gpui::test]
688#[cfg_attr(not(feature = "e2e"), ignore)]
689async fn test_basic_tool_calls(cx: &mut TestAppContext) {
690 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
691
692 // Test a tool call that's likely to complete *before* streaming stops.
693 let events = thread
694 .update(cx, |thread, cx| {
695 thread.add_tool(EchoTool);
696 thread.send(
697 UserMessageId::new(),
698 ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
699 cx,
700 )
701 })
702 .unwrap()
703 .collect()
704 .await;
705 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
706
707 // Test a tool calls that's likely to complete *after* streaming stops.
708 let events = thread
709 .update(cx, |thread, cx| {
710 thread.remove_tool(&EchoTool::NAME);
711 thread.add_tool(DelayTool);
712 thread.send(
713 UserMessageId::new(),
714 [
715 "Now call the delay tool with 200ms.",
716 "When the timer goes off, then you echo the output of the tool.",
717 ],
718 cx,
719 )
720 })
721 .unwrap()
722 .collect()
723 .await;
724 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
725 thread.update(cx, |thread, _cx| {
726 assert!(
727 thread
728 .last_received_or_pending_message()
729 .unwrap()
730 .as_agent_message()
731 .unwrap()
732 .content
733 .iter()
734 .any(|content| {
735 if let AgentMessageContent::Text(text) = content {
736 text.contains("Ding")
737 } else {
738 false
739 }
740 }),
741 "{}",
742 thread.to_markdown()
743 );
744 });
745}
746
747#[gpui::test]
748#[cfg_attr(not(feature = "e2e"), ignore)]
749async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
750 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
751
752 // Test a tool call that's likely to complete *before* streaming stops.
753 let mut events = thread
754 .update(cx, |thread, cx| {
755 thread.add_tool(WordListTool);
756 thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
757 })
758 .unwrap();
759
760 let mut saw_partial_tool_use = false;
761 while let Some(event) = events.next().await {
762 if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
763 thread.update(cx, |thread, _cx| {
764 // Look for a tool use in the thread's last message
765 let message = thread.last_received_or_pending_message().unwrap();
766 let agent_message = message.as_agent_message().unwrap();
767 let last_content = agent_message.content.last().unwrap();
768 if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
769 assert_eq!(last_tool_use.name.as_ref(), "word_list");
770 if tool_call.status == acp::ToolCallStatus::Pending {
771 if !last_tool_use.is_input_complete
772 && last_tool_use.input.get("g").is_none()
773 {
774 saw_partial_tool_use = true;
775 }
776 } else {
777 last_tool_use
778 .input
779 .get("a")
780 .expect("'a' has streamed because input is now complete");
781 last_tool_use
782 .input
783 .get("g")
784 .expect("'g' has streamed because input is now complete");
785 }
786 } else {
787 panic!("last content should be a tool use");
788 }
789 });
790 }
791 }
792
793 assert!(
794 saw_partial_tool_use,
795 "should see at least one partially streamed tool use in the history"
796 );
797}
798
799#[gpui::test]
800async fn test_tool_authorization(cx: &mut TestAppContext) {
801 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
802 let fake_model = model.as_fake();
803
804 let mut events = thread
805 .update(cx, |thread, cx| {
806 thread.add_tool(ToolRequiringPermission);
807 thread.send(UserMessageId::new(), ["abc"], cx)
808 })
809 .unwrap();
810 cx.run_until_parked();
811 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
812 LanguageModelToolUse {
813 id: "tool_id_1".into(),
814 name: ToolRequiringPermission::NAME.into(),
815 raw_input: "{}".into(),
816 input: json!({}),
817 is_input_complete: true,
818 thought_signature: None,
819 },
820 ));
821 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
822 LanguageModelToolUse {
823 id: "tool_id_2".into(),
824 name: ToolRequiringPermission::NAME.into(),
825 raw_input: "{}".into(),
826 input: json!({}),
827 is_input_complete: true,
828 thought_signature: None,
829 },
830 ));
831 fake_model.end_last_completion_stream();
832 let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
833 let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
834
835 // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
836 tool_call_auth_1
837 .response
838 .send(acp::PermissionOptionId::new("allow"))
839 .unwrap();
840 cx.run_until_parked();
841
842 // Reject the second - send "deny" option_id directly since Deny is now a button
843 tool_call_auth_2
844 .response
845 .send(acp::PermissionOptionId::new("deny"))
846 .unwrap();
847 cx.run_until_parked();
848
849 let completion = fake_model.pending_completions().pop().unwrap();
850 let message = completion.messages.last().unwrap();
851 assert_eq!(
852 message.content,
853 vec![
854 language_model::MessageContent::ToolResult(LanguageModelToolResult {
855 tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
856 tool_name: ToolRequiringPermission::NAME.into(),
857 is_error: false,
858 content: "Allowed".into(),
859 output: Some("Allowed".into())
860 }),
861 language_model::MessageContent::ToolResult(LanguageModelToolResult {
862 tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
863 tool_name: ToolRequiringPermission::NAME.into(),
864 is_error: true,
865 content: "Permission to run tool denied by user".into(),
866 output: Some("Permission to run tool denied by user".into())
867 })
868 ]
869 );
870
871 // Simulate yet another tool call.
872 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
873 LanguageModelToolUse {
874 id: "tool_id_3".into(),
875 name: ToolRequiringPermission::NAME.into(),
876 raw_input: "{}".into(),
877 input: json!({}),
878 is_input_complete: true,
879 thought_signature: None,
880 },
881 ));
882 fake_model.end_last_completion_stream();
883
884 // Respond by always allowing tools - send transformed option_id
885 // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
886 let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
887 tool_call_auth_3
888 .response
889 .send(acp::PermissionOptionId::new(
890 "always_allow:tool_requiring_permission",
891 ))
892 .unwrap();
893 cx.run_until_parked();
894 let completion = fake_model.pending_completions().pop().unwrap();
895 let message = completion.messages.last().unwrap();
896 assert_eq!(
897 message.content,
898 vec![language_model::MessageContent::ToolResult(
899 LanguageModelToolResult {
900 tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
901 tool_name: ToolRequiringPermission::NAME.into(),
902 is_error: false,
903 content: "Allowed".into(),
904 output: Some("Allowed".into())
905 }
906 )]
907 );
908
909 // Simulate a final tool call, ensuring we don't trigger authorization.
910 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
911 LanguageModelToolUse {
912 id: "tool_id_4".into(),
913 name: ToolRequiringPermission::NAME.into(),
914 raw_input: "{}".into(),
915 input: json!({}),
916 is_input_complete: true,
917 thought_signature: None,
918 },
919 ));
920 fake_model.end_last_completion_stream();
921 cx.run_until_parked();
922 let completion = fake_model.pending_completions().pop().unwrap();
923 let message = completion.messages.last().unwrap();
924 assert_eq!(
925 message.content,
926 vec![language_model::MessageContent::ToolResult(
927 LanguageModelToolResult {
928 tool_use_id: "tool_id_4".into(),
929 tool_name: ToolRequiringPermission::NAME.into(),
930 is_error: false,
931 content: "Allowed".into(),
932 output: Some("Allowed".into())
933 }
934 )]
935 );
936}
937
938#[gpui::test]
939async fn test_tool_hallucination(cx: &mut TestAppContext) {
940 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
941 let fake_model = model.as_fake();
942
943 let mut events = thread
944 .update(cx, |thread, cx| {
945 thread.send(UserMessageId::new(), ["abc"], cx)
946 })
947 .unwrap();
948 cx.run_until_parked();
949 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
950 LanguageModelToolUse {
951 id: "tool_id_1".into(),
952 name: "nonexistent_tool".into(),
953 raw_input: "{}".into(),
954 input: json!({}),
955 is_input_complete: true,
956 thought_signature: None,
957 },
958 ));
959 fake_model.end_last_completion_stream();
960
961 let tool_call = expect_tool_call(&mut events).await;
962 assert_eq!(tool_call.title, "nonexistent_tool");
963 assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
964 let update = expect_tool_call_update_fields(&mut events).await;
965 assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
966}
967
968async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
969 let event = events
970 .next()
971 .await
972 .expect("no tool call authorization event received")
973 .unwrap();
974 match event {
975 ThreadEvent::ToolCall(tool_call) => tool_call,
976 event => {
977 panic!("Unexpected event {event:?}");
978 }
979 }
980}
981
982async fn expect_tool_call_update_fields(
983 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
984) -> acp::ToolCallUpdate {
985 let event = events
986 .next()
987 .await
988 .expect("no tool call authorization event received")
989 .unwrap();
990 match event {
991 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
992 event => {
993 panic!("Unexpected event {event:?}");
994 }
995 }
996}
997
998async fn next_tool_call_authorization(
999 events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1000) -> ToolCallAuthorization {
1001 loop {
1002 let event = events
1003 .next()
1004 .await
1005 .expect("no tool call authorization event received")
1006 .unwrap();
1007 if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1008 let permission_kinds = tool_call_authorization
1009 .options
1010 .first_option_of_kind(acp::PermissionOptionKind::AllowAlways)
1011 .map(|option| option.kind);
1012 let allow_once = tool_call_authorization
1013 .options
1014 .first_option_of_kind(acp::PermissionOptionKind::AllowOnce)
1015 .map(|option| option.kind);
1016
1017 assert_eq!(
1018 permission_kinds,
1019 Some(acp::PermissionOptionKind::AllowAlways)
1020 );
1021 assert_eq!(allow_once, Some(acp::PermissionOptionKind::AllowOnce));
1022 return tool_call_authorization;
1023 }
1024 }
1025}
1026
1027#[test]
1028fn test_permission_options_terminal_with_pattern() {
1029 let permission_options = ToolPermissionContext::new(
1030 TerminalTool::NAME,
1031 vec!["cargo build --release".to_string()],
1032 )
1033 .build_permission_options();
1034
1035 let PermissionOptions::Dropdown(choices) = permission_options else {
1036 panic!("Expected dropdown permission options");
1037 };
1038
1039 assert_eq!(choices.len(), 3);
1040 let labels: Vec<&str> = choices
1041 .iter()
1042 .map(|choice| choice.allow.name.as_ref())
1043 .collect();
1044 assert!(labels.contains(&"Always for terminal"));
1045 assert!(labels.contains(&"Always for `cargo build` commands"));
1046 assert!(labels.contains(&"Only this time"));
1047}
1048
1049#[test]
1050fn test_permission_options_terminal_command_with_flag_second_token() {
1051 let permission_options =
1052 ToolPermissionContext::new(TerminalTool::NAME, vec!["ls -la".to_string()])
1053 .build_permission_options();
1054
1055 let PermissionOptions::Dropdown(choices) = permission_options else {
1056 panic!("Expected dropdown permission options");
1057 };
1058
1059 assert_eq!(choices.len(), 3);
1060 let labels: Vec<&str> = choices
1061 .iter()
1062 .map(|choice| choice.allow.name.as_ref())
1063 .collect();
1064 assert!(labels.contains(&"Always for terminal"));
1065 assert!(labels.contains(&"Always for `ls` commands"));
1066 assert!(labels.contains(&"Only this time"));
1067}
1068
1069#[test]
1070fn test_permission_options_terminal_single_word_command() {
1071 let permission_options =
1072 ToolPermissionContext::new(TerminalTool::NAME, vec!["whoami".to_string()])
1073 .build_permission_options();
1074
1075 let PermissionOptions::Dropdown(choices) = permission_options else {
1076 panic!("Expected dropdown permission options");
1077 };
1078
1079 assert_eq!(choices.len(), 3);
1080 let labels: Vec<&str> = choices
1081 .iter()
1082 .map(|choice| choice.allow.name.as_ref())
1083 .collect();
1084 assert!(labels.contains(&"Always for terminal"));
1085 assert!(labels.contains(&"Always for `whoami` commands"));
1086 assert!(labels.contains(&"Only this time"));
1087}
1088
1089#[test]
1090fn test_permission_options_edit_file_with_path_pattern() {
1091 let permission_options =
1092 ToolPermissionContext::new(EditFileTool::NAME, vec!["src/main.rs".to_string()])
1093 .build_permission_options();
1094
1095 let PermissionOptions::Dropdown(choices) = permission_options else {
1096 panic!("Expected dropdown permission options");
1097 };
1098
1099 let labels: Vec<&str> = choices
1100 .iter()
1101 .map(|choice| choice.allow.name.as_ref())
1102 .collect();
1103 assert!(labels.contains(&"Always for edit file"));
1104 assert!(labels.contains(&"Always for `src/`"));
1105}
1106
1107#[test]
1108fn test_permission_options_fetch_with_domain_pattern() {
1109 let permission_options =
1110 ToolPermissionContext::new(FetchTool::NAME, vec!["https://docs.rs/gpui".to_string()])
1111 .build_permission_options();
1112
1113 let PermissionOptions::Dropdown(choices) = permission_options else {
1114 panic!("Expected dropdown permission options");
1115 };
1116
1117 let labels: Vec<&str> = choices
1118 .iter()
1119 .map(|choice| choice.allow.name.as_ref())
1120 .collect();
1121 assert!(labels.contains(&"Always for fetch"));
1122 assert!(labels.contains(&"Always for `docs.rs`"));
1123}
1124
1125#[test]
1126fn test_permission_options_without_pattern() {
1127 let permission_options = ToolPermissionContext::new(
1128 TerminalTool::NAME,
1129 vec!["./deploy.sh --production".to_string()],
1130 )
1131 .build_permission_options();
1132
1133 let PermissionOptions::Dropdown(choices) = permission_options else {
1134 panic!("Expected dropdown permission options");
1135 };
1136
1137 assert_eq!(choices.len(), 2);
1138 let labels: Vec<&str> = choices
1139 .iter()
1140 .map(|choice| choice.allow.name.as_ref())
1141 .collect();
1142 assert!(labels.contains(&"Always for terminal"));
1143 assert!(labels.contains(&"Only this time"));
1144 assert!(!labels.iter().any(|label| label.contains("commands")));
1145}
1146
1147#[test]
1148fn test_permission_options_symlink_target_are_flat_once_only() {
1149 let permission_options =
1150 ToolPermissionContext::symlink_target(EditFileTool::NAME, vec!["/outside/file.txt".into()])
1151 .build_permission_options();
1152
1153 let PermissionOptions::Flat(options) = permission_options else {
1154 panic!("Expected flat permission options for symlink target authorization");
1155 };
1156
1157 assert_eq!(options.len(), 2);
1158 assert!(options.iter().any(|option| {
1159 option.option_id.0.as_ref() == "allow"
1160 && option.kind == acp::PermissionOptionKind::AllowOnce
1161 }));
1162 assert!(options.iter().any(|option| {
1163 option.option_id.0.as_ref() == "deny"
1164 && option.kind == acp::PermissionOptionKind::RejectOnce
1165 }));
1166}
1167
1168#[test]
1169fn test_permission_option_ids_for_terminal() {
1170 let permission_options = ToolPermissionContext::new(
1171 TerminalTool::NAME,
1172 vec!["cargo build --release".to_string()],
1173 )
1174 .build_permission_options();
1175
1176 let PermissionOptions::Dropdown(choices) = permission_options else {
1177 panic!("Expected dropdown permission options");
1178 };
1179
1180 let allow_ids: Vec<String> = choices
1181 .iter()
1182 .map(|choice| choice.allow.option_id.0.to_string())
1183 .collect();
1184 let deny_ids: Vec<String> = choices
1185 .iter()
1186 .map(|choice| choice.deny.option_id.0.to_string())
1187 .collect();
1188
1189 assert!(allow_ids.contains(&"always_allow:terminal".to_string()));
1190 assert!(allow_ids.contains(&"allow".to_string()));
1191 assert!(
1192 allow_ids
1193 .iter()
1194 .any(|id| id.starts_with("always_allow_pattern:terminal\n")),
1195 "Missing allow pattern option"
1196 );
1197
1198 assert!(deny_ids.contains(&"always_deny:terminal".to_string()));
1199 assert!(deny_ids.contains(&"deny".to_string()));
1200 assert!(
1201 deny_ids
1202 .iter()
1203 .any(|id| id.starts_with("always_deny_pattern:terminal\n")),
1204 "Missing deny pattern option"
1205 );
1206}
1207
1208#[gpui::test]
1209#[cfg_attr(not(feature = "e2e"), ignore)]
1210async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1211 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1212
1213 // Test concurrent tool calls with different delay times
1214 let events = thread
1215 .update(cx, |thread, cx| {
1216 thread.add_tool(DelayTool);
1217 thread.send(
1218 UserMessageId::new(),
1219 [
1220 "Call the delay tool twice in the same message.",
1221 "Once with 100ms. Once with 300ms.",
1222 "When both timers are complete, describe the outputs.",
1223 ],
1224 cx,
1225 )
1226 })
1227 .unwrap()
1228 .collect()
1229 .await;
1230
1231 let stop_reasons = stop_events(events);
1232 assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1233
1234 thread.update(cx, |thread, _cx| {
1235 let last_message = thread.last_received_or_pending_message().unwrap();
1236 let agent_message = last_message.as_agent_message().unwrap();
1237 let text = agent_message
1238 .content
1239 .iter()
1240 .filter_map(|content| {
1241 if let AgentMessageContent::Text(text) = content {
1242 Some(text.as_str())
1243 } else {
1244 None
1245 }
1246 })
1247 .collect::<String>();
1248
1249 assert!(text.contains("Ding"));
1250 });
1251}
1252
1253#[gpui::test]
1254async fn test_profiles(cx: &mut TestAppContext) {
1255 let ThreadTest {
1256 model, thread, fs, ..
1257 } = setup(cx, TestModel::Fake).await;
1258 let fake_model = model.as_fake();
1259
1260 thread.update(cx, |thread, _cx| {
1261 thread.add_tool(DelayTool);
1262 thread.add_tool(EchoTool);
1263 thread.add_tool(InfiniteTool);
1264 });
1265
1266 // Override profiles and wait for settings to be loaded.
1267 fs.insert_file(
1268 paths::settings_file(),
1269 json!({
1270 "agent": {
1271 "profiles": {
1272 "test-1": {
1273 "name": "Test Profile 1",
1274 "tools": {
1275 EchoTool::NAME: true,
1276 DelayTool::NAME: true,
1277 }
1278 },
1279 "test-2": {
1280 "name": "Test Profile 2",
1281 "tools": {
1282 InfiniteTool::NAME: true,
1283 }
1284 }
1285 }
1286 }
1287 })
1288 .to_string()
1289 .into_bytes(),
1290 )
1291 .await;
1292 cx.run_until_parked();
1293
1294 // Test that test-1 profile (default) has echo and delay tools
1295 thread
1296 .update(cx, |thread, cx| {
1297 thread.set_profile(AgentProfileId("test-1".into()), cx);
1298 thread.send(UserMessageId::new(), ["test"], cx)
1299 })
1300 .unwrap();
1301 cx.run_until_parked();
1302
1303 let mut pending_completions = fake_model.pending_completions();
1304 assert_eq!(pending_completions.len(), 1);
1305 let completion = pending_completions.pop().unwrap();
1306 let tool_names: Vec<String> = completion
1307 .tools
1308 .iter()
1309 .map(|tool| tool.name.clone())
1310 .collect();
1311 assert_eq!(tool_names, vec![DelayTool::NAME, EchoTool::NAME]);
1312 fake_model.end_last_completion_stream();
1313
1314 // Switch to test-2 profile, and verify that it has only the infinite tool.
1315 thread
1316 .update(cx, |thread, cx| {
1317 thread.set_profile(AgentProfileId("test-2".into()), cx);
1318 thread.send(UserMessageId::new(), ["test2"], cx)
1319 })
1320 .unwrap();
1321 cx.run_until_parked();
1322 let mut pending_completions = fake_model.pending_completions();
1323 assert_eq!(pending_completions.len(), 1);
1324 let completion = pending_completions.pop().unwrap();
1325 let tool_names: Vec<String> = completion
1326 .tools
1327 .iter()
1328 .map(|tool| tool.name.clone())
1329 .collect();
1330 assert_eq!(tool_names, vec![InfiniteTool::NAME]);
1331}
1332
1333#[gpui::test]
1334async fn test_mcp_tools(cx: &mut TestAppContext) {
1335 let ThreadTest {
1336 model,
1337 thread,
1338 context_server_store,
1339 fs,
1340 ..
1341 } = setup(cx, TestModel::Fake).await;
1342 let fake_model = model.as_fake();
1343
1344 // Override profiles and wait for settings to be loaded.
1345 fs.insert_file(
1346 paths::settings_file(),
1347 json!({
1348 "agent": {
1349 "tool_permissions": { "default": "allow" },
1350 "profiles": {
1351 "test": {
1352 "name": "Test Profile",
1353 "enable_all_context_servers": true,
1354 "tools": {
1355 EchoTool::NAME: true,
1356 }
1357 },
1358 }
1359 }
1360 })
1361 .to_string()
1362 .into_bytes(),
1363 )
1364 .await;
1365 cx.run_until_parked();
1366 thread.update(cx, |thread, cx| {
1367 thread.set_profile(AgentProfileId("test".into()), cx)
1368 });
1369
1370 let mut mcp_tool_calls = setup_context_server(
1371 "test_server",
1372 vec![context_server::types::Tool {
1373 name: "echo".into(),
1374 description: None,
1375 input_schema: serde_json::to_value(EchoTool::input_schema(
1376 LanguageModelToolSchemaFormat::JsonSchema,
1377 ))
1378 .unwrap(),
1379 output_schema: None,
1380 annotations: None,
1381 }],
1382 &context_server_store,
1383 cx,
1384 );
1385
1386 let events = thread.update(cx, |thread, cx| {
1387 thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1388 });
1389 cx.run_until_parked();
1390
1391 // Simulate the model calling the MCP tool.
1392 let completion = fake_model.pending_completions().pop().unwrap();
1393 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1394 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1395 LanguageModelToolUse {
1396 id: "tool_1".into(),
1397 name: "echo".into(),
1398 raw_input: json!({"text": "test"}).to_string(),
1399 input: json!({"text": "test"}),
1400 is_input_complete: true,
1401 thought_signature: None,
1402 },
1403 ));
1404 fake_model.end_last_completion_stream();
1405 cx.run_until_parked();
1406
1407 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1408 assert_eq!(tool_call_params.name, "echo");
1409 assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1410 tool_call_response
1411 .send(context_server::types::CallToolResponse {
1412 content: vec![context_server::types::ToolResponseContent::Text {
1413 text: "test".into(),
1414 }],
1415 is_error: None,
1416 meta: None,
1417 structured_content: None,
1418 })
1419 .unwrap();
1420 cx.run_until_parked();
1421
1422 assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1423 fake_model.send_last_completion_stream_text_chunk("Done!");
1424 fake_model.end_last_completion_stream();
1425 events.collect::<Vec<_>>().await;
1426
1427 // Send again after adding the echo tool, ensuring the name collision is resolved.
1428 let events = thread.update(cx, |thread, cx| {
1429 thread.add_tool(EchoTool);
1430 thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1431 });
1432 cx.run_until_parked();
1433 let completion = fake_model.pending_completions().pop().unwrap();
1434 assert_eq!(
1435 tool_names_for_completion(&completion),
1436 vec!["echo", "test_server_echo"]
1437 );
1438 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1439 LanguageModelToolUse {
1440 id: "tool_2".into(),
1441 name: "test_server_echo".into(),
1442 raw_input: json!({"text": "mcp"}).to_string(),
1443 input: json!({"text": "mcp"}),
1444 is_input_complete: true,
1445 thought_signature: None,
1446 },
1447 ));
1448 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1449 LanguageModelToolUse {
1450 id: "tool_3".into(),
1451 name: "echo".into(),
1452 raw_input: json!({"text": "native"}).to_string(),
1453 input: json!({"text": "native"}),
1454 is_input_complete: true,
1455 thought_signature: None,
1456 },
1457 ));
1458 fake_model.end_last_completion_stream();
1459 cx.run_until_parked();
1460
1461 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1462 assert_eq!(tool_call_params.name, "echo");
1463 assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1464 tool_call_response
1465 .send(context_server::types::CallToolResponse {
1466 content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1467 is_error: None,
1468 meta: None,
1469 structured_content: None,
1470 })
1471 .unwrap();
1472 cx.run_until_parked();
1473
1474 // Ensure the tool results were inserted with the correct names.
1475 let completion = fake_model.pending_completions().pop().unwrap();
1476 assert_eq!(
1477 completion.messages.last().unwrap().content,
1478 vec![
1479 MessageContent::ToolResult(LanguageModelToolResult {
1480 tool_use_id: "tool_3".into(),
1481 tool_name: "echo".into(),
1482 is_error: false,
1483 content: "native".into(),
1484 output: Some("native".into()),
1485 },),
1486 MessageContent::ToolResult(LanguageModelToolResult {
1487 tool_use_id: "tool_2".into(),
1488 tool_name: "test_server_echo".into(),
1489 is_error: false,
1490 content: "mcp".into(),
1491 output: Some("mcp".into()),
1492 },),
1493 ]
1494 );
1495 fake_model.end_last_completion_stream();
1496 events.collect::<Vec<_>>().await;
1497}
1498
1499#[gpui::test]
1500async fn test_mcp_tool_result_displayed_when_server_disconnected(cx: &mut TestAppContext) {
1501 let ThreadTest {
1502 model,
1503 thread,
1504 context_server_store,
1505 fs,
1506 ..
1507 } = setup(cx, TestModel::Fake).await;
1508 let fake_model = model.as_fake();
1509
1510 // Setup settings to allow MCP tools
1511 fs.insert_file(
1512 paths::settings_file(),
1513 json!({
1514 "agent": {
1515 "always_allow_tool_actions": true,
1516 "profiles": {
1517 "test": {
1518 "name": "Test Profile",
1519 "enable_all_context_servers": true,
1520 "tools": {}
1521 },
1522 }
1523 }
1524 })
1525 .to_string()
1526 .into_bytes(),
1527 )
1528 .await;
1529 cx.run_until_parked();
1530 thread.update(cx, |thread, cx| {
1531 thread.set_profile(AgentProfileId("test".into()), cx)
1532 });
1533
1534 // Setup a context server with a tool
1535 let mut mcp_tool_calls = setup_context_server(
1536 "github_server",
1537 vec![context_server::types::Tool {
1538 name: "issue_read".into(),
1539 description: Some("Read a GitHub issue".into()),
1540 input_schema: json!({
1541 "type": "object",
1542 "properties": {
1543 "issue_url": { "type": "string" }
1544 }
1545 }),
1546 output_schema: None,
1547 annotations: None,
1548 }],
1549 &context_server_store,
1550 cx,
1551 );
1552
1553 // Send a message and have the model call the MCP tool
1554 let events = thread.update(cx, |thread, cx| {
1555 thread
1556 .send(UserMessageId::new(), ["Read issue #47404"], cx)
1557 .unwrap()
1558 });
1559 cx.run_until_parked();
1560
1561 // Verify the MCP tool is available to the model
1562 let completion = fake_model.pending_completions().pop().unwrap();
1563 assert_eq!(
1564 tool_names_for_completion(&completion),
1565 vec!["issue_read"],
1566 "MCP tool should be available"
1567 );
1568
1569 // Simulate the model calling the MCP tool
1570 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1571 LanguageModelToolUse {
1572 id: "tool_1".into(),
1573 name: "issue_read".into(),
1574 raw_input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"})
1575 .to_string(),
1576 input: json!({"issue_url": "https://github.com/zed-industries/zed/issues/47404"}),
1577 is_input_complete: true,
1578 thought_signature: None,
1579 },
1580 ));
1581 fake_model.end_last_completion_stream();
1582 cx.run_until_parked();
1583
1584 // The MCP server receives the tool call and responds with content
1585 let expected_tool_output = "Issue #47404: Tool call results are cleared upon app close";
1586 let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1587 assert_eq!(tool_call_params.name, "issue_read");
1588 tool_call_response
1589 .send(context_server::types::CallToolResponse {
1590 content: vec![context_server::types::ToolResponseContent::Text {
1591 text: expected_tool_output.into(),
1592 }],
1593 is_error: None,
1594 meta: None,
1595 structured_content: None,
1596 })
1597 .unwrap();
1598 cx.run_until_parked();
1599
1600 // After tool completes, the model continues with a new completion request
1601 // that includes the tool results. We need to respond to this.
1602 let _completion = fake_model.pending_completions().pop().unwrap();
1603 fake_model.send_last_completion_stream_text_chunk("I found the issue!");
1604 fake_model
1605 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1606 fake_model.end_last_completion_stream();
1607 events.collect::<Vec<_>>().await;
1608
1609 // Verify the tool result is stored in the thread by checking the markdown output.
1610 // The tool result is in the first assistant message (not the last one, which is
1611 // the model's response after the tool completed).
1612 thread.update(cx, |thread, _cx| {
1613 let markdown = thread.to_markdown();
1614 assert!(
1615 markdown.contains("**Tool Result**: issue_read"),
1616 "Thread should contain tool result header"
1617 );
1618 assert!(
1619 markdown.contains(expected_tool_output),
1620 "Thread should contain tool output: {}",
1621 expected_tool_output
1622 );
1623 });
1624
1625 // Simulate app restart: disconnect the MCP server.
1626 // After restart, the MCP server won't be connected yet when the thread is replayed.
1627 context_server_store.update(cx, |store, cx| {
1628 let _ = store.stop_server(&ContextServerId("github_server".into()), cx);
1629 });
1630 cx.run_until_parked();
1631
1632 // Replay the thread (this is what happens when loading a saved thread)
1633 let mut replay_events = thread.update(cx, |thread, cx| thread.replay(cx));
1634
1635 let mut found_tool_call = None;
1636 let mut found_tool_call_update_with_output = None;
1637
1638 while let Some(event) = replay_events.next().await {
1639 let event = event.unwrap();
1640 match &event {
1641 ThreadEvent::ToolCall(tc) if tc.tool_call_id.to_string() == "tool_1" => {
1642 found_tool_call = Some(tc.clone());
1643 }
1644 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update))
1645 if update.tool_call_id.to_string() == "tool_1" =>
1646 {
1647 if update.fields.raw_output.is_some() {
1648 found_tool_call_update_with_output = Some(update.clone());
1649 }
1650 }
1651 _ => {}
1652 }
1653 }
1654
1655 // The tool call should be found
1656 assert!(
1657 found_tool_call.is_some(),
1658 "Tool call should be emitted during replay"
1659 );
1660
1661 assert!(
1662 found_tool_call_update_with_output.is_some(),
1663 "ToolCallUpdate with raw_output should be emitted even when MCP server is disconnected."
1664 );
1665
1666 let update = found_tool_call_update_with_output.unwrap();
1667 assert_eq!(
1668 update.fields.raw_output,
1669 Some(expected_tool_output.into()),
1670 "raw_output should contain the saved tool result"
1671 );
1672
1673 // Also verify the status is correct (completed, not failed)
1674 assert_eq!(
1675 update.fields.status,
1676 Some(acp::ToolCallStatus::Completed),
1677 "Tool call status should reflect the original completion status"
1678 );
1679}
1680
1681#[gpui::test]
1682async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1683 let ThreadTest {
1684 model,
1685 thread,
1686 context_server_store,
1687 fs,
1688 ..
1689 } = setup(cx, TestModel::Fake).await;
1690 let fake_model = model.as_fake();
1691
1692 // Set up a profile with all tools enabled
1693 fs.insert_file(
1694 paths::settings_file(),
1695 json!({
1696 "agent": {
1697 "profiles": {
1698 "test": {
1699 "name": "Test Profile",
1700 "enable_all_context_servers": true,
1701 "tools": {
1702 EchoTool::NAME: true,
1703 DelayTool::NAME: true,
1704 WordListTool::NAME: true,
1705 ToolRequiringPermission::NAME: true,
1706 InfiniteTool::NAME: true,
1707 }
1708 },
1709 }
1710 }
1711 })
1712 .to_string()
1713 .into_bytes(),
1714 )
1715 .await;
1716 cx.run_until_parked();
1717
1718 thread.update(cx, |thread, cx| {
1719 thread.set_profile(AgentProfileId("test".into()), cx);
1720 thread.add_tool(EchoTool);
1721 thread.add_tool(DelayTool);
1722 thread.add_tool(WordListTool);
1723 thread.add_tool(ToolRequiringPermission);
1724 thread.add_tool(InfiniteTool);
1725 });
1726
1727 // Set up multiple context servers with some overlapping tool names
1728 let _server1_calls = setup_context_server(
1729 "xxx",
1730 vec![
1731 context_server::types::Tool {
1732 name: "echo".into(), // Conflicts with native EchoTool
1733 description: None,
1734 input_schema: serde_json::to_value(EchoTool::input_schema(
1735 LanguageModelToolSchemaFormat::JsonSchema,
1736 ))
1737 .unwrap(),
1738 output_schema: None,
1739 annotations: None,
1740 },
1741 context_server::types::Tool {
1742 name: "unique_tool_1".into(),
1743 description: None,
1744 input_schema: json!({"type": "object", "properties": {}}),
1745 output_schema: None,
1746 annotations: None,
1747 },
1748 ],
1749 &context_server_store,
1750 cx,
1751 );
1752
1753 let _server2_calls = setup_context_server(
1754 "yyy",
1755 vec![
1756 context_server::types::Tool {
1757 name: "echo".into(), // Also conflicts with native EchoTool
1758 description: None,
1759 input_schema: serde_json::to_value(EchoTool::input_schema(
1760 LanguageModelToolSchemaFormat::JsonSchema,
1761 ))
1762 .unwrap(),
1763 output_schema: None,
1764 annotations: None,
1765 },
1766 context_server::types::Tool {
1767 name: "unique_tool_2".into(),
1768 description: None,
1769 input_schema: json!({"type": "object", "properties": {}}),
1770 output_schema: None,
1771 annotations: None,
1772 },
1773 context_server::types::Tool {
1774 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1775 description: None,
1776 input_schema: json!({"type": "object", "properties": {}}),
1777 output_schema: None,
1778 annotations: None,
1779 },
1780 context_server::types::Tool {
1781 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1782 description: None,
1783 input_schema: json!({"type": "object", "properties": {}}),
1784 output_schema: None,
1785 annotations: None,
1786 },
1787 ],
1788 &context_server_store,
1789 cx,
1790 );
1791 let _server3_calls = setup_context_server(
1792 "zzz",
1793 vec![
1794 context_server::types::Tool {
1795 name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1796 description: None,
1797 input_schema: json!({"type": "object", "properties": {}}),
1798 output_schema: None,
1799 annotations: None,
1800 },
1801 context_server::types::Tool {
1802 name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1803 description: None,
1804 input_schema: json!({"type": "object", "properties": {}}),
1805 output_schema: None,
1806 annotations: None,
1807 },
1808 context_server::types::Tool {
1809 name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1810 description: None,
1811 input_schema: json!({"type": "object", "properties": {}}),
1812 output_schema: None,
1813 annotations: None,
1814 },
1815 ],
1816 &context_server_store,
1817 cx,
1818 );
1819
1820 // Server with spaces in name - tests snake_case conversion for API compatibility
1821 let _server4_calls = setup_context_server(
1822 "Azure DevOps",
1823 vec![context_server::types::Tool {
1824 name: "echo".into(), // Also conflicts - will be disambiguated as azure_dev_ops_echo
1825 description: None,
1826 input_schema: serde_json::to_value(EchoTool::input_schema(
1827 LanguageModelToolSchemaFormat::JsonSchema,
1828 ))
1829 .unwrap(),
1830 output_schema: None,
1831 annotations: None,
1832 }],
1833 &context_server_store,
1834 cx,
1835 );
1836
1837 thread
1838 .update(cx, |thread, cx| {
1839 thread.send(UserMessageId::new(), ["Go"], cx)
1840 })
1841 .unwrap();
1842 cx.run_until_parked();
1843 let completion = fake_model.pending_completions().pop().unwrap();
1844 assert_eq!(
1845 tool_names_for_completion(&completion),
1846 vec![
1847 "azure_dev_ops_echo",
1848 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1849 "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1850 "delay",
1851 "echo",
1852 "infinite",
1853 "tool_requiring_permission",
1854 "unique_tool_1",
1855 "unique_tool_2",
1856 "word_list",
1857 "xxx_echo",
1858 "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1859 "yyy_echo",
1860 "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1861 ]
1862 );
1863}
1864
1865#[gpui::test]
1866#[cfg_attr(not(feature = "e2e"), ignore)]
1867async fn test_cancellation(cx: &mut TestAppContext) {
1868 let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1869
1870 let mut events = thread
1871 .update(cx, |thread, cx| {
1872 thread.add_tool(InfiniteTool);
1873 thread.add_tool(EchoTool);
1874 thread.send(
1875 UserMessageId::new(),
1876 ["Call the echo tool, then call the infinite tool, then explain their output"],
1877 cx,
1878 )
1879 })
1880 .unwrap();
1881
1882 // Wait until both tools are called.
1883 let mut expected_tools = vec!["Echo", "Infinite Tool"];
1884 let mut echo_id = None;
1885 let mut echo_completed = false;
1886 while let Some(event) = events.next().await {
1887 match event.unwrap() {
1888 ThreadEvent::ToolCall(tool_call) => {
1889 assert_eq!(tool_call.title, expected_tools.remove(0));
1890 if tool_call.title == "Echo" {
1891 echo_id = Some(tool_call.tool_call_id);
1892 }
1893 }
1894 ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1895 acp::ToolCallUpdate {
1896 tool_call_id,
1897 fields:
1898 acp::ToolCallUpdateFields {
1899 status: Some(acp::ToolCallStatus::Completed),
1900 ..
1901 },
1902 ..
1903 },
1904 )) if Some(&tool_call_id) == echo_id.as_ref() => {
1905 echo_completed = true;
1906 }
1907 _ => {}
1908 }
1909
1910 if expected_tools.is_empty() && echo_completed {
1911 break;
1912 }
1913 }
1914
1915 // Cancel the current send and ensure that the event stream is closed, even
1916 // if one of the tools is still running.
1917 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1918 let events = events.collect::<Vec<_>>().await;
1919 let last_event = events.last();
1920 assert!(
1921 matches!(
1922 last_event,
1923 Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1924 ),
1925 "unexpected event {last_event:?}"
1926 );
1927
1928 // Ensure we can still send a new message after cancellation.
1929 let events = thread
1930 .update(cx, |thread, cx| {
1931 thread.send(
1932 UserMessageId::new(),
1933 ["Testing: reply with 'Hello' then stop."],
1934 cx,
1935 )
1936 })
1937 .unwrap()
1938 .collect::<Vec<_>>()
1939 .await;
1940 thread.update(cx, |thread, _cx| {
1941 let message = thread.last_received_or_pending_message().unwrap();
1942 let agent_message = message.as_agent_message().unwrap();
1943 assert_eq!(
1944 agent_message.content,
1945 vec![AgentMessageContent::Text("Hello".to_string())]
1946 );
1947 });
1948 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1949}
1950
1951#[gpui::test]
1952async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1953 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1954 always_allow_tools(cx);
1955 let fake_model = model.as_fake();
1956
1957 let environment = Rc::new(cx.update(|cx| {
1958 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
1959 }));
1960 let handle = environment.terminal_handle.clone().unwrap();
1961
1962 let mut events = thread
1963 .update(cx, |thread, cx| {
1964 thread.add_tool(crate::TerminalTool::new(
1965 thread.project().clone(),
1966 environment,
1967 ));
1968 thread.send(UserMessageId::new(), ["run a command"], cx)
1969 })
1970 .unwrap();
1971
1972 cx.run_until_parked();
1973
1974 // Simulate the model calling the terminal tool
1975 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1976 LanguageModelToolUse {
1977 id: "terminal_tool_1".into(),
1978 name: TerminalTool::NAME.into(),
1979 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1980 input: json!({"command": "sleep 1000", "cd": "."}),
1981 is_input_complete: true,
1982 thought_signature: None,
1983 },
1984 ));
1985 fake_model.end_last_completion_stream();
1986
1987 // Wait for the terminal tool to start running
1988 wait_for_terminal_tool_started(&mut events, cx).await;
1989
1990 // Cancel the thread while the terminal is running
1991 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1992
1993 // Collect remaining events, driving the executor to let cancellation complete
1994 let remaining_events = collect_events_until_stop(&mut events, cx).await;
1995
1996 // Verify the terminal was killed
1997 assert!(
1998 handle.was_killed(),
1999 "expected terminal handle to be killed on cancellation"
2000 );
2001
2002 // Verify we got a cancellation stop event
2003 assert_eq!(
2004 stop_events(remaining_events),
2005 vec![acp::StopReason::Cancelled],
2006 );
2007
2008 // Verify the tool result contains the terminal output, not just "Tool canceled by user"
2009 thread.update(cx, |thread, _cx| {
2010 let message = thread.last_received_or_pending_message().unwrap();
2011 let agent_message = message.as_agent_message().unwrap();
2012
2013 let tool_use = agent_message
2014 .content
2015 .iter()
2016 .find_map(|content| match content {
2017 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2018 _ => None,
2019 })
2020 .expect("expected tool use in agent message");
2021
2022 let tool_result = agent_message
2023 .tool_results
2024 .get(&tool_use.id)
2025 .expect("expected tool result");
2026
2027 let result_text = match &tool_result.content {
2028 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2029 _ => panic!("expected text content in tool result"),
2030 };
2031
2032 // "partial output" comes from FakeTerminalHandle's output field
2033 assert!(
2034 result_text.contains("partial output"),
2035 "expected tool result to contain terminal output, got: {result_text}"
2036 );
2037 // Match the actual format from process_content in terminal_tool.rs
2038 assert!(
2039 result_text.contains("The user stopped this command"),
2040 "expected tool result to indicate user stopped, got: {result_text}"
2041 );
2042 });
2043
2044 // Verify we can send a new message after cancellation
2045 verify_thread_recovery(&thread, &fake_model, cx).await;
2046}
2047
2048#[gpui::test]
2049async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
2050 // This test verifies that tools which properly handle cancellation via
2051 // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
2052 // to cancellation and report that they were cancelled.
2053 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2054 always_allow_tools(cx);
2055 let fake_model = model.as_fake();
2056
2057 let (tool, was_cancelled) = CancellationAwareTool::new();
2058
2059 let mut events = thread
2060 .update(cx, |thread, cx| {
2061 thread.add_tool(tool);
2062 thread.send(
2063 UserMessageId::new(),
2064 ["call the cancellation aware tool"],
2065 cx,
2066 )
2067 })
2068 .unwrap();
2069
2070 cx.run_until_parked();
2071
2072 // Simulate the model calling the cancellation-aware tool
2073 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2074 LanguageModelToolUse {
2075 id: "cancellation_aware_1".into(),
2076 name: "cancellation_aware".into(),
2077 raw_input: r#"{}"#.into(),
2078 input: json!({}),
2079 is_input_complete: true,
2080 thought_signature: None,
2081 },
2082 ));
2083 fake_model.end_last_completion_stream();
2084
2085 cx.run_until_parked();
2086
2087 // Wait for the tool call to be reported
2088 let mut tool_started = false;
2089 let deadline = cx.executor().num_cpus() * 100;
2090 for _ in 0..deadline {
2091 cx.run_until_parked();
2092
2093 while let Some(Some(event)) = events.next().now_or_never() {
2094 if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
2095 if tool_call.title == "Cancellation Aware Tool" {
2096 tool_started = true;
2097 break;
2098 }
2099 }
2100 }
2101
2102 if tool_started {
2103 break;
2104 }
2105
2106 cx.background_executor
2107 .timer(Duration::from_millis(10))
2108 .await;
2109 }
2110 assert!(tool_started, "expected cancellation aware tool to start");
2111
2112 // Cancel the thread and wait for it to complete
2113 let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
2114
2115 // The cancel task should complete promptly because the tool handles cancellation
2116 let timeout = cx.background_executor.timer(Duration::from_secs(5));
2117 futures::select! {
2118 _ = cancel_task.fuse() => {}
2119 _ = timeout.fuse() => {
2120 panic!("cancel task timed out - tool did not respond to cancellation");
2121 }
2122 }
2123
2124 // Verify the tool detected cancellation via its flag
2125 assert!(
2126 was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
2127 "tool should have detected cancellation via event_stream.cancelled_by_user()"
2128 );
2129
2130 // Collect remaining events
2131 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2132
2133 // Verify we got a cancellation stop event
2134 assert_eq!(
2135 stop_events(remaining_events),
2136 vec![acp::StopReason::Cancelled],
2137 );
2138
2139 // Verify we can send a new message after cancellation
2140 verify_thread_recovery(&thread, &fake_model, cx).await;
2141}
2142
2143/// Helper to verify thread can recover after cancellation by sending a simple message.
2144async fn verify_thread_recovery(
2145 thread: &Entity<Thread>,
2146 fake_model: &FakeLanguageModel,
2147 cx: &mut TestAppContext,
2148) {
2149 let events = thread
2150 .update(cx, |thread, cx| {
2151 thread.send(
2152 UserMessageId::new(),
2153 ["Testing: reply with 'Hello' then stop."],
2154 cx,
2155 )
2156 })
2157 .unwrap();
2158 cx.run_until_parked();
2159 fake_model.send_last_completion_stream_text_chunk("Hello");
2160 fake_model
2161 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2162 fake_model.end_last_completion_stream();
2163
2164 let events = events.collect::<Vec<_>>().await;
2165 thread.update(cx, |thread, _cx| {
2166 let message = thread.last_received_or_pending_message().unwrap();
2167 let agent_message = message.as_agent_message().unwrap();
2168 assert_eq!(
2169 agent_message.content,
2170 vec![AgentMessageContent::Text("Hello".to_string())]
2171 );
2172 });
2173 assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
2174}
2175
2176/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
2177async fn wait_for_terminal_tool_started(
2178 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2179 cx: &mut TestAppContext,
2180) {
2181 let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
2182 for _ in 0..deadline {
2183 cx.run_until_parked();
2184
2185 while let Some(Some(event)) = events.next().now_or_never() {
2186 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2187 update,
2188 ))) = &event
2189 {
2190 if update.fields.content.as_ref().is_some_and(|content| {
2191 content
2192 .iter()
2193 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2194 }) {
2195 return;
2196 }
2197 }
2198 }
2199
2200 cx.background_executor
2201 .timer(Duration::from_millis(10))
2202 .await;
2203 }
2204 panic!("terminal tool did not start within the expected time");
2205}
2206
2207/// Collects events until a Stop event is received, driving the executor to completion.
2208async fn collect_events_until_stop(
2209 events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
2210 cx: &mut TestAppContext,
2211) -> Vec<Result<ThreadEvent>> {
2212 let mut collected = Vec::new();
2213 let deadline = cx.executor().num_cpus() * 200;
2214
2215 for _ in 0..deadline {
2216 cx.executor().advance_clock(Duration::from_millis(10));
2217 cx.run_until_parked();
2218
2219 while let Some(Some(event)) = events.next().now_or_never() {
2220 let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
2221 collected.push(event);
2222 if is_stop {
2223 return collected;
2224 }
2225 }
2226 }
2227 panic!(
2228 "did not receive Stop event within the expected time; collected {} events",
2229 collected.len()
2230 );
2231}
2232
2233#[gpui::test]
2234async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
2235 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2236 always_allow_tools(cx);
2237 let fake_model = model.as_fake();
2238
2239 let environment = Rc::new(cx.update(|cx| {
2240 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2241 }));
2242 let handle = environment.terminal_handle.clone().unwrap();
2243
2244 let message_id = UserMessageId::new();
2245 let mut events = thread
2246 .update(cx, |thread, cx| {
2247 thread.add_tool(crate::TerminalTool::new(
2248 thread.project().clone(),
2249 environment,
2250 ));
2251 thread.send(message_id.clone(), ["run a command"], cx)
2252 })
2253 .unwrap();
2254
2255 cx.run_until_parked();
2256
2257 // Simulate the model calling the terminal tool
2258 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2259 LanguageModelToolUse {
2260 id: "terminal_tool_1".into(),
2261 name: TerminalTool::NAME.into(),
2262 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2263 input: json!({"command": "sleep 1000", "cd": "."}),
2264 is_input_complete: true,
2265 thought_signature: None,
2266 },
2267 ));
2268 fake_model.end_last_completion_stream();
2269
2270 // Wait for the terminal tool to start running
2271 wait_for_terminal_tool_started(&mut events, cx).await;
2272
2273 // Truncate the thread while the terminal is running
2274 thread
2275 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2276 .unwrap();
2277
2278 // Drive the executor to let cancellation complete
2279 let _ = collect_events_until_stop(&mut events, cx).await;
2280
2281 // Verify the terminal was killed
2282 assert!(
2283 handle.was_killed(),
2284 "expected terminal handle to be killed on truncate"
2285 );
2286
2287 // Verify the thread is empty after truncation
2288 thread.update(cx, |thread, _cx| {
2289 assert_eq!(
2290 thread.to_markdown(),
2291 "",
2292 "expected thread to be empty after truncating the only message"
2293 );
2294 });
2295
2296 // Verify we can send a new message after truncation
2297 verify_thread_recovery(&thread, &fake_model, cx).await;
2298}
2299
2300#[gpui::test]
2301async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2302 // Tests that cancellation properly kills all running terminal tools when multiple are active.
2303 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2304 always_allow_tools(cx);
2305 let fake_model = model.as_fake();
2306
2307 let environment = Rc::new(MultiTerminalEnvironment::new());
2308
2309 let mut events = thread
2310 .update(cx, |thread, cx| {
2311 thread.add_tool(crate::TerminalTool::new(
2312 thread.project().clone(),
2313 environment.clone(),
2314 ));
2315 thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2316 })
2317 .unwrap();
2318
2319 cx.run_until_parked();
2320
2321 // Simulate the model calling two terminal tools
2322 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2323 LanguageModelToolUse {
2324 id: "terminal_tool_1".into(),
2325 name: TerminalTool::NAME.into(),
2326 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2327 input: json!({"command": "sleep 1000", "cd": "."}),
2328 is_input_complete: true,
2329 thought_signature: None,
2330 },
2331 ));
2332 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2333 LanguageModelToolUse {
2334 id: "terminal_tool_2".into(),
2335 name: TerminalTool::NAME.into(),
2336 raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2337 input: json!({"command": "sleep 2000", "cd": "."}),
2338 is_input_complete: true,
2339 thought_signature: None,
2340 },
2341 ));
2342 fake_model.end_last_completion_stream();
2343
2344 // Wait for both terminal tools to start by counting terminal content updates
2345 let mut terminals_started = 0;
2346 let deadline = cx.executor().num_cpus() * 100;
2347 for _ in 0..deadline {
2348 cx.run_until_parked();
2349
2350 while let Some(Some(event)) = events.next().now_or_never() {
2351 if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2352 update,
2353 ))) = &event
2354 {
2355 if update.fields.content.as_ref().is_some_and(|content| {
2356 content
2357 .iter()
2358 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2359 }) {
2360 terminals_started += 1;
2361 if terminals_started >= 2 {
2362 break;
2363 }
2364 }
2365 }
2366 }
2367 if terminals_started >= 2 {
2368 break;
2369 }
2370
2371 cx.background_executor
2372 .timer(Duration::from_millis(10))
2373 .await;
2374 }
2375 assert!(
2376 terminals_started >= 2,
2377 "expected 2 terminal tools to start, got {terminals_started}"
2378 );
2379
2380 // Cancel the thread while both terminals are running
2381 thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2382
2383 // Collect remaining events
2384 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2385
2386 // Verify both terminal handles were killed
2387 let handles = environment.handles();
2388 assert_eq!(
2389 handles.len(),
2390 2,
2391 "expected 2 terminal handles to be created"
2392 );
2393 assert!(
2394 handles[0].was_killed(),
2395 "expected first terminal handle to be killed on cancellation"
2396 );
2397 assert!(
2398 handles[1].was_killed(),
2399 "expected second terminal handle to be killed on cancellation"
2400 );
2401
2402 // Verify we got a cancellation stop event
2403 assert_eq!(
2404 stop_events(remaining_events),
2405 vec![acp::StopReason::Cancelled],
2406 );
2407}
2408
2409#[gpui::test]
2410async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2411 // Tests that clicking the stop button on the terminal card (as opposed to the main
2412 // cancel button) properly reports user stopped via the was_stopped_by_user path.
2413 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2414 always_allow_tools(cx);
2415 let fake_model = model.as_fake();
2416
2417 let environment = Rc::new(cx.update(|cx| {
2418 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2419 }));
2420 let handle = environment.terminal_handle.clone().unwrap();
2421
2422 let mut events = thread
2423 .update(cx, |thread, cx| {
2424 thread.add_tool(crate::TerminalTool::new(
2425 thread.project().clone(),
2426 environment,
2427 ));
2428 thread.send(UserMessageId::new(), ["run a command"], cx)
2429 })
2430 .unwrap();
2431
2432 cx.run_until_parked();
2433
2434 // Simulate the model calling the terminal tool
2435 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2436 LanguageModelToolUse {
2437 id: "terminal_tool_1".into(),
2438 name: TerminalTool::NAME.into(),
2439 raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2440 input: json!({"command": "sleep 1000", "cd": "."}),
2441 is_input_complete: true,
2442 thought_signature: None,
2443 },
2444 ));
2445 fake_model.end_last_completion_stream();
2446
2447 // Wait for the terminal tool to start running
2448 wait_for_terminal_tool_started(&mut events, cx).await;
2449
2450 // Simulate user clicking stop on the terminal card itself.
2451 // This sets the flag and signals exit (simulating what the real UI would do).
2452 handle.set_stopped_by_user(true);
2453 handle.killed.store(true, Ordering::SeqCst);
2454 handle.signal_exit();
2455
2456 // Wait for the tool to complete
2457 cx.run_until_parked();
2458
2459 // The thread continues after tool completion - simulate the model ending its turn
2460 fake_model
2461 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2462 fake_model.end_last_completion_stream();
2463
2464 // Collect remaining events
2465 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2466
2467 // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2468 assert_eq!(
2469 stop_events(remaining_events),
2470 vec![acp::StopReason::EndTurn],
2471 );
2472
2473 // Verify the tool result indicates user stopped
2474 thread.update(cx, |thread, _cx| {
2475 let message = thread.last_received_or_pending_message().unwrap();
2476 let agent_message = message.as_agent_message().unwrap();
2477
2478 let tool_use = agent_message
2479 .content
2480 .iter()
2481 .find_map(|content| match content {
2482 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2483 _ => None,
2484 })
2485 .expect("expected tool use in agent message");
2486
2487 let tool_result = agent_message
2488 .tool_results
2489 .get(&tool_use.id)
2490 .expect("expected tool result");
2491
2492 let result_text = match &tool_result.content {
2493 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2494 _ => panic!("expected text content in tool result"),
2495 };
2496
2497 assert!(
2498 result_text.contains("The user stopped this command"),
2499 "expected tool result to indicate user stopped, got: {result_text}"
2500 );
2501 });
2502}
2503
2504#[gpui::test]
2505async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2506 // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2507 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2508 always_allow_tools(cx);
2509 let fake_model = model.as_fake();
2510
2511 let environment = Rc::new(cx.update(|cx| {
2512 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
2513 }));
2514 let handle = environment.terminal_handle.clone().unwrap();
2515
2516 let mut events = thread
2517 .update(cx, |thread, cx| {
2518 thread.add_tool(crate::TerminalTool::new(
2519 thread.project().clone(),
2520 environment,
2521 ));
2522 thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2523 })
2524 .unwrap();
2525
2526 cx.run_until_parked();
2527
2528 // Simulate the model calling the terminal tool with a short timeout
2529 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2530 LanguageModelToolUse {
2531 id: "terminal_tool_1".into(),
2532 name: TerminalTool::NAME.into(),
2533 raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2534 input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2535 is_input_complete: true,
2536 thought_signature: None,
2537 },
2538 ));
2539 fake_model.end_last_completion_stream();
2540
2541 // Wait for the terminal tool to start running
2542 wait_for_terminal_tool_started(&mut events, cx).await;
2543
2544 // Advance clock past the timeout
2545 cx.executor().advance_clock(Duration::from_millis(200));
2546 cx.run_until_parked();
2547
2548 // The thread continues after tool completion - simulate the model ending its turn
2549 fake_model
2550 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2551 fake_model.end_last_completion_stream();
2552
2553 // Collect remaining events
2554 let remaining_events = collect_events_until_stop(&mut events, cx).await;
2555
2556 // Verify the terminal was killed due to timeout
2557 assert!(
2558 handle.was_killed(),
2559 "expected terminal handle to be killed on timeout"
2560 );
2561
2562 // Verify we got an EndTurn (the tool completed, just with timeout)
2563 assert_eq!(
2564 stop_events(remaining_events),
2565 vec![acp::StopReason::EndTurn],
2566 );
2567
2568 // Verify the tool result indicates timeout, not user stopped
2569 thread.update(cx, |thread, _cx| {
2570 let message = thread.last_received_or_pending_message().unwrap();
2571 let agent_message = message.as_agent_message().unwrap();
2572
2573 let tool_use = agent_message
2574 .content
2575 .iter()
2576 .find_map(|content| match content {
2577 AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2578 _ => None,
2579 })
2580 .expect("expected tool use in agent message");
2581
2582 let tool_result = agent_message
2583 .tool_results
2584 .get(&tool_use.id)
2585 .expect("expected tool result");
2586
2587 let result_text = match &tool_result.content {
2588 language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2589 _ => panic!("expected text content in tool result"),
2590 };
2591
2592 assert!(
2593 result_text.contains("timed out"),
2594 "expected tool result to indicate timeout, got: {result_text}"
2595 );
2596 assert!(
2597 !result_text.contains("The user stopped"),
2598 "tool result should not mention user stopped when it timed out, got: {result_text}"
2599 );
2600 });
2601}
2602
2603#[gpui::test]
2604async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2605 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2606 let fake_model = model.as_fake();
2607
2608 let events_1 = thread
2609 .update(cx, |thread, cx| {
2610 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2611 })
2612 .unwrap();
2613 cx.run_until_parked();
2614 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2615 cx.run_until_parked();
2616
2617 let events_2 = thread
2618 .update(cx, |thread, cx| {
2619 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2620 })
2621 .unwrap();
2622 cx.run_until_parked();
2623 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2624 fake_model
2625 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2626 fake_model.end_last_completion_stream();
2627
2628 let events_1 = events_1.collect::<Vec<_>>().await;
2629 assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2630 let events_2 = events_2.collect::<Vec<_>>().await;
2631 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2632}
2633
2634#[gpui::test]
2635async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2636 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2637 let fake_model = model.as_fake();
2638
2639 let events_1 = thread
2640 .update(cx, |thread, cx| {
2641 thread.send(UserMessageId::new(), ["Hello 1"], cx)
2642 })
2643 .unwrap();
2644 cx.run_until_parked();
2645 fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2646 fake_model
2647 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2648 fake_model.end_last_completion_stream();
2649 let events_1 = events_1.collect::<Vec<_>>().await;
2650
2651 let events_2 = thread
2652 .update(cx, |thread, cx| {
2653 thread.send(UserMessageId::new(), ["Hello 2"], cx)
2654 })
2655 .unwrap();
2656 cx.run_until_parked();
2657 fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2658 fake_model
2659 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2660 fake_model.end_last_completion_stream();
2661 let events_2 = events_2.collect::<Vec<_>>().await;
2662
2663 assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2664 assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2665}
2666
2667#[gpui::test]
2668async fn test_refusal(cx: &mut TestAppContext) {
2669 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2670 let fake_model = model.as_fake();
2671
2672 let events = thread
2673 .update(cx, |thread, cx| {
2674 thread.send(UserMessageId::new(), ["Hello"], cx)
2675 })
2676 .unwrap();
2677 cx.run_until_parked();
2678 thread.read_with(cx, |thread, _| {
2679 assert_eq!(
2680 thread.to_markdown(),
2681 indoc! {"
2682 ## User
2683
2684 Hello
2685 "}
2686 );
2687 });
2688
2689 fake_model.send_last_completion_stream_text_chunk("Hey!");
2690 cx.run_until_parked();
2691 thread.read_with(cx, |thread, _| {
2692 assert_eq!(
2693 thread.to_markdown(),
2694 indoc! {"
2695 ## User
2696
2697 Hello
2698
2699 ## Assistant
2700
2701 Hey!
2702 "}
2703 );
2704 });
2705
2706 // If the model refuses to continue, the thread should remove all the messages after the last user message.
2707 fake_model
2708 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2709 let events = events.collect::<Vec<_>>().await;
2710 assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2711 thread.read_with(cx, |thread, _| {
2712 assert_eq!(thread.to_markdown(), "");
2713 });
2714}
2715
2716#[gpui::test]
2717async fn test_truncate_first_message(cx: &mut TestAppContext) {
2718 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2719 let fake_model = model.as_fake();
2720
2721 let message_id = UserMessageId::new();
2722 thread
2723 .update(cx, |thread, cx| {
2724 thread.send(message_id.clone(), ["Hello"], cx)
2725 })
2726 .unwrap();
2727 cx.run_until_parked();
2728 thread.read_with(cx, |thread, _| {
2729 assert_eq!(
2730 thread.to_markdown(),
2731 indoc! {"
2732 ## User
2733
2734 Hello
2735 "}
2736 );
2737 assert_eq!(thread.latest_token_usage(), None);
2738 });
2739
2740 fake_model.send_last_completion_stream_text_chunk("Hey!");
2741 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2742 language_model::TokenUsage {
2743 input_tokens: 32_000,
2744 output_tokens: 16_000,
2745 cache_creation_input_tokens: 0,
2746 cache_read_input_tokens: 0,
2747 },
2748 ));
2749 cx.run_until_parked();
2750 thread.read_with(cx, |thread, _| {
2751 assert_eq!(
2752 thread.to_markdown(),
2753 indoc! {"
2754 ## User
2755
2756 Hello
2757
2758 ## Assistant
2759
2760 Hey!
2761 "}
2762 );
2763 assert_eq!(
2764 thread.latest_token_usage(),
2765 Some(acp_thread::TokenUsage {
2766 used_tokens: 32_000 + 16_000,
2767 max_tokens: 1_000_000,
2768 max_output_tokens: None,
2769 input_tokens: 32_000,
2770 output_tokens: 16_000,
2771 })
2772 );
2773 });
2774
2775 thread
2776 .update(cx, |thread, cx| thread.truncate(message_id, cx))
2777 .unwrap();
2778 cx.run_until_parked();
2779 thread.read_with(cx, |thread, _| {
2780 assert_eq!(thread.to_markdown(), "");
2781 assert_eq!(thread.latest_token_usage(), None);
2782 });
2783
2784 // Ensure we can still send a new message after truncation.
2785 thread
2786 .update(cx, |thread, cx| {
2787 thread.send(UserMessageId::new(), ["Hi"], cx)
2788 })
2789 .unwrap();
2790 thread.update(cx, |thread, _cx| {
2791 assert_eq!(
2792 thread.to_markdown(),
2793 indoc! {"
2794 ## User
2795
2796 Hi
2797 "}
2798 );
2799 });
2800 cx.run_until_parked();
2801 fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2802 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2803 language_model::TokenUsage {
2804 input_tokens: 40_000,
2805 output_tokens: 20_000,
2806 cache_creation_input_tokens: 0,
2807 cache_read_input_tokens: 0,
2808 },
2809 ));
2810 cx.run_until_parked();
2811 thread.read_with(cx, |thread, _| {
2812 assert_eq!(
2813 thread.to_markdown(),
2814 indoc! {"
2815 ## User
2816
2817 Hi
2818
2819 ## Assistant
2820
2821 Ahoy!
2822 "}
2823 );
2824
2825 assert_eq!(
2826 thread.latest_token_usage(),
2827 Some(acp_thread::TokenUsage {
2828 used_tokens: 40_000 + 20_000,
2829 max_tokens: 1_000_000,
2830 max_output_tokens: None,
2831 input_tokens: 40_000,
2832 output_tokens: 20_000,
2833 })
2834 );
2835 });
2836}
2837
2838#[gpui::test]
2839async fn test_truncate_second_message(cx: &mut TestAppContext) {
2840 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2841 let fake_model = model.as_fake();
2842
2843 thread
2844 .update(cx, |thread, cx| {
2845 thread.send(UserMessageId::new(), ["Message 1"], cx)
2846 })
2847 .unwrap();
2848 cx.run_until_parked();
2849 fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2850 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2851 language_model::TokenUsage {
2852 input_tokens: 32_000,
2853 output_tokens: 16_000,
2854 cache_creation_input_tokens: 0,
2855 cache_read_input_tokens: 0,
2856 },
2857 ));
2858 fake_model.end_last_completion_stream();
2859 cx.run_until_parked();
2860
2861 let assert_first_message_state = |cx: &mut TestAppContext| {
2862 thread.clone().read_with(cx, |thread, _| {
2863 assert_eq!(
2864 thread.to_markdown(),
2865 indoc! {"
2866 ## User
2867
2868 Message 1
2869
2870 ## Assistant
2871
2872 Message 1 response
2873 "}
2874 );
2875
2876 assert_eq!(
2877 thread.latest_token_usage(),
2878 Some(acp_thread::TokenUsage {
2879 used_tokens: 32_000 + 16_000,
2880 max_tokens: 1_000_000,
2881 max_output_tokens: None,
2882 input_tokens: 32_000,
2883 output_tokens: 16_000,
2884 })
2885 );
2886 });
2887 };
2888
2889 assert_first_message_state(cx);
2890
2891 let second_message_id = UserMessageId::new();
2892 thread
2893 .update(cx, |thread, cx| {
2894 thread.send(second_message_id.clone(), ["Message 2"], cx)
2895 })
2896 .unwrap();
2897 cx.run_until_parked();
2898
2899 fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2900 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2901 language_model::TokenUsage {
2902 input_tokens: 40_000,
2903 output_tokens: 20_000,
2904 cache_creation_input_tokens: 0,
2905 cache_read_input_tokens: 0,
2906 },
2907 ));
2908 fake_model.end_last_completion_stream();
2909 cx.run_until_parked();
2910
2911 thread.read_with(cx, |thread, _| {
2912 assert_eq!(
2913 thread.to_markdown(),
2914 indoc! {"
2915 ## User
2916
2917 Message 1
2918
2919 ## Assistant
2920
2921 Message 1 response
2922
2923 ## User
2924
2925 Message 2
2926
2927 ## Assistant
2928
2929 Message 2 response
2930 "}
2931 );
2932
2933 assert_eq!(
2934 thread.latest_token_usage(),
2935 Some(acp_thread::TokenUsage {
2936 used_tokens: 40_000 + 20_000,
2937 max_tokens: 1_000_000,
2938 max_output_tokens: None,
2939 input_tokens: 40_000,
2940 output_tokens: 20_000,
2941 })
2942 );
2943 });
2944
2945 thread
2946 .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2947 .unwrap();
2948 cx.run_until_parked();
2949
2950 assert_first_message_state(cx);
2951}
2952
2953#[gpui::test]
2954async fn test_title_generation(cx: &mut TestAppContext) {
2955 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2956 let fake_model = model.as_fake();
2957
2958 let summary_model = Arc::new(FakeLanguageModel::default());
2959 thread.update(cx, |thread, cx| {
2960 thread.set_summarization_model(Some(summary_model.clone()), cx)
2961 });
2962
2963 let send = thread
2964 .update(cx, |thread, cx| {
2965 thread.send(UserMessageId::new(), ["Hello"], cx)
2966 })
2967 .unwrap();
2968 cx.run_until_parked();
2969
2970 fake_model.send_last_completion_stream_text_chunk("Hey!");
2971 fake_model.end_last_completion_stream();
2972 cx.run_until_parked();
2973 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2974
2975 // Ensure the summary model has been invoked to generate a title.
2976 summary_model.send_last_completion_stream_text_chunk("Hello ");
2977 summary_model.send_last_completion_stream_text_chunk("world\nG");
2978 summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2979 summary_model.end_last_completion_stream();
2980 send.collect::<Vec<_>>().await;
2981 cx.run_until_parked();
2982 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2983
2984 // Send another message, ensuring no title is generated this time.
2985 let send = thread
2986 .update(cx, |thread, cx| {
2987 thread.send(UserMessageId::new(), ["Hello again"], cx)
2988 })
2989 .unwrap();
2990 cx.run_until_parked();
2991 fake_model.send_last_completion_stream_text_chunk("Hey again!");
2992 fake_model.end_last_completion_stream();
2993 cx.run_until_parked();
2994 assert_eq!(summary_model.pending_completions(), Vec::new());
2995 send.collect::<Vec<_>>().await;
2996 thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2997}
2998
2999#[gpui::test]
3000async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
3001 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3002 let fake_model = model.as_fake();
3003
3004 let _events = thread
3005 .update(cx, |thread, cx| {
3006 thread.add_tool(ToolRequiringPermission);
3007 thread.add_tool(EchoTool);
3008 thread.send(UserMessageId::new(), ["Hey!"], cx)
3009 })
3010 .unwrap();
3011 cx.run_until_parked();
3012
3013 let permission_tool_use = LanguageModelToolUse {
3014 id: "tool_id_1".into(),
3015 name: ToolRequiringPermission::NAME.into(),
3016 raw_input: "{}".into(),
3017 input: json!({}),
3018 is_input_complete: true,
3019 thought_signature: None,
3020 };
3021 let echo_tool_use = LanguageModelToolUse {
3022 id: "tool_id_2".into(),
3023 name: EchoTool::NAME.into(),
3024 raw_input: json!({"text": "test"}).to_string(),
3025 input: json!({"text": "test"}),
3026 is_input_complete: true,
3027 thought_signature: None,
3028 };
3029 fake_model.send_last_completion_stream_text_chunk("Hi!");
3030 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3031 permission_tool_use,
3032 ));
3033 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3034 echo_tool_use.clone(),
3035 ));
3036 fake_model.end_last_completion_stream();
3037 cx.run_until_parked();
3038
3039 // Ensure pending tools are skipped when building a request.
3040 let request = thread
3041 .read_with(cx, |thread, cx| {
3042 thread.build_completion_request(CompletionIntent::EditFile, cx)
3043 })
3044 .unwrap();
3045 assert_eq!(
3046 request.messages[1..],
3047 vec![
3048 LanguageModelRequestMessage {
3049 role: Role::User,
3050 content: vec!["Hey!".into()],
3051 cache: true,
3052 reasoning_details: None,
3053 },
3054 LanguageModelRequestMessage {
3055 role: Role::Assistant,
3056 content: vec![
3057 MessageContent::Text("Hi!".into()),
3058 MessageContent::ToolUse(echo_tool_use.clone())
3059 ],
3060 cache: false,
3061 reasoning_details: None,
3062 },
3063 LanguageModelRequestMessage {
3064 role: Role::User,
3065 content: vec![MessageContent::ToolResult(LanguageModelToolResult {
3066 tool_use_id: echo_tool_use.id.clone(),
3067 tool_name: echo_tool_use.name,
3068 is_error: false,
3069 content: "test".into(),
3070 output: Some("test".into())
3071 })],
3072 cache: false,
3073 reasoning_details: None,
3074 },
3075 ],
3076 );
3077}
3078
3079#[gpui::test]
3080async fn test_agent_connection(cx: &mut TestAppContext) {
3081 cx.update(settings::init);
3082 let templates = Templates::new();
3083
3084 // Initialize language model system with test provider
3085 cx.update(|cx| {
3086 gpui_tokio::init(cx);
3087
3088 let http_client = FakeHttpClient::with_404_response();
3089 let clock = Arc::new(clock::FakeSystemClock::new());
3090 let client = Client::new(clock, http_client, cx);
3091 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3092 language_model::init(client.clone(), cx);
3093 language_models::init(user_store, client.clone(), cx);
3094 LanguageModelRegistry::test(cx);
3095 });
3096 cx.executor().forbid_parking();
3097
3098 // Create a project for new_thread
3099 let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
3100 fake_fs.insert_tree(path!("/test"), json!({})).await;
3101 let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
3102 let cwd = Path::new("/test");
3103 let thread_store = cx.new(|cx| ThreadStore::new(cx));
3104
3105 // Create agent and connection
3106 let agent = NativeAgent::new(
3107 project.clone(),
3108 thread_store,
3109 templates.clone(),
3110 None,
3111 fake_fs.clone(),
3112 &mut cx.to_async(),
3113 )
3114 .await
3115 .unwrap();
3116 let connection = NativeAgentConnection(agent.clone());
3117
3118 // Create a thread using new_thread
3119 let connection_rc = Rc::new(connection.clone());
3120 let acp_thread = cx
3121 .update(|cx| connection_rc.new_session(project, cwd, cx))
3122 .await
3123 .expect("new_thread should succeed");
3124
3125 // Get the session_id from the AcpThread
3126 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
3127
3128 // Test model_selector returns Some
3129 let selector_opt = connection.model_selector(&session_id);
3130 assert!(
3131 selector_opt.is_some(),
3132 "agent should always support ModelSelector"
3133 );
3134 let selector = selector_opt.unwrap();
3135
3136 // Test list_models
3137 let listed_models = cx
3138 .update(|cx| selector.list_models(cx))
3139 .await
3140 .expect("list_models should succeed");
3141 let AgentModelList::Grouped(listed_models) = listed_models else {
3142 panic!("Unexpected model list type");
3143 };
3144 assert!(!listed_models.is_empty(), "should have at least one model");
3145 assert_eq!(
3146 listed_models[&AgentModelGroupName("Fake".into())][0]
3147 .id
3148 .0
3149 .as_ref(),
3150 "fake/fake"
3151 );
3152
3153 // Test selected_model returns the default
3154 let model = cx
3155 .update(|cx| selector.selected_model(cx))
3156 .await
3157 .expect("selected_model should succeed");
3158 let model = cx
3159 .update(|cx| agent.read(cx).models().model_from_id(&model.id))
3160 .unwrap();
3161 let model = model.as_fake();
3162 assert_eq!(model.id().0, "fake", "should return default model");
3163
3164 let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
3165 cx.run_until_parked();
3166 model.send_last_completion_stream_text_chunk("def");
3167 cx.run_until_parked();
3168 acp_thread.read_with(cx, |thread, cx| {
3169 assert_eq!(
3170 thread.to_markdown(cx),
3171 indoc! {"
3172 ## User
3173
3174 abc
3175
3176 ## Assistant
3177
3178 def
3179
3180 "}
3181 )
3182 });
3183
3184 // Test cancel
3185 cx.update(|cx| connection.cancel(&session_id, cx));
3186 request.await.expect("prompt should fail gracefully");
3187
3188 // Explicitly close the session and drop the ACP thread.
3189 cx.update(|cx| Rc::new(connection.clone()).close_session(&session_id, cx))
3190 .await
3191 .unwrap();
3192 drop(acp_thread);
3193 let result = cx
3194 .update(|cx| {
3195 connection.prompt(
3196 Some(acp_thread::UserMessageId::new()),
3197 acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
3198 cx,
3199 )
3200 })
3201 .await;
3202 assert_eq!(
3203 result.as_ref().unwrap_err().to_string(),
3204 "Session not found",
3205 "unexpected result: {:?}",
3206 result
3207 );
3208}
3209
3210#[gpui::test]
3211async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
3212 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3213 thread.update(cx, |thread, _cx| thread.add_tool(EchoTool));
3214 let fake_model = model.as_fake();
3215
3216 let mut events = thread
3217 .update(cx, |thread, cx| {
3218 thread.send(UserMessageId::new(), ["Echo something"], cx)
3219 })
3220 .unwrap();
3221 cx.run_until_parked();
3222
3223 // Simulate streaming partial input.
3224 let input = json!({});
3225 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3226 LanguageModelToolUse {
3227 id: "1".into(),
3228 name: EchoTool::NAME.into(),
3229 raw_input: input.to_string(),
3230 input,
3231 is_input_complete: false,
3232 thought_signature: None,
3233 },
3234 ));
3235
3236 // Input streaming completed
3237 let input = json!({ "text": "Hello!" });
3238 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3239 LanguageModelToolUse {
3240 id: "1".into(),
3241 name: "echo".into(),
3242 raw_input: input.to_string(),
3243 input,
3244 is_input_complete: true,
3245 thought_signature: None,
3246 },
3247 ));
3248 fake_model.end_last_completion_stream();
3249 cx.run_until_parked();
3250
3251 let tool_call = expect_tool_call(&mut events).await;
3252 assert_eq!(
3253 tool_call,
3254 acp::ToolCall::new("1", "Echo")
3255 .raw_input(json!({}))
3256 .meta(acp::Meta::from_iter([("tool_name".into(), "echo".into())]))
3257 );
3258 let update = expect_tool_call_update_fields(&mut events).await;
3259 assert_eq!(
3260 update,
3261 acp::ToolCallUpdate::new(
3262 "1",
3263 acp::ToolCallUpdateFields::new()
3264 .title("Echo")
3265 .kind(acp::ToolKind::Other)
3266 .raw_input(json!({ "text": "Hello!"}))
3267 )
3268 );
3269 let update = expect_tool_call_update_fields(&mut events).await;
3270 assert_eq!(
3271 update,
3272 acp::ToolCallUpdate::new(
3273 "1",
3274 acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3275 )
3276 );
3277 let update = expect_tool_call_update_fields(&mut events).await;
3278 assert_eq!(
3279 update,
3280 acp::ToolCallUpdate::new(
3281 "1",
3282 acp::ToolCallUpdateFields::new()
3283 .status(acp::ToolCallStatus::Completed)
3284 .raw_output("Hello!")
3285 )
3286 );
3287}
3288
3289#[gpui::test]
3290async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3291 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3292 let fake_model = model.as_fake();
3293
3294 let mut events = thread
3295 .update(cx, |thread, cx| {
3296 thread.send(UserMessageId::new(), ["Hello!"], cx)
3297 })
3298 .unwrap();
3299 cx.run_until_parked();
3300
3301 fake_model.send_last_completion_stream_text_chunk("Hey!");
3302 fake_model.end_last_completion_stream();
3303
3304 let mut retry_events = Vec::new();
3305 while let Some(Ok(event)) = events.next().await {
3306 match event {
3307 ThreadEvent::Retry(retry_status) => {
3308 retry_events.push(retry_status);
3309 }
3310 ThreadEvent::Stop(..) => break,
3311 _ => {}
3312 }
3313 }
3314
3315 assert_eq!(retry_events.len(), 0);
3316 thread.read_with(cx, |thread, _cx| {
3317 assert_eq!(
3318 thread.to_markdown(),
3319 indoc! {"
3320 ## User
3321
3322 Hello!
3323
3324 ## Assistant
3325
3326 Hey!
3327 "}
3328 )
3329 });
3330}
3331
3332#[gpui::test]
3333async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3334 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3335 let fake_model = model.as_fake();
3336
3337 let mut events = thread
3338 .update(cx, |thread, cx| {
3339 thread.send(UserMessageId::new(), ["Hello!"], cx)
3340 })
3341 .unwrap();
3342 cx.run_until_parked();
3343
3344 fake_model.send_last_completion_stream_text_chunk("Hey,");
3345 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3346 provider: LanguageModelProviderName::new("Anthropic"),
3347 retry_after: Some(Duration::from_secs(3)),
3348 });
3349 fake_model.end_last_completion_stream();
3350
3351 cx.executor().advance_clock(Duration::from_secs(3));
3352 cx.run_until_parked();
3353
3354 fake_model.send_last_completion_stream_text_chunk("there!");
3355 fake_model.end_last_completion_stream();
3356 cx.run_until_parked();
3357
3358 let mut retry_events = Vec::new();
3359 while let Some(Ok(event)) = events.next().await {
3360 match event {
3361 ThreadEvent::Retry(retry_status) => {
3362 retry_events.push(retry_status);
3363 }
3364 ThreadEvent::Stop(..) => break,
3365 _ => {}
3366 }
3367 }
3368
3369 assert_eq!(retry_events.len(), 1);
3370 assert!(matches!(
3371 retry_events[0],
3372 acp_thread::RetryStatus { attempt: 1, .. }
3373 ));
3374 thread.read_with(cx, |thread, _cx| {
3375 assert_eq!(
3376 thread.to_markdown(),
3377 indoc! {"
3378 ## User
3379
3380 Hello!
3381
3382 ## Assistant
3383
3384 Hey,
3385
3386 [resume]
3387
3388 ## Assistant
3389
3390 there!
3391 "}
3392 )
3393 });
3394}
3395
3396#[gpui::test]
3397async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3398 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3399 let fake_model = model.as_fake();
3400
3401 let events = thread
3402 .update(cx, |thread, cx| {
3403 thread.add_tool(EchoTool);
3404 thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3405 })
3406 .unwrap();
3407 cx.run_until_parked();
3408
3409 let tool_use_1 = LanguageModelToolUse {
3410 id: "tool_1".into(),
3411 name: EchoTool::NAME.into(),
3412 raw_input: json!({"text": "test"}).to_string(),
3413 input: json!({"text": "test"}),
3414 is_input_complete: true,
3415 thought_signature: None,
3416 };
3417 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3418 tool_use_1.clone(),
3419 ));
3420 fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3421 provider: LanguageModelProviderName::new("Anthropic"),
3422 retry_after: Some(Duration::from_secs(3)),
3423 });
3424 fake_model.end_last_completion_stream();
3425
3426 cx.executor().advance_clock(Duration::from_secs(3));
3427 let completion = fake_model.pending_completions().pop().unwrap();
3428 assert_eq!(
3429 completion.messages[1..],
3430 vec![
3431 LanguageModelRequestMessage {
3432 role: Role::User,
3433 content: vec!["Call the echo tool!".into()],
3434 cache: false,
3435 reasoning_details: None,
3436 },
3437 LanguageModelRequestMessage {
3438 role: Role::Assistant,
3439 content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3440 cache: false,
3441 reasoning_details: None,
3442 },
3443 LanguageModelRequestMessage {
3444 role: Role::User,
3445 content: vec![language_model::MessageContent::ToolResult(
3446 LanguageModelToolResult {
3447 tool_use_id: tool_use_1.id.clone(),
3448 tool_name: tool_use_1.name.clone(),
3449 is_error: false,
3450 content: "test".into(),
3451 output: Some("test".into())
3452 }
3453 )],
3454 cache: true,
3455 reasoning_details: None,
3456 },
3457 ]
3458 );
3459
3460 fake_model.send_last_completion_stream_text_chunk("Done");
3461 fake_model.end_last_completion_stream();
3462 cx.run_until_parked();
3463 events.collect::<Vec<_>>().await;
3464 thread.read_with(cx, |thread, _cx| {
3465 assert_eq!(
3466 thread.last_received_or_pending_message(),
3467 Some(Message::Agent(AgentMessage {
3468 content: vec![AgentMessageContent::Text("Done".into())],
3469 tool_results: IndexMap::default(),
3470 reasoning_details: None,
3471 }))
3472 );
3473 })
3474}
3475
3476#[gpui::test]
3477async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3478 let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3479 let fake_model = model.as_fake();
3480
3481 let mut events = thread
3482 .update(cx, |thread, cx| {
3483 thread.send(UserMessageId::new(), ["Hello!"], cx)
3484 })
3485 .unwrap();
3486 cx.run_until_parked();
3487
3488 for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3489 fake_model.send_last_completion_stream_error(
3490 LanguageModelCompletionError::ServerOverloaded {
3491 provider: LanguageModelProviderName::new("Anthropic"),
3492 retry_after: Some(Duration::from_secs(3)),
3493 },
3494 );
3495 fake_model.end_last_completion_stream();
3496 cx.executor().advance_clock(Duration::from_secs(3));
3497 cx.run_until_parked();
3498 }
3499
3500 let mut errors = Vec::new();
3501 let mut retry_events = Vec::new();
3502 while let Some(event) = events.next().await {
3503 match event {
3504 Ok(ThreadEvent::Retry(retry_status)) => {
3505 retry_events.push(retry_status);
3506 }
3507 Ok(ThreadEvent::Stop(..)) => break,
3508 Err(error) => errors.push(error),
3509 _ => {}
3510 }
3511 }
3512
3513 assert_eq!(
3514 retry_events.len(),
3515 crate::thread::MAX_RETRY_ATTEMPTS as usize
3516 );
3517 for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3518 assert_eq!(retry_events[i].attempt, i + 1);
3519 }
3520 assert_eq!(errors.len(), 1);
3521 let error = errors[0]
3522 .downcast_ref::<LanguageModelCompletionError>()
3523 .unwrap();
3524 assert!(matches!(
3525 error,
3526 LanguageModelCompletionError::ServerOverloaded { .. }
3527 ));
3528}
3529
3530/// Filters out the stop events for asserting against in tests
3531fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3532 result_events
3533 .into_iter()
3534 .filter_map(|event| match event.unwrap() {
3535 ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3536 _ => None,
3537 })
3538 .collect()
3539}
3540
3541struct ThreadTest {
3542 model: Arc<dyn LanguageModel>,
3543 thread: Entity<Thread>,
3544 project_context: Entity<ProjectContext>,
3545 context_server_store: Entity<ContextServerStore>,
3546 fs: Arc<FakeFs>,
3547}
3548
3549enum TestModel {
3550 Sonnet4,
3551 Fake,
3552}
3553
3554impl TestModel {
3555 fn id(&self) -> LanguageModelId {
3556 match self {
3557 TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3558 TestModel::Fake => unreachable!(),
3559 }
3560 }
3561}
3562
3563async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3564 cx.executor().allow_parking();
3565
3566 let fs = FakeFs::new(cx.background_executor.clone());
3567 fs.create_dir(paths::settings_file().parent().unwrap())
3568 .await
3569 .unwrap();
3570 fs.insert_file(
3571 paths::settings_file(),
3572 json!({
3573 "agent": {
3574 "default_profile": "test-profile",
3575 "profiles": {
3576 "test-profile": {
3577 "name": "Test Profile",
3578 "tools": {
3579 EchoTool::NAME: true,
3580 DelayTool::NAME: true,
3581 WordListTool::NAME: true,
3582 ToolRequiringPermission::NAME: true,
3583 InfiniteTool::NAME: true,
3584 CancellationAwareTool::NAME: true,
3585 (TerminalTool::NAME): true,
3586 }
3587 }
3588 }
3589 }
3590 })
3591 .to_string()
3592 .into_bytes(),
3593 )
3594 .await;
3595
3596 cx.update(|cx| {
3597 settings::init(cx);
3598
3599 match model {
3600 TestModel::Fake => {}
3601 TestModel::Sonnet4 => {
3602 gpui_tokio::init(cx);
3603 let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3604 cx.set_http_client(Arc::new(http_client));
3605 let client = Client::production(cx);
3606 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3607 language_model::init(client.clone(), cx);
3608 language_models::init(user_store, client.clone(), cx);
3609 }
3610 };
3611
3612 watch_settings(fs.clone(), cx);
3613 });
3614
3615 let templates = Templates::new();
3616
3617 fs.insert_tree(path!("/test"), json!({})).await;
3618 let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3619
3620 let model = cx
3621 .update(|cx| {
3622 if let TestModel::Fake = model {
3623 Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3624 } else {
3625 let model_id = model.id();
3626 let models = LanguageModelRegistry::read_global(cx);
3627 let model = models
3628 .available_models(cx)
3629 .find(|model| model.id() == model_id)
3630 .unwrap();
3631
3632 let provider = models.provider(&model.provider_id()).unwrap();
3633 let authenticated = provider.authenticate(cx);
3634
3635 cx.spawn(async move |_cx| {
3636 authenticated.await.unwrap();
3637 model
3638 })
3639 }
3640 })
3641 .await;
3642
3643 let project_context = cx.new(|_cx| ProjectContext::default());
3644 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3645 let context_server_registry =
3646 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3647 let thread = cx.new(|cx| {
3648 Thread::new(
3649 project,
3650 project_context.clone(),
3651 context_server_registry,
3652 templates,
3653 Some(model.clone()),
3654 cx,
3655 )
3656 });
3657 ThreadTest {
3658 model,
3659 thread,
3660 project_context,
3661 context_server_store,
3662 fs,
3663 }
3664}
3665
3666#[cfg(test)]
3667#[ctor::ctor]
3668fn init_logger() {
3669 if std::env::var("RUST_LOG").is_ok() {
3670 env_logger::init();
3671 }
3672}
3673
3674fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3675 let fs = fs.clone();
3676 cx.spawn({
3677 async move |cx| {
3678 let (mut new_settings_content_rx, watcher_task) = settings::watch_config_file(
3679 cx.background_executor(),
3680 fs,
3681 paths::settings_file().clone(),
3682 );
3683 let _watcher_task = watcher_task;
3684
3685 while let Some(new_settings_content) = new_settings_content_rx.next().await {
3686 cx.update(|cx| {
3687 SettingsStore::update_global(cx, |settings, cx| {
3688 settings.set_user_settings(&new_settings_content, cx)
3689 })
3690 })
3691 .ok();
3692 }
3693 }
3694 })
3695 .detach();
3696}
3697
3698fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3699 completion
3700 .tools
3701 .iter()
3702 .map(|tool| tool.name.clone())
3703 .collect()
3704}
3705
3706fn setup_context_server(
3707 name: &'static str,
3708 tools: Vec<context_server::types::Tool>,
3709 context_server_store: &Entity<ContextServerStore>,
3710 cx: &mut TestAppContext,
3711) -> mpsc::UnboundedReceiver<(
3712 context_server::types::CallToolParams,
3713 oneshot::Sender<context_server::types::CallToolResponse>,
3714)> {
3715 cx.update(|cx| {
3716 let mut settings = ProjectSettings::get_global(cx).clone();
3717 settings.context_servers.insert(
3718 name.into(),
3719 project::project_settings::ContextServerSettings::Stdio {
3720 enabled: true,
3721 remote: false,
3722 command: ContextServerCommand {
3723 path: "somebinary".into(),
3724 args: Vec::new(),
3725 env: None,
3726 timeout: None,
3727 },
3728 },
3729 );
3730 ProjectSettings::override_global(settings, cx);
3731 });
3732
3733 let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3734 let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3735 .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3736 context_server::types::InitializeResponse {
3737 protocol_version: context_server::types::ProtocolVersion(
3738 context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3739 ),
3740 server_info: context_server::types::Implementation {
3741 name: name.into(),
3742 version: "1.0.0".to_string(),
3743 },
3744 capabilities: context_server::types::ServerCapabilities {
3745 tools: Some(context_server::types::ToolsCapabilities {
3746 list_changed: Some(true),
3747 }),
3748 ..Default::default()
3749 },
3750 meta: None,
3751 }
3752 })
3753 .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3754 let tools = tools.clone();
3755 async move {
3756 context_server::types::ListToolsResponse {
3757 tools,
3758 next_cursor: None,
3759 meta: None,
3760 }
3761 }
3762 })
3763 .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3764 let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3765 async move {
3766 let (response_tx, response_rx) = oneshot::channel();
3767 mcp_tool_calls_tx
3768 .unbounded_send((params, response_tx))
3769 .unwrap();
3770 response_rx.await.unwrap()
3771 }
3772 });
3773 context_server_store.update(cx, |store, cx| {
3774 store.start_server(
3775 Arc::new(ContextServer::new(
3776 ContextServerId(name.into()),
3777 Arc::new(fake_transport),
3778 )),
3779 cx,
3780 );
3781 });
3782 cx.run_until_parked();
3783 mcp_tool_calls_rx
3784}
3785
3786#[gpui::test]
3787async fn test_tokens_before_message(cx: &mut TestAppContext) {
3788 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3789 let fake_model = model.as_fake();
3790
3791 // First message
3792 let message_1_id = UserMessageId::new();
3793 thread
3794 .update(cx, |thread, cx| {
3795 thread.send(message_1_id.clone(), ["First message"], cx)
3796 })
3797 .unwrap();
3798 cx.run_until_parked();
3799
3800 // Before any response, tokens_before_message should return None for first message
3801 thread.read_with(cx, |thread, _| {
3802 assert_eq!(
3803 thread.tokens_before_message(&message_1_id),
3804 None,
3805 "First message should have no tokens before it"
3806 );
3807 });
3808
3809 // Complete first message with usage
3810 fake_model.send_last_completion_stream_text_chunk("Response 1");
3811 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3812 language_model::TokenUsage {
3813 input_tokens: 100,
3814 output_tokens: 50,
3815 cache_creation_input_tokens: 0,
3816 cache_read_input_tokens: 0,
3817 },
3818 ));
3819 fake_model.end_last_completion_stream();
3820 cx.run_until_parked();
3821
3822 // First message still has no tokens before it
3823 thread.read_with(cx, |thread, _| {
3824 assert_eq!(
3825 thread.tokens_before_message(&message_1_id),
3826 None,
3827 "First message should still have no tokens before it after response"
3828 );
3829 });
3830
3831 // Second message
3832 let message_2_id = UserMessageId::new();
3833 thread
3834 .update(cx, |thread, cx| {
3835 thread.send(message_2_id.clone(), ["Second message"], cx)
3836 })
3837 .unwrap();
3838 cx.run_until_parked();
3839
3840 // Second message should have first message's input tokens before it
3841 thread.read_with(cx, |thread, _| {
3842 assert_eq!(
3843 thread.tokens_before_message(&message_2_id),
3844 Some(100),
3845 "Second message should have 100 tokens before it (from first request)"
3846 );
3847 });
3848
3849 // Complete second message
3850 fake_model.send_last_completion_stream_text_chunk("Response 2");
3851 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3852 language_model::TokenUsage {
3853 input_tokens: 250, // Total for this request (includes previous context)
3854 output_tokens: 75,
3855 cache_creation_input_tokens: 0,
3856 cache_read_input_tokens: 0,
3857 },
3858 ));
3859 fake_model.end_last_completion_stream();
3860 cx.run_until_parked();
3861
3862 // Third message
3863 let message_3_id = UserMessageId::new();
3864 thread
3865 .update(cx, |thread, cx| {
3866 thread.send(message_3_id.clone(), ["Third message"], cx)
3867 })
3868 .unwrap();
3869 cx.run_until_parked();
3870
3871 // Third message should have second message's input tokens (250) before it
3872 thread.read_with(cx, |thread, _| {
3873 assert_eq!(
3874 thread.tokens_before_message(&message_3_id),
3875 Some(250),
3876 "Third message should have 250 tokens before it (from second request)"
3877 );
3878 // Second message should still have 100
3879 assert_eq!(
3880 thread.tokens_before_message(&message_2_id),
3881 Some(100),
3882 "Second message should still have 100 tokens before it"
3883 );
3884 // First message still has none
3885 assert_eq!(
3886 thread.tokens_before_message(&message_1_id),
3887 None,
3888 "First message should still have no tokens before it"
3889 );
3890 });
3891}
3892
3893#[gpui::test]
3894async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3895 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3896 let fake_model = model.as_fake();
3897
3898 // Set up three messages with responses
3899 let message_1_id = UserMessageId::new();
3900 thread
3901 .update(cx, |thread, cx| {
3902 thread.send(message_1_id.clone(), ["Message 1"], cx)
3903 })
3904 .unwrap();
3905 cx.run_until_parked();
3906 fake_model.send_last_completion_stream_text_chunk("Response 1");
3907 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3908 language_model::TokenUsage {
3909 input_tokens: 100,
3910 output_tokens: 50,
3911 cache_creation_input_tokens: 0,
3912 cache_read_input_tokens: 0,
3913 },
3914 ));
3915 fake_model.end_last_completion_stream();
3916 cx.run_until_parked();
3917
3918 let message_2_id = UserMessageId::new();
3919 thread
3920 .update(cx, |thread, cx| {
3921 thread.send(message_2_id.clone(), ["Message 2"], cx)
3922 })
3923 .unwrap();
3924 cx.run_until_parked();
3925 fake_model.send_last_completion_stream_text_chunk("Response 2");
3926 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3927 language_model::TokenUsage {
3928 input_tokens: 250,
3929 output_tokens: 75,
3930 cache_creation_input_tokens: 0,
3931 cache_read_input_tokens: 0,
3932 },
3933 ));
3934 fake_model.end_last_completion_stream();
3935 cx.run_until_parked();
3936
3937 // Verify initial state
3938 thread.read_with(cx, |thread, _| {
3939 assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3940 });
3941
3942 // Truncate at message 2 (removes message 2 and everything after)
3943 thread
3944 .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3945 .unwrap();
3946 cx.run_until_parked();
3947
3948 // After truncation, message_2_id no longer exists, so lookup should return None
3949 thread.read_with(cx, |thread, _| {
3950 assert_eq!(
3951 thread.tokens_before_message(&message_2_id),
3952 None,
3953 "After truncation, message 2 no longer exists"
3954 );
3955 // Message 1 still exists but has no tokens before it
3956 assert_eq!(
3957 thread.tokens_before_message(&message_1_id),
3958 None,
3959 "First message still has no tokens before it"
3960 );
3961 });
3962}
3963
3964#[gpui::test]
3965async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3966 init_test(cx);
3967
3968 let fs = FakeFs::new(cx.executor());
3969 fs.insert_tree("/root", json!({})).await;
3970 let project = Project::test(fs, ["/root".as_ref()], cx).await;
3971
3972 // Test 1: Deny rule blocks command
3973 {
3974 let environment = Rc::new(cx.update(|cx| {
3975 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
3976 }));
3977
3978 cx.update(|cx| {
3979 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3980 settings.tool_permissions.tools.insert(
3981 TerminalTool::NAME.into(),
3982 agent_settings::ToolRules {
3983 default: Some(settings::ToolPermissionMode::Confirm),
3984 always_allow: vec![],
3985 always_deny: vec![
3986 agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3987 ],
3988 always_confirm: vec![],
3989 invalid_patterns: vec![],
3990 },
3991 );
3992 agent_settings::AgentSettings::override_global(settings, cx);
3993 });
3994
3995 #[allow(clippy::arc_with_non_send_sync)]
3996 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3997 let (event_stream, _rx) = crate::ToolCallEventStream::test();
3998
3999 let task = cx.update(|cx| {
4000 tool.run(
4001 ToolInput::resolved(crate::TerminalToolInput {
4002 command: "rm -rf /".to_string(),
4003 cd: ".".to_string(),
4004 timeout_ms: None,
4005 }),
4006 event_stream,
4007 cx,
4008 )
4009 });
4010
4011 let result = task.await;
4012 assert!(
4013 result.is_err(),
4014 "expected command to be blocked by deny rule"
4015 );
4016 let err_msg = result.unwrap_err().to_lowercase();
4017 assert!(
4018 err_msg.contains("blocked"),
4019 "error should mention the command was blocked"
4020 );
4021 }
4022
4023 // Test 2: Allow rule skips confirmation (and overrides default: Deny)
4024 {
4025 let environment = Rc::new(cx.update(|cx| {
4026 FakeThreadEnvironment::default()
4027 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4028 }));
4029
4030 cx.update(|cx| {
4031 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4032 settings.tool_permissions.tools.insert(
4033 TerminalTool::NAME.into(),
4034 agent_settings::ToolRules {
4035 default: Some(settings::ToolPermissionMode::Deny),
4036 always_allow: vec![
4037 agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
4038 ],
4039 always_deny: vec![],
4040 always_confirm: vec![],
4041 invalid_patterns: vec![],
4042 },
4043 );
4044 agent_settings::AgentSettings::override_global(settings, cx);
4045 });
4046
4047 #[allow(clippy::arc_with_non_send_sync)]
4048 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4049 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4050
4051 let task = cx.update(|cx| {
4052 tool.run(
4053 ToolInput::resolved(crate::TerminalToolInput {
4054 command: "echo hello".to_string(),
4055 cd: ".".to_string(),
4056 timeout_ms: None,
4057 }),
4058 event_stream,
4059 cx,
4060 )
4061 });
4062
4063 let update = rx.expect_update_fields().await;
4064 assert!(
4065 update.content.iter().any(|blocks| {
4066 blocks
4067 .iter()
4068 .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
4069 }),
4070 "expected terminal content (allow rule should skip confirmation and override default deny)"
4071 );
4072
4073 let result = task.await;
4074 assert!(
4075 result.is_ok(),
4076 "expected command to succeed without confirmation"
4077 );
4078 }
4079
4080 // Test 3: global default: allow does NOT override always_confirm patterns
4081 {
4082 let environment = Rc::new(cx.update(|cx| {
4083 FakeThreadEnvironment::default()
4084 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4085 }));
4086
4087 cx.update(|cx| {
4088 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4089 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4090 settings.tool_permissions.tools.insert(
4091 TerminalTool::NAME.into(),
4092 agent_settings::ToolRules {
4093 default: Some(settings::ToolPermissionMode::Allow),
4094 always_allow: vec![],
4095 always_deny: vec![],
4096 always_confirm: vec![
4097 agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
4098 ],
4099 invalid_patterns: vec![],
4100 },
4101 );
4102 agent_settings::AgentSettings::override_global(settings, cx);
4103 });
4104
4105 #[allow(clippy::arc_with_non_send_sync)]
4106 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4107 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
4108
4109 let _task = cx.update(|cx| {
4110 tool.run(
4111 ToolInput::resolved(crate::TerminalToolInput {
4112 command: "sudo rm file".to_string(),
4113 cd: ".".to_string(),
4114 timeout_ms: None,
4115 }),
4116 event_stream,
4117 cx,
4118 )
4119 });
4120
4121 // With global default: allow, confirm patterns are still respected
4122 // The expect_authorization() call will panic if no authorization is requested,
4123 // which validates that the confirm pattern still triggers confirmation
4124 let _auth = rx.expect_authorization().await;
4125
4126 drop(_task);
4127 }
4128
4129 // Test 4: tool-specific default: deny is respected even with global default: allow
4130 {
4131 let environment = Rc::new(cx.update(|cx| {
4132 FakeThreadEnvironment::default()
4133 .with_terminal(FakeTerminalHandle::new_with_immediate_exit(cx, 0))
4134 }));
4135
4136 cx.update(|cx| {
4137 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4138 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
4139 settings.tool_permissions.tools.insert(
4140 TerminalTool::NAME.into(),
4141 agent_settings::ToolRules {
4142 default: Some(settings::ToolPermissionMode::Deny),
4143 always_allow: vec![],
4144 always_deny: vec![],
4145 always_confirm: vec![],
4146 invalid_patterns: vec![],
4147 },
4148 );
4149 agent_settings::AgentSettings::override_global(settings, cx);
4150 });
4151
4152 #[allow(clippy::arc_with_non_send_sync)]
4153 let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
4154 let (event_stream, _rx) = crate::ToolCallEventStream::test();
4155
4156 let task = cx.update(|cx| {
4157 tool.run(
4158 ToolInput::resolved(crate::TerminalToolInput {
4159 command: "echo hello".to_string(),
4160 cd: ".".to_string(),
4161 timeout_ms: None,
4162 }),
4163 event_stream,
4164 cx,
4165 )
4166 });
4167
4168 // tool-specific default: deny is respected even with global default: allow
4169 let result = task.await;
4170 assert!(
4171 result.is_err(),
4172 "expected command to be blocked by tool-specific deny default"
4173 );
4174 let err_msg = result.unwrap_err().to_lowercase();
4175 assert!(
4176 err_msg.contains("disabled"),
4177 "error should mention the tool is disabled, got: {err_msg}"
4178 );
4179 }
4180}
4181
4182#[gpui::test]
4183async fn test_subagent_tool_call_end_to_end(cx: &mut TestAppContext) {
4184 init_test(cx);
4185 cx.update(|cx| {
4186 LanguageModelRegistry::test(cx);
4187 });
4188 cx.update(|cx| {
4189 cx.update_flags(true, vec!["subagents".to_string()]);
4190 });
4191
4192 let fs = FakeFs::new(cx.executor());
4193 fs.insert_tree(
4194 "/",
4195 json!({
4196 "a": {
4197 "b.md": "Lorem"
4198 }
4199 }),
4200 )
4201 .await;
4202 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4203 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4204 let agent = NativeAgent::new(
4205 project.clone(),
4206 thread_store.clone(),
4207 Templates::new(),
4208 None,
4209 fs.clone(),
4210 &mut cx.to_async(),
4211 )
4212 .await
4213 .unwrap();
4214 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4215
4216 let acp_thread = cx
4217 .update(|cx| {
4218 connection
4219 .clone()
4220 .new_session(project.clone(), Path::new(""), cx)
4221 })
4222 .await
4223 .unwrap();
4224 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4225 let thread = agent.read_with(cx, |agent, _| {
4226 agent.sessions.get(&session_id).unwrap().thread.clone()
4227 });
4228 let model = Arc::new(FakeLanguageModel::default());
4229
4230 // Ensure empty threads are not saved, even if they get mutated.
4231 thread.update(cx, |thread, cx| {
4232 thread.set_model(model.clone(), cx);
4233 });
4234 cx.run_until_parked();
4235
4236 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4237 cx.run_until_parked();
4238 model.send_last_completion_stream_text_chunk("spawning subagent");
4239 let subagent_tool_input = SpawnAgentToolInput {
4240 label: "label".to_string(),
4241 message: "subagent task prompt".to_string(),
4242 session_id: None,
4243 };
4244 let subagent_tool_use = LanguageModelToolUse {
4245 id: "subagent_1".into(),
4246 name: SpawnAgentTool::NAME.into(),
4247 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4248 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4249 is_input_complete: true,
4250 thought_signature: None,
4251 };
4252 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4253 subagent_tool_use,
4254 ));
4255 model.end_last_completion_stream();
4256
4257 cx.run_until_parked();
4258
4259 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4260 thread
4261 .running_subagent_ids(cx)
4262 .get(0)
4263 .expect("subagent thread should be running")
4264 .clone()
4265 });
4266
4267 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4268 agent
4269 .sessions
4270 .get(&subagent_session_id)
4271 .expect("subagent session should exist")
4272 .acp_thread
4273 .clone()
4274 });
4275
4276 model.send_last_completion_stream_text_chunk("subagent task response");
4277 model.end_last_completion_stream();
4278
4279 cx.run_until_parked();
4280
4281 assert_eq!(
4282 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4283 indoc! {"
4284 ## User
4285
4286 subagent task prompt
4287
4288 ## Assistant
4289
4290 subagent task response
4291
4292 "}
4293 );
4294
4295 model.send_last_completion_stream_text_chunk("Response");
4296 model.end_last_completion_stream();
4297
4298 send.await.unwrap();
4299
4300 assert_eq!(
4301 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4302 indoc! {r#"
4303 ## User
4304
4305 Prompt
4306
4307 ## Assistant
4308
4309 spawning subagent
4310
4311 **Tool Call: label**
4312 Status: Completed
4313
4314 subagent task response
4315
4316 ## Assistant
4317
4318 Response
4319
4320 "#},
4321 );
4322}
4323
4324#[gpui::test]
4325async fn test_subagent_tool_output_does_not_include_thinking(cx: &mut TestAppContext) {
4326 init_test(cx);
4327 cx.update(|cx| {
4328 LanguageModelRegistry::test(cx);
4329 });
4330 cx.update(|cx| {
4331 cx.update_flags(true, vec!["subagents".to_string()]);
4332 });
4333
4334 let fs = FakeFs::new(cx.executor());
4335 fs.insert_tree(
4336 "/",
4337 json!({
4338 "a": {
4339 "b.md": "Lorem"
4340 }
4341 }),
4342 )
4343 .await;
4344 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4345 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4346 let agent = NativeAgent::new(
4347 project.clone(),
4348 thread_store.clone(),
4349 Templates::new(),
4350 None,
4351 fs.clone(),
4352 &mut cx.to_async(),
4353 )
4354 .await
4355 .unwrap();
4356 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4357
4358 let acp_thread = cx
4359 .update(|cx| {
4360 connection
4361 .clone()
4362 .new_session(project.clone(), Path::new(""), cx)
4363 })
4364 .await
4365 .unwrap();
4366 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4367 let thread = agent.read_with(cx, |agent, _| {
4368 agent.sessions.get(&session_id).unwrap().thread.clone()
4369 });
4370 let model = Arc::new(FakeLanguageModel::default());
4371
4372 // Ensure empty threads are not saved, even if they get mutated.
4373 thread.update(cx, |thread, cx| {
4374 thread.set_model(model.clone(), cx);
4375 });
4376 cx.run_until_parked();
4377
4378 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4379 cx.run_until_parked();
4380 model.send_last_completion_stream_text_chunk("spawning subagent");
4381 let subagent_tool_input = SpawnAgentToolInput {
4382 label: "label".to_string(),
4383 message: "subagent task prompt".to_string(),
4384 session_id: None,
4385 };
4386 let subagent_tool_use = LanguageModelToolUse {
4387 id: "subagent_1".into(),
4388 name: SpawnAgentTool::NAME.into(),
4389 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4390 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4391 is_input_complete: true,
4392 thought_signature: None,
4393 };
4394 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4395 subagent_tool_use,
4396 ));
4397 model.end_last_completion_stream();
4398
4399 cx.run_until_parked();
4400
4401 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4402 thread
4403 .running_subagent_ids(cx)
4404 .get(0)
4405 .expect("subagent thread should be running")
4406 .clone()
4407 });
4408
4409 let subagent_thread = agent.read_with(cx, |agent, _cx| {
4410 agent
4411 .sessions
4412 .get(&subagent_session_id)
4413 .expect("subagent session should exist")
4414 .acp_thread
4415 .clone()
4416 });
4417
4418 model.send_last_completion_stream_text_chunk("subagent task response 1");
4419 model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
4420 text: "thinking more about the subagent task".into(),
4421 signature: None,
4422 });
4423 model.send_last_completion_stream_text_chunk("subagent task response 2");
4424 model.end_last_completion_stream();
4425
4426 cx.run_until_parked();
4427
4428 assert_eq!(
4429 subagent_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4430 indoc! {"
4431 ## User
4432
4433 subagent task prompt
4434
4435 ## Assistant
4436
4437 subagent task response 1
4438
4439 <thinking>
4440 thinking more about the subagent task
4441 </thinking>
4442
4443 subagent task response 2
4444
4445 "}
4446 );
4447
4448 model.send_last_completion_stream_text_chunk("Response");
4449 model.end_last_completion_stream();
4450
4451 send.await.unwrap();
4452
4453 assert_eq!(
4454 acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4455 indoc! {r#"
4456 ## User
4457
4458 Prompt
4459
4460 ## Assistant
4461
4462 spawning subagent
4463
4464 **Tool Call: label**
4465 Status: Completed
4466
4467 subagent task response 1
4468
4469 subagent task response 2
4470
4471 ## Assistant
4472
4473 Response
4474
4475 "#},
4476 );
4477}
4478
4479#[gpui::test]
4480async fn test_subagent_tool_call_cancellation_during_task_prompt(cx: &mut TestAppContext) {
4481 init_test(cx);
4482 cx.update(|cx| {
4483 LanguageModelRegistry::test(cx);
4484 });
4485 cx.update(|cx| {
4486 cx.update_flags(true, vec!["subagents".to_string()]);
4487 });
4488
4489 let fs = FakeFs::new(cx.executor());
4490 fs.insert_tree(
4491 "/",
4492 json!({
4493 "a": {
4494 "b.md": "Lorem"
4495 }
4496 }),
4497 )
4498 .await;
4499 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4500 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4501 let agent = NativeAgent::new(
4502 project.clone(),
4503 thread_store.clone(),
4504 Templates::new(),
4505 None,
4506 fs.clone(),
4507 &mut cx.to_async(),
4508 )
4509 .await
4510 .unwrap();
4511 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4512
4513 let acp_thread = cx
4514 .update(|cx| {
4515 connection
4516 .clone()
4517 .new_session(project.clone(), Path::new(""), cx)
4518 })
4519 .await
4520 .unwrap();
4521 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4522 let thread = agent.read_with(cx, |agent, _| {
4523 agent.sessions.get(&session_id).unwrap().thread.clone()
4524 });
4525 let model = Arc::new(FakeLanguageModel::default());
4526
4527 // Ensure empty threads are not saved, even if they get mutated.
4528 thread.update(cx, |thread, cx| {
4529 thread.set_model(model.clone(), cx);
4530 });
4531 cx.run_until_parked();
4532
4533 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
4534 cx.run_until_parked();
4535 model.send_last_completion_stream_text_chunk("spawning subagent");
4536 let subagent_tool_input = SpawnAgentToolInput {
4537 label: "label".to_string(),
4538 message: "subagent task prompt".to_string(),
4539 session_id: None,
4540 };
4541 let subagent_tool_use = LanguageModelToolUse {
4542 id: "subagent_1".into(),
4543 name: SpawnAgentTool::NAME.into(),
4544 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4545 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4546 is_input_complete: true,
4547 thought_signature: None,
4548 };
4549 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4550 subagent_tool_use,
4551 ));
4552 model.end_last_completion_stream();
4553
4554 cx.run_until_parked();
4555
4556 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4557 thread
4558 .running_subagent_ids(cx)
4559 .get(0)
4560 .expect("subagent thread should be running")
4561 .clone()
4562 });
4563 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4564 agent
4565 .sessions
4566 .get(&subagent_session_id)
4567 .expect("subagent session should exist")
4568 .acp_thread
4569 .clone()
4570 });
4571
4572 // model.send_last_completion_stream_text_chunk("subagent task response");
4573 // model.end_last_completion_stream();
4574
4575 // cx.run_until_parked();
4576
4577 acp_thread.update(cx, |thread, cx| thread.cancel(cx)).await;
4578
4579 cx.run_until_parked();
4580
4581 send.await.unwrap();
4582
4583 acp_thread.read_with(cx, |thread, cx| {
4584 assert_eq!(thread.status(), ThreadStatus::Idle);
4585 assert_eq!(
4586 thread.to_markdown(cx),
4587 indoc! {"
4588 ## User
4589
4590 Prompt
4591
4592 ## Assistant
4593
4594 spawning subagent
4595
4596 **Tool Call: label**
4597 Status: Canceled
4598
4599 "}
4600 );
4601 });
4602 subagent_acp_thread.read_with(cx, |thread, cx| {
4603 assert_eq!(thread.status(), ThreadStatus::Idle);
4604 assert_eq!(
4605 thread.to_markdown(cx),
4606 indoc! {"
4607 ## User
4608
4609 subagent task prompt
4610
4611 "}
4612 );
4613 });
4614}
4615
4616#[gpui::test]
4617async fn test_subagent_tool_resume_session(cx: &mut TestAppContext) {
4618 init_test(cx);
4619 cx.update(|cx| {
4620 LanguageModelRegistry::test(cx);
4621 });
4622 cx.update(|cx| {
4623 cx.update_flags(true, vec!["subagents".to_string()]);
4624 });
4625
4626 let fs = FakeFs::new(cx.executor());
4627 fs.insert_tree(
4628 "/",
4629 json!({
4630 "a": {
4631 "b.md": "Lorem"
4632 }
4633 }),
4634 )
4635 .await;
4636 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
4637 let thread_store = cx.new(|cx| ThreadStore::new(cx));
4638 let agent = NativeAgent::new(
4639 project.clone(),
4640 thread_store.clone(),
4641 Templates::new(),
4642 None,
4643 fs.clone(),
4644 &mut cx.to_async(),
4645 )
4646 .await
4647 .unwrap();
4648 let connection = Rc::new(NativeAgentConnection(agent.clone()));
4649
4650 let acp_thread = cx
4651 .update(|cx| {
4652 connection
4653 .clone()
4654 .new_session(project.clone(), Path::new(""), cx)
4655 })
4656 .await
4657 .unwrap();
4658 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
4659 let thread = agent.read_with(cx, |agent, _| {
4660 agent.sessions.get(&session_id).unwrap().thread.clone()
4661 });
4662 let model = Arc::new(FakeLanguageModel::default());
4663
4664 thread.update(cx, |thread, cx| {
4665 thread.set_model(model.clone(), cx);
4666 });
4667 cx.run_until_parked();
4668
4669 // === First turn: create subagent ===
4670 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
4671 cx.run_until_parked();
4672 model.send_last_completion_stream_text_chunk("spawning subagent");
4673 let subagent_tool_input = SpawnAgentToolInput {
4674 label: "initial task".to_string(),
4675 message: "do the first task".to_string(),
4676 session_id: None,
4677 };
4678 let subagent_tool_use = LanguageModelToolUse {
4679 id: "subagent_1".into(),
4680 name: SpawnAgentTool::NAME.into(),
4681 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
4682 input: serde_json::to_value(&subagent_tool_input).unwrap(),
4683 is_input_complete: true,
4684 thought_signature: None,
4685 };
4686 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
4687 subagent_tool_use,
4688 ));
4689 model.end_last_completion_stream();
4690
4691 cx.run_until_parked();
4692
4693 let subagent_session_id = thread.read_with(cx, |thread, cx| {
4694 thread
4695 .running_subagent_ids(cx)
4696 .get(0)
4697 .expect("subagent thread should be running")
4698 .clone()
4699 });
4700
4701 let subagent_acp_thread = agent.read_with(cx, |agent, _cx| {
4702 agent
4703 .sessions
4704 .get(&subagent_session_id)
4705 .expect("subagent session should exist")
4706 .acp_thread
4707 .clone()
4708 });
4709
4710 // Subagent responds
4711 model.send_last_completion_stream_text_chunk("first task response");
4712 model.end_last_completion_stream();
4713
4714 cx.run_until_parked();
4715
4716 // Parent model responds to complete first turn
4717 model.send_last_completion_stream_text_chunk("First response");
4718 model.end_last_completion_stream();
4719
4720 send.await.unwrap();
4721
4722 // Verify subagent is no longer running
4723 thread.read_with(cx, |thread, cx| {
4724 assert!(
4725 thread.running_subagent_ids(cx).is_empty(),
4726 "subagent should not be running after completion"
4727 );
4728 });
4729
4730 // === Second turn: resume subagent with session_id ===
4731 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
4732 cx.run_until_parked();
4733 model.send_last_completion_stream_text_chunk("resuming subagent");
4734 let resume_tool_input = SpawnAgentToolInput {
4735 label: "follow-up task".to_string(),
4736 message: "do the follow-up task".to_string(),
4737 session_id: Some(subagent_session_id.clone()),
4738 };
4739 let resume_tool_use = LanguageModelToolUse {
4740 id: "subagent_2".into(),
4741 name: SpawnAgentTool::NAME.into(),
4742 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
4743 input: serde_json::to_value(&resume_tool_input).unwrap(),
4744 is_input_complete: true,
4745 thought_signature: None,
4746 };
4747 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
4748 model.end_last_completion_stream();
4749
4750 cx.run_until_parked();
4751
4752 // Subagent should be running again with the same session
4753 thread.read_with(cx, |thread, cx| {
4754 let running = thread.running_subagent_ids(cx);
4755 assert_eq!(running.len(), 1, "subagent should be running");
4756 assert_eq!(running[0], subagent_session_id, "should be same session");
4757 });
4758
4759 // Subagent responds to follow-up
4760 model.send_last_completion_stream_text_chunk("follow-up task response");
4761 model.end_last_completion_stream();
4762
4763 cx.run_until_parked();
4764
4765 // Parent model responds to complete second turn
4766 model.send_last_completion_stream_text_chunk("Second response");
4767 model.end_last_completion_stream();
4768
4769 send2.await.unwrap();
4770
4771 // Verify subagent is no longer running
4772 thread.read_with(cx, |thread, cx| {
4773 assert!(
4774 thread.running_subagent_ids(cx).is_empty(),
4775 "subagent should not be running after resume completion"
4776 );
4777 });
4778
4779 // Verify the subagent's acp thread has both conversation turns
4780 assert_eq!(
4781 subagent_acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx)),
4782 indoc! {"
4783 ## User
4784
4785 do the first task
4786
4787 ## Assistant
4788
4789 first task response
4790
4791 ## User
4792
4793 do the follow-up task
4794
4795 ## Assistant
4796
4797 follow-up task response
4798
4799 "}
4800 );
4801}
4802
4803#[gpui::test]
4804async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
4805 init_test(cx);
4806
4807 cx.update(|cx| {
4808 cx.update_flags(true, vec!["subagents".to_string()]);
4809 });
4810
4811 let fs = FakeFs::new(cx.executor());
4812 fs.insert_tree(path!("/test"), json!({})).await;
4813 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4814 let project_context = cx.new(|_cx| ProjectContext::default());
4815 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4816 let context_server_registry =
4817 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4818 let model = Arc::new(FakeLanguageModel::default());
4819
4820 let environment = Rc::new(cx.update(|cx| {
4821 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4822 }));
4823
4824 let thread = cx.new(|cx| {
4825 let mut thread = Thread::new(
4826 project.clone(),
4827 project_context,
4828 context_server_registry,
4829 Templates::new(),
4830 Some(model),
4831 cx,
4832 );
4833 thread.add_default_tools(environment, cx);
4834 thread
4835 });
4836
4837 thread.read_with(cx, |thread, _| {
4838 assert!(
4839 thread.has_registered_tool(SpawnAgentTool::NAME),
4840 "subagent tool should be present when feature flag is enabled"
4841 );
4842 });
4843}
4844
4845#[gpui::test]
4846async fn test_subagent_thread_inherits_parent_thread_properties(cx: &mut TestAppContext) {
4847 init_test(cx);
4848
4849 cx.update(|cx| {
4850 cx.update_flags(true, vec!["subagents".to_string()]);
4851 });
4852
4853 let fs = FakeFs::new(cx.executor());
4854 fs.insert_tree(path!("/test"), json!({})).await;
4855 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4856 let project_context = cx.new(|_cx| ProjectContext::default());
4857 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4858 let context_server_registry =
4859 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4860 let model = Arc::new(FakeLanguageModel::default());
4861
4862 let parent_thread = cx.new(|cx| {
4863 Thread::new(
4864 project.clone(),
4865 project_context,
4866 context_server_registry,
4867 Templates::new(),
4868 Some(model.clone()),
4869 cx,
4870 )
4871 });
4872
4873 let subagent_thread = cx.new(|cx| Thread::new_subagent(&parent_thread, cx));
4874 subagent_thread.read_with(cx, |subagent_thread, cx| {
4875 assert!(subagent_thread.is_subagent());
4876 assert_eq!(subagent_thread.depth(), 1);
4877 assert_eq!(
4878 subagent_thread.model().map(|model| model.id()),
4879 Some(model.id())
4880 );
4881 assert_eq!(
4882 subagent_thread.parent_thread_id(),
4883 Some(parent_thread.read(cx).id().clone())
4884 );
4885 });
4886}
4887
4888#[gpui::test]
4889async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4890 init_test(cx);
4891
4892 cx.update(|cx| {
4893 cx.update_flags(true, vec!["subagents".to_string()]);
4894 });
4895
4896 let fs = FakeFs::new(cx.executor());
4897 fs.insert_tree(path!("/test"), json!({})).await;
4898 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4899 let project_context = cx.new(|_cx| ProjectContext::default());
4900 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4901 let context_server_registry =
4902 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4903 let model = Arc::new(FakeLanguageModel::default());
4904 let environment = Rc::new(cx.update(|cx| {
4905 FakeThreadEnvironment::default().with_terminal(FakeTerminalHandle::new_never_exits(cx))
4906 }));
4907
4908 let deep_parent_thread = cx.new(|cx| {
4909 let mut thread = Thread::new(
4910 project.clone(),
4911 project_context,
4912 context_server_registry,
4913 Templates::new(),
4914 Some(model.clone()),
4915 cx,
4916 );
4917 thread.set_subagent_context(SubagentContext {
4918 parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4919 depth: MAX_SUBAGENT_DEPTH - 1,
4920 });
4921 thread
4922 });
4923 let deep_subagent_thread = cx.new(|cx| {
4924 let mut thread = Thread::new_subagent(&deep_parent_thread, cx);
4925 thread.add_default_tools(environment, cx);
4926 thread
4927 });
4928
4929 deep_subagent_thread.read_with(cx, |thread, _| {
4930 assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4931 assert!(
4932 !thread.has_registered_tool(SpawnAgentTool::NAME),
4933 "subagent tool should not be present at max depth"
4934 );
4935 });
4936}
4937
4938#[gpui::test]
4939async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4940 init_test(cx);
4941
4942 cx.update(|cx| {
4943 cx.update_flags(true, vec!["subagents".to_string()]);
4944 });
4945
4946 let fs = FakeFs::new(cx.executor());
4947 fs.insert_tree(path!("/test"), json!({})).await;
4948 let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4949 let project_context = cx.new(|_cx| ProjectContext::default());
4950 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4951 let context_server_registry =
4952 cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4953 let model = Arc::new(FakeLanguageModel::default());
4954
4955 let parent = cx.new(|cx| {
4956 Thread::new(
4957 project.clone(),
4958 project_context.clone(),
4959 context_server_registry.clone(),
4960 Templates::new(),
4961 Some(model.clone()),
4962 cx,
4963 )
4964 });
4965
4966 let subagent = cx.new(|cx| Thread::new_subagent(&parent, cx));
4967
4968 parent.update(cx, |thread, _cx| {
4969 thread.register_running_subagent(subagent.downgrade());
4970 });
4971
4972 subagent
4973 .update(cx, |thread, cx| {
4974 thread.send(UserMessageId::new(), ["Do work".to_string()], cx)
4975 })
4976 .unwrap();
4977 cx.run_until_parked();
4978
4979 subagent.read_with(cx, |thread, _| {
4980 assert!(!thread.is_turn_complete(), "subagent should be running");
4981 });
4982
4983 parent.update(cx, |thread, cx| {
4984 thread.cancel(cx).detach();
4985 });
4986
4987 subagent.read_with(cx, |thread, _| {
4988 assert!(
4989 thread.is_turn_complete(),
4990 "subagent should be cancelled when parent cancels"
4991 );
4992 });
4993}
4994
4995#[gpui::test]
4996async fn test_subagent_context_window_warning(cx: &mut TestAppContext) {
4997 init_test(cx);
4998 cx.update(|cx| {
4999 LanguageModelRegistry::test(cx);
5000 });
5001 cx.update(|cx| {
5002 cx.update_flags(true, vec!["subagents".to_string()]);
5003 });
5004
5005 let fs = FakeFs::new(cx.executor());
5006 fs.insert_tree(
5007 "/",
5008 json!({
5009 "a": {
5010 "b.md": "Lorem"
5011 }
5012 }),
5013 )
5014 .await;
5015 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5016 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5017 let agent = NativeAgent::new(
5018 project.clone(),
5019 thread_store.clone(),
5020 Templates::new(),
5021 None,
5022 fs.clone(),
5023 &mut cx.to_async(),
5024 )
5025 .await
5026 .unwrap();
5027 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5028
5029 let acp_thread = cx
5030 .update(|cx| {
5031 connection
5032 .clone()
5033 .new_session(project.clone(), Path::new(""), cx)
5034 })
5035 .await
5036 .unwrap();
5037 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5038 let thread = agent.read_with(cx, |agent, _| {
5039 agent.sessions.get(&session_id).unwrap().thread.clone()
5040 });
5041 let model = Arc::new(FakeLanguageModel::default());
5042
5043 thread.update(cx, |thread, cx| {
5044 thread.set_model(model.clone(), cx);
5045 });
5046 cx.run_until_parked();
5047
5048 // Start the parent turn
5049 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5050 cx.run_until_parked();
5051 model.send_last_completion_stream_text_chunk("spawning subagent");
5052 let subagent_tool_input = SpawnAgentToolInput {
5053 label: "label".to_string(),
5054 message: "subagent task prompt".to_string(),
5055 session_id: None,
5056 };
5057 let subagent_tool_use = LanguageModelToolUse {
5058 id: "subagent_1".into(),
5059 name: SpawnAgentTool::NAME.into(),
5060 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5061 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5062 is_input_complete: true,
5063 thought_signature: None,
5064 };
5065 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5066 subagent_tool_use,
5067 ));
5068 model.end_last_completion_stream();
5069
5070 cx.run_until_parked();
5071
5072 // Verify subagent is running
5073 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5074 thread
5075 .running_subagent_ids(cx)
5076 .get(0)
5077 .expect("subagent thread should be running")
5078 .clone()
5079 });
5080
5081 // Send a usage update that crosses the warning threshold (80% of 1,000,000)
5082 model.send_last_completion_stream_text_chunk("partial work");
5083 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5084 TokenUsage {
5085 input_tokens: 850_000,
5086 output_tokens: 0,
5087 cache_creation_input_tokens: 0,
5088 cache_read_input_tokens: 0,
5089 },
5090 ));
5091
5092 cx.run_until_parked();
5093
5094 // The subagent should no longer be running
5095 thread.read_with(cx, |thread, cx| {
5096 assert!(
5097 thread.running_subagent_ids(cx).is_empty(),
5098 "subagent should be stopped after context window warning"
5099 );
5100 });
5101
5102 // The parent model should get a new completion request to respond to the tool error
5103 model.send_last_completion_stream_text_chunk("Response after warning");
5104 model.end_last_completion_stream();
5105
5106 send.await.unwrap();
5107
5108 // Verify the parent thread shows the warning error in the tool call
5109 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5110 assert!(
5111 markdown.contains("nearing the end of its context window"),
5112 "tool output should contain context window warning message, got:\n{markdown}"
5113 );
5114 assert!(
5115 markdown.contains("Status: Failed"),
5116 "tool call should have Failed status, got:\n{markdown}"
5117 );
5118
5119 // Verify the subagent session still exists (can be resumed)
5120 agent.read_with(cx, |agent, _cx| {
5121 assert!(
5122 agent.sessions.contains_key(&subagent_session_id),
5123 "subagent session should still exist for potential resume"
5124 );
5125 });
5126}
5127
5128#[gpui::test]
5129async fn test_subagent_no_context_window_warning_when_already_at_warning(cx: &mut TestAppContext) {
5130 init_test(cx);
5131 cx.update(|cx| {
5132 LanguageModelRegistry::test(cx);
5133 });
5134 cx.update(|cx| {
5135 cx.update_flags(true, vec!["subagents".to_string()]);
5136 });
5137
5138 let fs = FakeFs::new(cx.executor());
5139 fs.insert_tree(
5140 "/",
5141 json!({
5142 "a": {
5143 "b.md": "Lorem"
5144 }
5145 }),
5146 )
5147 .await;
5148 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5149 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5150 let agent = NativeAgent::new(
5151 project.clone(),
5152 thread_store.clone(),
5153 Templates::new(),
5154 None,
5155 fs.clone(),
5156 &mut cx.to_async(),
5157 )
5158 .await
5159 .unwrap();
5160 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5161
5162 let acp_thread = cx
5163 .update(|cx| {
5164 connection
5165 .clone()
5166 .new_session(project.clone(), Path::new(""), cx)
5167 })
5168 .await
5169 .unwrap();
5170 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5171 let thread = agent.read_with(cx, |agent, _| {
5172 agent.sessions.get(&session_id).unwrap().thread.clone()
5173 });
5174 let model = Arc::new(FakeLanguageModel::default());
5175
5176 thread.update(cx, |thread, cx| {
5177 thread.set_model(model.clone(), cx);
5178 });
5179 cx.run_until_parked();
5180
5181 // === First turn: create subagent, trigger context window warning ===
5182 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("First prompt", cx));
5183 cx.run_until_parked();
5184 model.send_last_completion_stream_text_chunk("spawning subagent");
5185 let subagent_tool_input = SpawnAgentToolInput {
5186 label: "initial task".to_string(),
5187 message: "do the first task".to_string(),
5188 session_id: None,
5189 };
5190 let subagent_tool_use = LanguageModelToolUse {
5191 id: "subagent_1".into(),
5192 name: SpawnAgentTool::NAME.into(),
5193 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5194 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5195 is_input_complete: true,
5196 thought_signature: None,
5197 };
5198 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5199 subagent_tool_use,
5200 ));
5201 model.end_last_completion_stream();
5202
5203 cx.run_until_parked();
5204
5205 let subagent_session_id = thread.read_with(cx, |thread, cx| {
5206 thread
5207 .running_subagent_ids(cx)
5208 .get(0)
5209 .expect("subagent thread should be running")
5210 .clone()
5211 });
5212
5213 // Subagent sends a usage update that crosses the warning threshold.
5214 // This triggers Normal→Warning, stopping the subagent.
5215 model.send_last_completion_stream_text_chunk("partial work");
5216 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5217 TokenUsage {
5218 input_tokens: 850_000,
5219 output_tokens: 0,
5220 cache_creation_input_tokens: 0,
5221 cache_read_input_tokens: 0,
5222 },
5223 ));
5224
5225 cx.run_until_parked();
5226
5227 // Verify the first turn was stopped with a context window warning
5228 thread.read_with(cx, |thread, cx| {
5229 assert!(
5230 thread.running_subagent_ids(cx).is_empty(),
5231 "subagent should be stopped after context window warning"
5232 );
5233 });
5234
5235 // Parent model responds to complete first turn
5236 model.send_last_completion_stream_text_chunk("First response");
5237 model.end_last_completion_stream();
5238
5239 send.await.unwrap();
5240
5241 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5242 assert!(
5243 markdown.contains("nearing the end of its context window"),
5244 "first turn should have context window warning, got:\n{markdown}"
5245 );
5246
5247 // === Second turn: resume the same subagent (now at Warning level) ===
5248 let send2 = acp_thread.update(cx, |thread, cx| thread.send_raw("Follow up", cx));
5249 cx.run_until_parked();
5250 model.send_last_completion_stream_text_chunk("resuming subagent");
5251 let resume_tool_input = SpawnAgentToolInput {
5252 label: "follow-up task".to_string(),
5253 message: "do the follow-up task".to_string(),
5254 session_id: Some(subagent_session_id.clone()),
5255 };
5256 let resume_tool_use = LanguageModelToolUse {
5257 id: "subagent_2".into(),
5258 name: SpawnAgentTool::NAME.into(),
5259 raw_input: serde_json::to_string(&resume_tool_input).unwrap(),
5260 input: serde_json::to_value(&resume_tool_input).unwrap(),
5261 is_input_complete: true,
5262 thought_signature: None,
5263 };
5264 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(resume_tool_use));
5265 model.end_last_completion_stream();
5266
5267 cx.run_until_parked();
5268
5269 // Subagent responds with tokens still at warning level (no worse).
5270 // Since ratio_before_prompt was already Warning, this should NOT
5271 // trigger the context window warning again.
5272 model.send_last_completion_stream_text_chunk("follow-up task response");
5273 model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
5274 TokenUsage {
5275 input_tokens: 870_000,
5276 output_tokens: 0,
5277 cache_creation_input_tokens: 0,
5278 cache_read_input_tokens: 0,
5279 },
5280 ));
5281 model.end_last_completion_stream();
5282
5283 cx.run_until_parked();
5284
5285 // Parent model responds to complete second turn
5286 model.send_last_completion_stream_text_chunk("Second response");
5287 model.end_last_completion_stream();
5288
5289 send2.await.unwrap();
5290
5291 // The resumed subagent should have completed normally since the ratio
5292 // didn't transition (it was Warning before and stayed at Warning)
5293 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5294 assert!(
5295 markdown.contains("follow-up task response"),
5296 "resumed subagent should complete normally when already at warning, got:\n{markdown}"
5297 );
5298 // The second tool call should NOT have a context window warning
5299 let second_tool_pos = markdown
5300 .find("follow-up task")
5301 .expect("should find follow-up tool call");
5302 let after_second_tool = &markdown[second_tool_pos..];
5303 assert!(
5304 !after_second_tool.contains("nearing the end of its context window"),
5305 "should NOT contain context window warning for resumed subagent at same level, got:\n{after_second_tool}"
5306 );
5307}
5308
5309#[gpui::test]
5310async fn test_subagent_error_propagation(cx: &mut TestAppContext) {
5311 init_test(cx);
5312 cx.update(|cx| {
5313 LanguageModelRegistry::test(cx);
5314 });
5315 cx.update(|cx| {
5316 cx.update_flags(true, vec!["subagents".to_string()]);
5317 });
5318
5319 let fs = FakeFs::new(cx.executor());
5320 fs.insert_tree(
5321 "/",
5322 json!({
5323 "a": {
5324 "b.md": "Lorem"
5325 }
5326 }),
5327 )
5328 .await;
5329 let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
5330 let thread_store = cx.new(|cx| ThreadStore::new(cx));
5331 let agent = NativeAgent::new(
5332 project.clone(),
5333 thread_store.clone(),
5334 Templates::new(),
5335 None,
5336 fs.clone(),
5337 &mut cx.to_async(),
5338 )
5339 .await
5340 .unwrap();
5341 let connection = Rc::new(NativeAgentConnection(agent.clone()));
5342
5343 let acp_thread = cx
5344 .update(|cx| {
5345 connection
5346 .clone()
5347 .new_session(project.clone(), Path::new(""), cx)
5348 })
5349 .await
5350 .unwrap();
5351 let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
5352 let thread = agent.read_with(cx, |agent, _| {
5353 agent.sessions.get(&session_id).unwrap().thread.clone()
5354 });
5355 let model = Arc::new(FakeLanguageModel::default());
5356
5357 thread.update(cx, |thread, cx| {
5358 thread.set_model(model.clone(), cx);
5359 });
5360 cx.run_until_parked();
5361
5362 // Start the parent turn
5363 let send = acp_thread.update(cx, |thread, cx| thread.send_raw("Prompt", cx));
5364 cx.run_until_parked();
5365 model.send_last_completion_stream_text_chunk("spawning subagent");
5366 let subagent_tool_input = SpawnAgentToolInput {
5367 label: "label".to_string(),
5368 message: "subagent task prompt".to_string(),
5369 session_id: None,
5370 };
5371 let subagent_tool_use = LanguageModelToolUse {
5372 id: "subagent_1".into(),
5373 name: SpawnAgentTool::NAME.into(),
5374 raw_input: serde_json::to_string(&subagent_tool_input).unwrap(),
5375 input: serde_json::to_value(&subagent_tool_input).unwrap(),
5376 is_input_complete: true,
5377 thought_signature: None,
5378 };
5379 model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
5380 subagent_tool_use,
5381 ));
5382 model.end_last_completion_stream();
5383
5384 cx.run_until_parked();
5385
5386 // Verify subagent is running
5387 thread.read_with(cx, |thread, cx| {
5388 assert!(
5389 !thread.running_subagent_ids(cx).is_empty(),
5390 "subagent should be running"
5391 );
5392 });
5393
5394 // The subagent's model returns a non-retryable error
5395 model.send_last_completion_stream_error(LanguageModelCompletionError::PromptTooLarge {
5396 tokens: None,
5397 });
5398
5399 cx.run_until_parked();
5400
5401 // The subagent should no longer be running
5402 thread.read_with(cx, |thread, cx| {
5403 assert!(
5404 thread.running_subagent_ids(cx).is_empty(),
5405 "subagent should not be running after error"
5406 );
5407 });
5408
5409 // The parent model should get a new completion request to respond to the tool error
5410 model.send_last_completion_stream_text_chunk("Response after error");
5411 model.end_last_completion_stream();
5412
5413 send.await.unwrap();
5414
5415 // Verify the parent thread shows the error in the tool call
5416 let markdown = acp_thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
5417 assert!(
5418 markdown.contains("Status: Failed"),
5419 "tool call should have Failed status after model error, got:\n{markdown}"
5420 );
5421}
5422
5423#[gpui::test]
5424async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5425 init_test(cx);
5426
5427 let fs = FakeFs::new(cx.executor());
5428 fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5429 .await;
5430 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5431
5432 cx.update(|cx| {
5433 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5434 settings.tool_permissions.tools.insert(
5435 EditFileTool::NAME.into(),
5436 agent_settings::ToolRules {
5437 default: Some(settings::ToolPermissionMode::Allow),
5438 always_allow: vec![],
5439 always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5440 always_confirm: vec![],
5441 invalid_patterns: vec![],
5442 },
5443 );
5444 agent_settings::AgentSettings::override_global(settings, cx);
5445 });
5446
5447 let context_server_registry =
5448 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5449 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5450 let templates = crate::Templates::new();
5451 let thread = cx.new(|cx| {
5452 crate::Thread::new(
5453 project.clone(),
5454 cx.new(|_cx| prompt_store::ProjectContext::default()),
5455 context_server_registry,
5456 templates.clone(),
5457 None,
5458 cx,
5459 )
5460 });
5461
5462 #[allow(clippy::arc_with_non_send_sync)]
5463 let tool = Arc::new(crate::EditFileTool::new(
5464 project.clone(),
5465 thread.downgrade(),
5466 language_registry,
5467 templates,
5468 ));
5469 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5470
5471 let task = cx.update(|cx| {
5472 tool.run(
5473 ToolInput::resolved(crate::EditFileToolInput {
5474 display_description: "Edit sensitive file".to_string(),
5475 path: "root/sensitive_config.txt".into(),
5476 mode: crate::EditFileMode::Edit,
5477 }),
5478 event_stream,
5479 cx,
5480 )
5481 });
5482
5483 let result = task.await;
5484 assert!(result.is_err(), "expected edit to be blocked");
5485 assert!(
5486 result.unwrap_err().to_string().contains("blocked"),
5487 "error should mention the edit was blocked"
5488 );
5489}
5490
5491#[gpui::test]
5492async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5493 init_test(cx);
5494
5495 let fs = FakeFs::new(cx.executor());
5496 fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5497 .await;
5498 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5499
5500 cx.update(|cx| {
5501 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5502 settings.tool_permissions.tools.insert(
5503 DeletePathTool::NAME.into(),
5504 agent_settings::ToolRules {
5505 default: Some(settings::ToolPermissionMode::Allow),
5506 always_allow: vec![],
5507 always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5508 always_confirm: vec![],
5509 invalid_patterns: vec![],
5510 },
5511 );
5512 agent_settings::AgentSettings::override_global(settings, cx);
5513 });
5514
5515 let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5516
5517 #[allow(clippy::arc_with_non_send_sync)]
5518 let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5519 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5520
5521 let task = cx.update(|cx| {
5522 tool.run(
5523 ToolInput::resolved(crate::DeletePathToolInput {
5524 path: "root/important_data.txt".to_string(),
5525 }),
5526 event_stream,
5527 cx,
5528 )
5529 });
5530
5531 let result = task.await;
5532 assert!(result.is_err(), "expected deletion to be blocked");
5533 assert!(
5534 result.unwrap_err().contains("blocked"),
5535 "error should mention the deletion was blocked"
5536 );
5537}
5538
5539#[gpui::test]
5540async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5541 init_test(cx);
5542
5543 let fs = FakeFs::new(cx.executor());
5544 fs.insert_tree(
5545 "/root",
5546 json!({
5547 "safe.txt": "content",
5548 "protected": {}
5549 }),
5550 )
5551 .await;
5552 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5553
5554 cx.update(|cx| {
5555 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5556 settings.tool_permissions.tools.insert(
5557 MovePathTool::NAME.into(),
5558 agent_settings::ToolRules {
5559 default: Some(settings::ToolPermissionMode::Allow),
5560 always_allow: vec![],
5561 always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5562 always_confirm: vec![],
5563 invalid_patterns: vec![],
5564 },
5565 );
5566 agent_settings::AgentSettings::override_global(settings, cx);
5567 });
5568
5569 #[allow(clippy::arc_with_non_send_sync)]
5570 let tool = Arc::new(crate::MovePathTool::new(project));
5571 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5572
5573 let task = cx.update(|cx| {
5574 tool.run(
5575 ToolInput::resolved(crate::MovePathToolInput {
5576 source_path: "root/safe.txt".to_string(),
5577 destination_path: "root/protected/safe.txt".to_string(),
5578 }),
5579 event_stream,
5580 cx,
5581 )
5582 });
5583
5584 let result = task.await;
5585 assert!(
5586 result.is_err(),
5587 "expected move to be blocked due to destination path"
5588 );
5589 assert!(
5590 result.unwrap_err().contains("blocked"),
5591 "error should mention the move was blocked"
5592 );
5593}
5594
5595#[gpui::test]
5596async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5597 init_test(cx);
5598
5599 let fs = FakeFs::new(cx.executor());
5600 fs.insert_tree(
5601 "/root",
5602 json!({
5603 "secret.txt": "secret content",
5604 "public": {}
5605 }),
5606 )
5607 .await;
5608 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5609
5610 cx.update(|cx| {
5611 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5612 settings.tool_permissions.tools.insert(
5613 MovePathTool::NAME.into(),
5614 agent_settings::ToolRules {
5615 default: Some(settings::ToolPermissionMode::Allow),
5616 always_allow: vec![],
5617 always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5618 always_confirm: vec![],
5619 invalid_patterns: vec![],
5620 },
5621 );
5622 agent_settings::AgentSettings::override_global(settings, cx);
5623 });
5624
5625 #[allow(clippy::arc_with_non_send_sync)]
5626 let tool = Arc::new(crate::MovePathTool::new(project));
5627 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5628
5629 let task = cx.update(|cx| {
5630 tool.run(
5631 ToolInput::resolved(crate::MovePathToolInput {
5632 source_path: "root/secret.txt".to_string(),
5633 destination_path: "root/public/not_secret.txt".to_string(),
5634 }),
5635 event_stream,
5636 cx,
5637 )
5638 });
5639
5640 let result = task.await;
5641 assert!(
5642 result.is_err(),
5643 "expected move to be blocked due to source path"
5644 );
5645 assert!(
5646 result.unwrap_err().contains("blocked"),
5647 "error should mention the move was blocked"
5648 );
5649}
5650
5651#[gpui::test]
5652async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5653 init_test(cx);
5654
5655 let fs = FakeFs::new(cx.executor());
5656 fs.insert_tree(
5657 "/root",
5658 json!({
5659 "confidential.txt": "confidential data",
5660 "dest": {}
5661 }),
5662 )
5663 .await;
5664 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5665
5666 cx.update(|cx| {
5667 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5668 settings.tool_permissions.tools.insert(
5669 CopyPathTool::NAME.into(),
5670 agent_settings::ToolRules {
5671 default: Some(settings::ToolPermissionMode::Allow),
5672 always_allow: vec![],
5673 always_deny: vec![
5674 agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5675 ],
5676 always_confirm: vec![],
5677 invalid_patterns: vec![],
5678 },
5679 );
5680 agent_settings::AgentSettings::override_global(settings, cx);
5681 });
5682
5683 #[allow(clippy::arc_with_non_send_sync)]
5684 let tool = Arc::new(crate::CopyPathTool::new(project));
5685 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5686
5687 let task = cx.update(|cx| {
5688 tool.run(
5689 ToolInput::resolved(crate::CopyPathToolInput {
5690 source_path: "root/confidential.txt".to_string(),
5691 destination_path: "root/dest/copy.txt".to_string(),
5692 }),
5693 event_stream,
5694 cx,
5695 )
5696 });
5697
5698 let result = task.await;
5699 assert!(result.is_err(), "expected copy to be blocked");
5700 assert!(
5701 result.unwrap_err().contains("blocked"),
5702 "error should mention the copy was blocked"
5703 );
5704}
5705
5706#[gpui::test]
5707async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5708 init_test(cx);
5709
5710 let fs = FakeFs::new(cx.executor());
5711 fs.insert_tree(
5712 "/root",
5713 json!({
5714 "normal.txt": "normal content",
5715 "readonly": {
5716 "config.txt": "readonly content"
5717 }
5718 }),
5719 )
5720 .await;
5721 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5722
5723 cx.update(|cx| {
5724 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5725 settings.tool_permissions.tools.insert(
5726 SaveFileTool::NAME.into(),
5727 agent_settings::ToolRules {
5728 default: Some(settings::ToolPermissionMode::Allow),
5729 always_allow: vec![],
5730 always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5731 always_confirm: vec![],
5732 invalid_patterns: vec![],
5733 },
5734 );
5735 agent_settings::AgentSettings::override_global(settings, cx);
5736 });
5737
5738 #[allow(clippy::arc_with_non_send_sync)]
5739 let tool = Arc::new(crate::SaveFileTool::new(project));
5740 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5741
5742 let task = cx.update(|cx| {
5743 tool.run(
5744 ToolInput::resolved(crate::SaveFileToolInput {
5745 paths: vec![
5746 std::path::PathBuf::from("root/normal.txt"),
5747 std::path::PathBuf::from("root/readonly/config.txt"),
5748 ],
5749 }),
5750 event_stream,
5751 cx,
5752 )
5753 });
5754
5755 let result = task.await;
5756 assert!(
5757 result.is_err(),
5758 "expected save to be blocked due to denied path"
5759 );
5760 assert!(
5761 result.unwrap_err().contains("blocked"),
5762 "error should mention the save was blocked"
5763 );
5764}
5765
5766#[gpui::test]
5767async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5768 init_test(cx);
5769
5770 let fs = FakeFs::new(cx.executor());
5771 fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5772 .await;
5773 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5774
5775 cx.update(|cx| {
5776 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5777 settings.tool_permissions.tools.insert(
5778 SaveFileTool::NAME.into(),
5779 agent_settings::ToolRules {
5780 default: Some(settings::ToolPermissionMode::Allow),
5781 always_allow: vec![],
5782 always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5783 always_confirm: vec![],
5784 invalid_patterns: vec![],
5785 },
5786 );
5787 agent_settings::AgentSettings::override_global(settings, cx);
5788 });
5789
5790 #[allow(clippy::arc_with_non_send_sync)]
5791 let tool = Arc::new(crate::SaveFileTool::new(project));
5792 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5793
5794 let task = cx.update(|cx| {
5795 tool.run(
5796 ToolInput::resolved(crate::SaveFileToolInput {
5797 paths: vec![std::path::PathBuf::from("root/config.secret")],
5798 }),
5799 event_stream,
5800 cx,
5801 )
5802 });
5803
5804 let result = task.await;
5805 assert!(result.is_err(), "expected save to be blocked");
5806 assert!(
5807 result.unwrap_err().contains("blocked"),
5808 "error should mention the save was blocked"
5809 );
5810}
5811
5812#[gpui::test]
5813async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5814 init_test(cx);
5815
5816 cx.update(|cx| {
5817 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5818 settings.tool_permissions.tools.insert(
5819 WebSearchTool::NAME.into(),
5820 agent_settings::ToolRules {
5821 default: Some(settings::ToolPermissionMode::Allow),
5822 always_allow: vec![],
5823 always_deny: vec![
5824 agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5825 ],
5826 always_confirm: vec![],
5827 invalid_patterns: vec![],
5828 },
5829 );
5830 agent_settings::AgentSettings::override_global(settings, cx);
5831 });
5832
5833 #[allow(clippy::arc_with_non_send_sync)]
5834 let tool = Arc::new(crate::WebSearchTool);
5835 let (event_stream, _rx) = crate::ToolCallEventStream::test();
5836
5837 let input: crate::WebSearchToolInput =
5838 serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5839
5840 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
5841
5842 let result = task.await;
5843 assert!(result.is_err(), "expected search to be blocked");
5844 match result.unwrap_err() {
5845 crate::WebSearchToolOutput::Error { error } => {
5846 assert!(
5847 error.contains("blocked"),
5848 "error should mention the search was blocked"
5849 );
5850 }
5851 other => panic!("expected Error variant, got: {other:?}"),
5852 }
5853}
5854
5855#[gpui::test]
5856async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5857 init_test(cx);
5858
5859 let fs = FakeFs::new(cx.executor());
5860 fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5861 .await;
5862 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5863
5864 cx.update(|cx| {
5865 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5866 settings.tool_permissions.tools.insert(
5867 EditFileTool::NAME.into(),
5868 agent_settings::ToolRules {
5869 default: Some(settings::ToolPermissionMode::Confirm),
5870 always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5871 always_deny: vec![],
5872 always_confirm: vec![],
5873 invalid_patterns: vec![],
5874 },
5875 );
5876 agent_settings::AgentSettings::override_global(settings, cx);
5877 });
5878
5879 let context_server_registry =
5880 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5881 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5882 let templates = crate::Templates::new();
5883 let thread = cx.new(|cx| {
5884 crate::Thread::new(
5885 project.clone(),
5886 cx.new(|_cx| prompt_store::ProjectContext::default()),
5887 context_server_registry,
5888 templates.clone(),
5889 None,
5890 cx,
5891 )
5892 });
5893
5894 #[allow(clippy::arc_with_non_send_sync)]
5895 let tool = Arc::new(crate::EditFileTool::new(
5896 project,
5897 thread.downgrade(),
5898 language_registry,
5899 templates,
5900 ));
5901 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5902
5903 let _task = cx.update(|cx| {
5904 tool.run(
5905 ToolInput::resolved(crate::EditFileToolInput {
5906 display_description: "Edit README".to_string(),
5907 path: "root/README.md".into(),
5908 mode: crate::EditFileMode::Edit,
5909 }),
5910 event_stream,
5911 cx,
5912 )
5913 });
5914
5915 cx.run_until_parked();
5916
5917 let event = rx.try_next();
5918 assert!(
5919 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5920 "expected no authorization request for allowed .md file"
5921 );
5922}
5923
5924#[gpui::test]
5925async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut TestAppContext) {
5926 init_test(cx);
5927
5928 let fs = FakeFs::new(cx.executor());
5929 fs.insert_tree(
5930 "/root",
5931 json!({
5932 ".zed": {
5933 "settings.json": "{}"
5934 },
5935 "README.md": "# Hello"
5936 }),
5937 )
5938 .await;
5939 let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5940
5941 cx.update(|cx| {
5942 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5943 settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
5944 agent_settings::AgentSettings::override_global(settings, cx);
5945 });
5946
5947 let context_server_registry =
5948 cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5949 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5950 let templates = crate::Templates::new();
5951 let thread = cx.new(|cx| {
5952 crate::Thread::new(
5953 project.clone(),
5954 cx.new(|_cx| prompt_store::ProjectContext::default()),
5955 context_server_registry,
5956 templates.clone(),
5957 None,
5958 cx,
5959 )
5960 });
5961
5962 #[allow(clippy::arc_with_non_send_sync)]
5963 let tool = Arc::new(crate::EditFileTool::new(
5964 project,
5965 thread.downgrade(),
5966 language_registry,
5967 templates,
5968 ));
5969
5970 // Editing a file inside .zed/ should still prompt even with global default: allow,
5971 // because local settings paths are sensitive and require confirmation regardless.
5972 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5973 let _task = cx.update(|cx| {
5974 tool.run(
5975 ToolInput::resolved(crate::EditFileToolInput {
5976 display_description: "Edit local settings".to_string(),
5977 path: "root/.zed/settings.json".into(),
5978 mode: crate::EditFileMode::Edit,
5979 }),
5980 event_stream,
5981 cx,
5982 )
5983 });
5984
5985 let _update = rx.expect_update_fields().await;
5986 let _auth = rx.expect_authorization().await;
5987}
5988
5989#[gpui::test]
5990async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5991 init_test(cx);
5992
5993 cx.update(|cx| {
5994 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5995 settings.tool_permissions.tools.insert(
5996 FetchTool::NAME.into(),
5997 agent_settings::ToolRules {
5998 default: Some(settings::ToolPermissionMode::Allow),
5999 always_allow: vec![],
6000 always_deny: vec![
6001 agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
6002 ],
6003 always_confirm: vec![],
6004 invalid_patterns: vec![],
6005 },
6006 );
6007 agent_settings::AgentSettings::override_global(settings, cx);
6008 });
6009
6010 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6011
6012 #[allow(clippy::arc_with_non_send_sync)]
6013 let tool = Arc::new(crate::FetchTool::new(http_client));
6014 let (event_stream, _rx) = crate::ToolCallEventStream::test();
6015
6016 let input: crate::FetchToolInput =
6017 serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
6018
6019 let task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6020
6021 let result = task.await;
6022 assert!(result.is_err(), "expected fetch to be blocked");
6023 assert!(
6024 result.unwrap_err().contains("blocked"),
6025 "error should mention the fetch was blocked"
6026 );
6027}
6028
6029#[gpui::test]
6030async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
6031 init_test(cx);
6032
6033 cx.update(|cx| {
6034 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
6035 settings.tool_permissions.tools.insert(
6036 FetchTool::NAME.into(),
6037 agent_settings::ToolRules {
6038 default: Some(settings::ToolPermissionMode::Confirm),
6039 always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
6040 always_deny: vec![],
6041 always_confirm: vec![],
6042 invalid_patterns: vec![],
6043 },
6044 );
6045 agent_settings::AgentSettings::override_global(settings, cx);
6046 });
6047
6048 let http_client = gpui::http_client::FakeHttpClient::with_200_response();
6049
6050 #[allow(clippy::arc_with_non_send_sync)]
6051 let tool = Arc::new(crate::FetchTool::new(http_client));
6052 let (event_stream, mut rx) = crate::ToolCallEventStream::test();
6053
6054 let input: crate::FetchToolInput =
6055 serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
6056
6057 let _task = cx.update(|cx| tool.run(ToolInput::resolved(input), event_stream, cx));
6058
6059 cx.run_until_parked();
6060
6061 let event = rx.try_next();
6062 assert!(
6063 !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
6064 "expected no authorization request for allowed docs.rs URL"
6065 );
6066}
6067
6068#[gpui::test]
6069async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
6070 init_test(cx);
6071 always_allow_tools(cx);
6072
6073 let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
6074 let fake_model = model.as_fake();
6075
6076 // Add a tool so we can simulate tool calls
6077 thread.update(cx, |thread, _cx| {
6078 thread.add_tool(EchoTool);
6079 });
6080
6081 // Start a turn by sending a message
6082 let mut events = thread
6083 .update(cx, |thread, cx| {
6084 thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
6085 })
6086 .unwrap();
6087 cx.run_until_parked();
6088
6089 // Simulate the model making a tool call
6090 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
6091 LanguageModelToolUse {
6092 id: "tool_1".into(),
6093 name: "echo".into(),
6094 raw_input: r#"{"text": "hello"}"#.into(),
6095 input: json!({"text": "hello"}),
6096 is_input_complete: true,
6097 thought_signature: None,
6098 },
6099 ));
6100 fake_model
6101 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
6102
6103 // Signal that a message is queued before ending the stream
6104 thread.update(cx, |thread, _cx| {
6105 thread.set_has_queued_message(true);
6106 });
6107
6108 // Now end the stream - tool will run, and the boundary check should see the queue
6109 fake_model.end_last_completion_stream();
6110
6111 // Collect all events until the turn stops
6112 let all_events = collect_events_until_stop(&mut events, cx).await;
6113
6114 // Verify we received the tool call event
6115 let tool_call_ids: Vec<_> = all_events
6116 .iter()
6117 .filter_map(|e| match e {
6118 Ok(ThreadEvent::ToolCall(tc)) => Some(tc.tool_call_id.to_string()),
6119 _ => None,
6120 })
6121 .collect();
6122 assert_eq!(
6123 tool_call_ids,
6124 vec!["tool_1"],
6125 "Should have received a tool call event for our echo tool"
6126 );
6127
6128 // The turn should have stopped with EndTurn
6129 let stop_reasons = stop_events(all_events);
6130 assert_eq!(
6131 stop_reasons,
6132 vec![acp::StopReason::EndTurn],
6133 "Turn should have ended after tool completion due to queued message"
6134 );
6135
6136 // Verify the queued message flag is still set
6137 thread.update(cx, |thread, _cx| {
6138 assert!(
6139 thread.has_queued_message(),
6140 "Should still have queued message flag set"
6141 );
6142 });
6143
6144 // Thread should be idle now
6145 thread.update(cx, |thread, _cx| {
6146 assert!(
6147 thread.is_turn_complete(),
6148 "Thread should not be running after turn ends"
6149 );
6150 });
6151}